Source code for sofia_redux.visualization.utils.model_fit

#  Licensed under a 3-clause BSD style license - see LICENSE.rst

import itertools
import uuid
from typing import (List, Dict, Optional, Union,
                    Any, Iterable, TypeVar, Tuple, Sequence)

from astropy import modeling as am
from matplotlib.axis import Axis
import numpy as np

from sofia_redux.visualization.utils.eye_error import EyeError

__all__ = ['ModelFit']


ArrayLike = TypeVar('ArrayLike', List, Tuple, Sequence, np.ndarray)


[docs] class ModelFit(object): """ Hold parameters and descriptions of generic model fits. Parameters ---------- params : dict, optional A nested dictionary that describes a single fit to data. The top level key is the `model_id` and the corresponding value is a dictionary whose only key is the `order`. The values of the inner dictionary are the actual parameters of the fit. If not provided, an empty instance of `ModelFit` with default values is returned. Attributes ---------- id_tag : str Represents the model fit ID number. model_id : uuid.UUID The name of the data model the fit was made to. Typically, the name of a FITS file. filename : str Name of the file feature : str Name of the model used to fit the feature. background: str Name of the model used to fit the background. Typical examples are 'constant' or 'linear'. fit_type : list Strings describing the base model fit to the data. Typical values include 'gauss', 'moffat', 'const', 'linear'. Often composite models are fit to data, in which case all valid descriptions are included. The fit_type is used to define the parameter names to include. order : int Order number of the `model_id` data set that the model was fit to. aperture : int Aperture number of the `model_id` data set that the model was fit to. status : str Status of the last attempted fit. May be 'pass' or 'fail'. fit : am.Model The actual fit to the data. axis_names : list Names for the fit axes, used as keys for several other attributes. units : dict Name of the units used for each axis. limits : dict The upper and lower selection bounds of the data fit. fields : dict Field names for each axis. dataset : dict Data arrays for each axis. axis : ma.Axes The axes object the fit has been plotted on. visible : bool Flag to indicate visibility of the fit plot. columns : array-like Column values corresponding to the x-axis data. color : str Color hex value for the fit plot. param_names : dict Names of the parameters for each type of fitted model. """ id_iter = itertools.count() """int : ID iteration value.""" def __init__(self, params=None): self.id_tag = f'model_fit_{next(self.id_iter)}' self.model_id = '' self.filename = '' self.feature = '' self.background = '' self.fit_type = ['gauss'] self.order = 0 self.aperture = 0 self.status = 'pass' self.fit = None self.axis_names = ['x', 'y'] self.units = dict.fromkeys(self.axis_names, '') self.limits = dict.fromkeys(['lower', 'upper'], None) self.fields = dict.fromkeys(self.axis_names, '') self.dataset = dict.fromkeys(self.axis_names, None) self.axis = None self.visible = True self.columns = None self.color = None self.param_names = {'gauss': ['mean', 'stddev', 'amplitude'], 'moffat': ['x_0', 'gamma', 'amplitude', 'alpha'], 'linear': ['intercept', 'slope'], 'constant': ['amplitude']} if params: self.load_parameters(params)
[docs] def get_id(self) -> str: """Get the ID for the current fit model.""" return self.id_tag
[docs] def load_parameters(self, parameters: Dict[str, Dict]) -> None: """ Parse model fit parameters. Parameters ---------- parameters : dict A nested dictionary that describes a single fit to data. The top level key is the `model_id` and the corresponding value is a dictionary whose only key is the `order`. The values of the inner dictionary are the actual parameters of the fit. If not provided, an empty instance of `ModelFit` with default values is returned. """ self.model_id = list(parameters.keys())[0] self.order = int(list(list(parameters.values())[0].keys())[0]) params = parameters[self.model_id][self.order] self.filename = params.get('filename', '') keys = ['fit', 'axis', 'status'] for key in keys: setattr(self, key, params.get(key, None)) self.visible = params.get('visible', True) for axis in self.axis_names: self.units[axis] = params.get(f'{axis}_unit', None) self.fields[axis] = params.get(f'{axis}_field', None) self.limits['lower'] = params.get('lower_limit', None) self.limits['upper'] = params.get('upper_limit', None) if isinstance(self.fit, am.CompoundModel): self.fit_type = [self._determine_fit_type(fit) for fit in self.fit] else: self.fit_type = [self._determine_fit_type(self.fit)]
[docs] def set_fit_type(self, fit_type: Optional[Union[str, List[str]]] = None, feature: Optional[str] = None, background: Optional[str] = None): """ Set the type of fit used. Parameters ---------- fit_type : list, optional If given, `self.fits` will be set to it directly. If not given, the fit type will be parsed from `feature` and `background`. feature : str, optional Name of the model used to fit the feature. Typical examples are 'gaussian' or 'moffat'. background : str, optional Name of the model used to fit the background. Typical examples are 'constant' or 'linear'. Raises ------ EyeError If no arguments are given. """ if fit_type: if not isinstance(fit_type, list): fit_type = [fit_type] self.fit_type = fit_type elif feature or background: self.fit_type = list() if feature: self.fit_type.append(feature) self.set_feature(feature) if background: self.fit_type.append(background) self.set_background(background) else: raise EyeError('Need to provide fit type')
[docs] def set_feature(self, feature: str) -> None: """ Set the feature name for the model fitting. Parameters ---------- feature : str Name of the model used to fit the feature. """ self.feature = feature
[docs] def get_feature(self) -> str: """ Get the feature name for the model fitting. Returns ------- feature : str Name of the model used to fit the feature. """ return self.feature
[docs] def set_background(self, background: str) -> None: """ Set the background name for the model fitting. Parameters ---------- background: str Name of the model used to fit the background. Typical examples are 'constant' or 'linear'. """ self.background = background
[docs] def get_background(self) -> str: """ Get the background name for the model fitting Returns ------- background: str Name of the model used to fit the background. Typical examples are 'constant' or 'linear'. """ return self.background
[docs] def get_fields(self, axis: Optional[str] = None ) -> Optional[Union[str, Dict[str, str]]]: """ Retrieve field names for a given axis. Parameters ---------- axis : str, optional Names of the axes to be used as keys for fields. Returns ------- fields : str or dict If an axis has been provided, the corresponding value is returned. Otherwise, a dictionary containing all values is returned. """ if axis is not None: try: return self.fields[axis] except KeyError: return None else: return self.fields
[docs] def set_fields(self, fields: Dict[str, str]) -> None: """ Set field names. Parameters ---------- fields : dict Keys are axis names; values are field names. """ for key, field in fields.items(): self.fields[key] = field
[docs] def set_axis(self, axis: Axis) -> None: """ Set the fitting axis. Parameters ---------- axis : matplotlib.Axes Assign the fitting axis """ self.axis = axis
[docs] def get_axis(self) -> Axis: """ Get the fitting axis. Returns ------- axis : matplotlib.Axes The fitting axis. """ return self.axis
[docs] def set_status(self, status: str) -> None: """ Set the last fit status. Parameters ---------- status : str Status to set. """ self.status = status
[docs] def get_status(self) -> str: """ Get the last fit status. Returns ------- str Last status set. """ return self.status
[docs] def set_filename(self, filename: str) -> None: """ Set the filename. Parameters ---------- filename : str Filename to set. """ self.filename = str(filename)
[docs] def get_filename(self) -> str: """ Get the filename. Returns ------- str The filename. """ return self.filename
[docs] def set_model_id(self, model_id: uuid.UUID) -> None: """ Set a model ID. Parameters ---------- model_id : uuid.UUID Unique model ID to set. """ self.model_id = model_id
[docs] def get_model_id(self) -> uuid.UUID: """ Get the model ID. Returns ------- uuid.UUID The model ID. """ return self.model_id
[docs] def set_order(self, order: int) -> None: """ Set an order number. Parameters ---------- order : int The order to set. """ self.order = order
[docs] def get_order(self) -> int: """ Get the order number. Returns ------- int The order number. """ return self.order
[docs] def set_aperture(self, aperture: int) -> None: """ Set an aperture number. Parameters ---------- aperture : int The aperture number to set. """ self.aperture = aperture
[docs] def get_aperture(self) -> int: """ Get the aperture number. Returns ------- int The aperture number. """ return self.aperture
[docs] def set_columns(self, columns: ArrayLike) -> None: """ Set column data. Parameters ---------- columns : array-like The column data to set. """ self.columns = columns
[docs] def get_columns(self) -> int: """ Get the columns array. Returns ------- array-like The column data. """ return self.columns
[docs] def set_dataset(self, dataset: Optional[Dict[str, Iterable]] = None, x: Optional[Iterable] = None, y: Optional[Iterable] = None) -> None: """ Set a data set to fit to. Parameters ---------- dataset : dict, optional If provided, must have axis names for keys and data arrays for values. x : array-like, optional Directly specifies the x-axis data values. y : array-like, optional Directly specifies the y-axis data values. Raises ------ EyeError If all expected axes are not provided. """ if dataset is not None: self.dataset = dataset elif x is not None and y is not None: self.dataset['x'] = x self.dataset['y'] = y else: raise EyeError('Must provide all axes of dataset')
[docs] def get_dataset(self) -> Dict[str, Optional[Iterable]]: """ Get the data set to fit to. Returns ------- dict Keys are axis names, values are data arrays. """ return self.dataset
@staticmethod def _determine_fit_type(fit: am.Model) -> str: """ Map the type of a model fit to a string description. Parameters ---------- fit : astropy.modeling.Model A simple, non-compound, astropy model. Returns ------- fit_type : str String description of the type of fit. If fit type cannot be identified, return 'UNKNOWN'. """ if isinstance(fit, am.models.Gaussian1D): return 'gauss' elif isinstance(fit, am.models.Moffat1D): return 'moffat' elif isinstance(fit, am.models.Linear1D): return 'linear' elif isinstance(fit, am.models.Const1D): return 'constant' else: return 'UNKNOWN'
[docs] def set_visibility(self, state: bool): """ Set the visibility of a fit. Parameters ---------- state : bool If the fitting is going to visible or not. """ if isinstance(state, bool): self.visible = state
[docs] def get_visibility(self) -> bool: """ Get the visibility of the fitting model. Returns ------- visible : bool True, if the fitting is going to visible otherwise False. """ return self.visible
[docs] def set_color(self, color: str) -> None: """ Set the plot color for the fit artist. Parameters ---------- color : str The color hex value to set. """ self.color = color
[docs] def get_color(self) -> str: """ Get the plot color for the fit artist. Returns ------- str The plot color. """ return self.color
[docs] def get_fit(self): """ Obtain a copy of a fit. Returns ------- fit : astropy.modeling.Model A simple, non-compound, astropy model. Returns None if there is no fitting model. """ if self.fit is None: return None else: return self.fit.copy()
[docs] def set_fit(self, fit) -> None: """ Set the fitting model to a given fit. Parameters ---------- fit : astropy.modeling.Model A simple, non-compound, astropy model. """ self.fit = fit.copy()
[docs] def get_units(self, key: Optional[str] = None ) -> Optional[Union[Dict[str, str], str]]: """ Obtain the units. Parameters ---------- key : str, optional Obtain the unit for a given key (axis). Returns ------- units : str or Dict A str if valid key is provided. For an invalid key it returns None. If no key is provided, it returns a dict. """ if key: try: return self.units[key] except KeyError: return None else: return self.units
[docs] def set_units(self, units: Dict[str, str]) -> None: """ Set the units of the fitting model Parameters ---------- units : dict Keys are axis names; values are corresponding units. """ self.units = units
[docs] def get_limits(self, limit: Optional[str] = None ) -> Optional[Union[float, Dict[str, Optional[float]]]]: """ Obtain the limit values. Parameters ---------- limit : str, optional May be 'upper' or 'lower'. Returns ------- float or dict If a valid limit name is specified, only that value is returned. If an invalid limit is specified, None is returned. Otherwise, the limits dictionary is returned. """ if limit is not None: try: return self.limits[limit] except KeyError: return None else: return self.limits
[docs] def set_limits(self, limits: Union[float, Dict[str, float], List[List[float]]], key: Optional[str] = None) -> None: """ Set the limit values. Parameters ---------- limits : float, dict, or list of list If key is provided, should be specified as a float. Otherwise, should contain both lower and upper limits. If provided as a list, it is assumed the lower limit is the first element in the first list and the upper limit is the second element in the first list. key : str, optional May be 'upper' or 'lower'. """ if key: self.limits[key] = limits elif isinstance(limits, dict): self.limits = limits elif isinstance(limits, list): self.limits['lower'] = limits[0][0] self.limits['upper'] = limits[1][0]
[docs] def get_fit_types(self) -> List[str]: """ Get the fit type. Returns ------- list of str Fit types as [feature, background]. """ return self.fit_type
[docs] def parameters_as_string(self) -> Dict[str, str]: """ Format all fitting details as strings for display in table. Returns ------- param : dict Dictionary with parameter names for keys and string formatted parameter values for values. """ param = {'model_id': str(self.model_id), 'filename': str(self.filename), 'order': f'{self.order+1:d}', 'aperture': f'{self.aperture+1:d}', 'x_field': f"{self.fields['x']} [{self.units['x']}]", 'y_field': f"{self.fields['y']} [{self.units['y']}]", 'lower_limit': f"{self.limits['lower']:.10g}", 'upper_limit': f"{self.limits['upper']:.10g}", 'type': ', '.join(self.fit_type), 'axis': self.axis, 'visible': self.visible, 'baseline': f'{self.get_baseline():.10g}', 'mid_point': f'{self.get_mid_point():.10g}', 'mid_point_column': f'{self.get_mid_point_column():.10g}', 'fwhm': f'{self.get_fwhm():.10g}'} if self.fit is not None: for i, fit_type in enumerate(self.fit_type): for name in self.param_names[fit_type]: try: value = getattr(self.fit, name).value except AttributeError: value = getattr(self.fit, f'{name}_{i}').value param[name] = f'{value:.10g}' return param
[docs] def parameters_as_dict(self) -> Dict[str, Any]: """ Fit parameters in dictionary form for returning to the View. Returns ------- param : dict Keys are the names of each parameter and values are the corresponding values. """ param = {'model_id': self.model_id, 'filename': self.filename, 'order': self.order, 'aperture': self.aperture, 'x_field': self.fields['x'], 'y_field': self.fields['y'], 'x_unit': self.units['x'], 'y_unit': self.units['y'], 'fit': self.fit, 'lower_limit': self.limits['lower'], 'upper_limit': self.limits['upper'], 'type': self.fit_type, 'axis': self.axis, 'visible': self.visible, 'baseline': self.get_baseline(), 'mid_point': self.get_mid_point(), 'mid_point_column': self.get_mid_point_column(), 'fwhm': self.get_fwhm()} for i, fit_type in enumerate(self.fit_type): for name in self.param_names[fit_type]: try: value = getattr(self.fit, name).value except AttributeError: value = getattr(self.fit, f'{name}_{i}').value param[name] = value return param
[docs] def parameters_as_list(self) -> List[Any]: """ Fit parameters in a list structure for writing to disk. Returns ------- params : list List of all fit parameter values. Parameter names are not included. """ # For selection param = [self.filename, self.order, self.aperture, self.fields['x'], self.fields['y'], self.units['x'], self.units['y'] ] for i, fit_type in enumerate(self.fit_type): for name in self.param_names[fit_type]: try: value = getattr(self.fit, name).value except AttributeError: value = getattr(self.fit, f'{name}_{i}').value param.append(f'{value:.10g}') param.extend([f'{self.get_baseline():.10g}', f'{self.get_mid_point():.10g}', f'{self.get_mid_point_column():.10g}', f'{self.get_fwhm():.10g}', self.limits['lower'], self.limits['upper'], self.visible]) return param
[docs] def parameters_as_html(self) -> Dict[str, str]: """ Format all fitting details as HTML for display in text view. Returns ------- param : dict Dictionary with parameter names for keys and string formatted parameter values for values. """ param = self.parameters_as_string() html = ['<html>'] status = self.get_status() if status == 'pass': style = 'color: green' else: # Fit was not successful style = 'color: red' status_string = status.replace('_', ' ').title() html.append(f'Last fit status: ' f'<span style="{style}">{status_string}</span>') html.append('Parameters: <pre>') skip = ['axis', 'visible', 'model_id'] for par, value in param.items(): if par not in skip: html.append(f' {par}: {value}') html.append('</pre></html>') return '<br>'.join(html)
[docs] def get_mid_point(self) -> float: """ Determine the midpoint of the feature. Returns ------- mid_point : float Midpoint of the feature of the fit. If not found, return NaN. """ if not self.fit: return np.nan if isinstance(self.fit, am.CompoundModel): fits = self.fit else: fits = [self.fit] for model in fits: if isinstance(model, am.models.Gaussian1D): return model.mean.value elif isinstance(model, am.models.Moffat1D): return model.x_0.value # if no feature was fit, but limits are available, # return the midpoint of the fit range try: return (self.limits['lower'] + self.limits['upper']) / 2 except (TypeError, ValueError): return np.nan
[docs] def get_mid_point_column(self) -> float: """ Determine the midpoint of the feature in column number. Returns ------- mid_point : float Midpoint of the feature of the fit in pixels. If not found, return NaN. """ midpoint = self.get_mid_point() if np.isnan(midpoint): return np.nan try: mid_col = np.interp(midpoint, self.columns, np.arange(self.columns.size)) except (TypeError, ValueError, IndexError, AttributeError): mid_col = np.nan return mid_col
[docs] def get_fwhm(self) -> float: """ Determine the FWHM of the feature. Returns ------- fwhm : float FWHM of the feature of the fit. If not found, return NaN. """ if not self.fit: return np.nan if isinstance(self.fit, am.CompoundModel): fits = self.fit else: fits = [self.fit] for model in fits: if isinstance(model, am.models.Gaussian1D): return model.fwhm elif isinstance(model, am.models.Moffat1D): return model.fwhm return np.nan
[docs] def get_baseline(self) -> float: """ Determine the baseline of the feature. Returns ------- baseline : float Baseline of the feature of the fit. If not found, return NaN. """ if not self.fit: return np.nan midpoint = self.get_mid_point() if np.isnan(midpoint): return np.nan if isinstance(self.fit, am.CompoundModel): fits = self.fit else: fits = [self.fit] for model in fits: if isinstance(model, am.models.Linear1D): return model.slope.value * midpoint \ + model.intercept.value elif isinstance(model, am.models.Const1D): return model.amplitude.value # if no background fit was found, model was feature only, # so return 0 return 0.0
[docs] def scale_parameters(self, x_scale: float, y_scale: float) -> None: """ Scale parameters after a fit. This is used to restore a baseline scale, after fitting to normalized data. Parameters ---------- x_scale : float Value to scale the x data. y_scale : float Value to scale the y data """ if not self.fit: return if isinstance(self.fit, am.CompoundModel): fits = self.fit else: fits = [self.fit] for model in fits: if isinstance(model, am.models.Linear1D): model.intercept *= y_scale model.slope *= y_scale / x_scale elif isinstance(model, am.models.Const1D): model.amplitude *= y_scale elif isinstance(model, am.models.Gaussian1D): model.mean *= x_scale model.stddev *= x_scale model.amplitude *= y_scale elif isinstance(model, am.models.Moffat1D): model.x_0 *= x_scale model.gamma *= x_scale model.amplitude *= y_scale