[docs]
class FittingResults(QtWidgets.QDialog, frw.Ui_Dialog):
"""
Fitting results display widget.
After a curve has been fit to data in the Eye, a
``FittingResults`` window is opened to display the
parameters of the fit and all subsequent fits. The
window can also be used to remove the fit curves
from the plots in the Eye.
Parameters
----------
parent : View
Parent view widget for the dialog window.
Attributes
----------
model_fits : list
Feature fit parameters to display in the widget.
signals : sofia_redux.visualization.signals.Signals
Custom signals to pass events to the controller, as needed.
table_header : list
Column titles for table.
canvas : matplotlib.matplotlib.backends.backend_agg.FigureCanvasAgg
Matplotlib canvas for the most recent fit.
fig : matplotlib.figure.Figure
Matplotlib figure for the most recent fit.
ax : matplotlib.axes.Axes
Matplotlib axes for the most recent fit.
fit_color : str
Color for the most recent fit plot.
"""
def __init__(self, parent: Any):
if not HAS_PYQT6: # pragma: no cover
raise ImportError('PyQt6 package is required for the Eye.')
super(self.__class__, self).__init__(parent)
self.setupUi(self)
self.setModal(0)
self.model_fits = list()
self.signals = parent.signals
self.table_header = list()
self.canvas = self.last_fit_widget.canvas
self.fig = self.last_fit_widget.canvas.fig
self.ax = self.fig.add_subplot()
self.fit_color = '#2848ad'
self.splitter.setStretchFactor(1, 1)
self.currently_plotted = list()
self.save_button.clicked.connect(self.save_results)
self.close_button.clicked.connect(self.close)
self.clear_button.clicked.connect(self.clear_fit)
[docs]
def clear_fit(self) -> None:
"""Clear the table widget display and reset fits stored."""
# clear table
self.model_fits = list()
self._define_table()
self._clear_figure()
# will clear the plot in the Eye
self.signals.clear_fit.emit()
[docs]
def add_results(self, fit_params: List[model_fit.ModelFit]) -> None:
"""
Add fit results to the table display.
Parameters
----------
fit_params : list of ModelFit
Updated fit parameters.
] """
if not isinstance(fit_params, list):
fit_params = [fit_params]
for fit in fit_params:
if fit.get_status() == 'pass':
self.model_fits.append(fit)
self._update_table()
self._update_figure(fit_params)
def _update_table(self):
"""Reset the contents of the table to reflect all loaded fits."""
self.table_widget.clearContents()
self._define_table()
for row_index, model in enumerate(self.model_fits):
checkbox = QtWidgets.QCheckBox(parent=self.table_widget)
checkbox.setChecked(model.get_visibility())
checkbox.clicked.connect(self.signals.toggle_fit_visibility)
self.table_widget.setCellWidget(row_index, 0, checkbox)
details = model.parameters_as_string()
filename = details['filename']
vhead = QTableWidgetItem(os.path.basename(filename))
vhead.setTextAlignment(QtCore.Qt.AlignmentFlag.AlignLeft)
vhead.setToolTip(filename)
self.table_widget.setVerticalHeaderItem(row_index, vhead)
for col_index, col_name in enumerate(self.table_header[1:]):
try:
text = details[col_name]
except KeyError:
text = 'NA'
item = QTableWidgetItem(text)
item.setTextAlignment(QtCore.Qt.AlignmentFlag.AlignHCenter)
self.table_widget.setItem(row_index, col_index + 1, item)
# resize rows and columns to contents
self._resize()
# scroll to bottom to show latest update
self.table_widget.scrollToBottom()
def _resize(self) -> None:
"""Resize the table to current contents."""
self.table_widget.resizeColumnsToContents()
self.table_widget.resizeRowsToContents()
def _define_table(self) -> None:
"""Define the table widget row count."""
# don't redefine column count, set row count to number
# of parameter items (filenames/orders within the
# self.parameters list of dicts)
columns = {'generic': ['show', 'order', 'aperture',
'x_field', 'y_field', 'type'],
'gauss': ['mid_point', 'mid_point_column',
'fwhm', 'amplitude'],
'moffat': ['mid_point', 'mid_point_column',
'fwhm', 'amplitude'],
'linear': ['mid_point', 'mid_point_column',
'baseline', 'slope'],
'constant': ['baseline']}
fit_types = set()
for fit in self.model_fits:
fit_types.update(fit.get_fit_types())
self.table_header = columns['generic']
for fit_type in fit_types:
for col in columns[fit_type]:
if col not in self.table_header:
self.table_header.append(col)
self.table_widget.setColumnCount(len(self.table_header))
self.table_widget.setHorizontalHeaderLabels(
[self._labelize(s) for s in self.table_header])
self.table_widget.setRowCount(len(self.model_fits))
@staticmethod
def _labelize(value: str) -> str:
"""
Format a keyword to appropriate style for display.
Parameters
----------
value : str
String to be formatted.
Returns
----------
value : str
Formatted string.
"""
if value == 'fwhm':
value = 'FWHM'
else:
value = value.strip().replace('_', ' ').title()
return value
def _clear_figure(self):
"""Clear the last fit plot and description."""
self.ax.clear()
self.canvas.draw_idle()
self.last_fit_values.setText('')
self.currently_plotted = list()
def _update_figure(self, fits: Optional[List[model_fit.ModelFit]] = None,
current_only: Optional[bool] = False) -> None:
"""
Plot the last model fit.
Parameters
----------
fits : list of model_fit.ModelFit, optional
The fits to plot. If not provided, the plot is updated from
the current model fits.
"""
# clear old lines
self.ax.clear()
html = list()
if current_only:
fits = self.currently_plotted.copy()
elif fits is None:
fits = self.model_fits
if len(fits) == 0:
return
self.currently_plotted = list()
# Only fit most recent fit
for fit in fits:
status = fit.get_status()
# don't bother with completely empty data sets
if 'empty' in status.lower():
continue
# Plot the data fit was made to
dataset = fit.get_dataset()
if all([v is not None for v in dataset.values()]):
self.ax.step(dataset['x'], dataset['y'], where='mid',
color=fit.get_color(), alpha=0.9)
# If the fit was successful, plot the fit
if status == 'pass':
# model
x = np.linspace(fit.get_limits('lower'),
fit.get_limits('upper'), 100)
y = fit.get_fit()(x)
self.ax.plot(x, y, color='gray', linestyle='dashed',
alpha=0.8)
# mid point
self.ax.axvline(fit.get_mid_point(), color='gray',
linestyle='dotted', alpha=0.6)
self.currently_plotted.append(fit)
html.append(fit.parameters_as_html())
# show all the parameters
self.last_fit_values.setHtml('<br>'.join(html))
# Configure the plot
self.ax.set_xlabel(fits[0].get_fields('x'))
self.ax.set_ylabel(fits[0].get_fields('y'))
self.ax.set_title('Last Fit Feature')
self.canvas.draw_idle()
[docs]
def save_results(self) -> None:
"""Save current table to a CSV file."""
# Get save filename
filename = self._get_save_filename()
if not filename:
return
# Get the fit models to save based on selections
parameters = self._selected_parameters()
# Write to file
self._write_parameters(parameters, filename)
def _selected_parameters(self) -> List[List]:
"""
Retrieve selected parameters from the table.
Returns
-------
parameters : list of list
Fit parameters flattened into a single list, where
each element is a list of row values.
"""
selected = list()
flat_params = self.format_parameters(kind='string')
for index in self.table_widget.selectionModel().selectedRows():
selected.append(flat_params[index.row()])
if selected:
return selected
else:
return flat_params
[docs]
def gather_models(self) -> List[model_fit.ModelFit]:
"""
Return all loaded models with updated visibility.
Returns
-------
model_fits : list
Feature fit parameters to display in the widget.
"""
for row_index in range(self.table_widget.rowCount()):
checkbox = self.table_widget.cellWidget(row_index, 0)
try:
fit = self.model_fits[row_index]
except IndexError:
continue
fit.set_visibility(checkbox.isChecked())
return self.model_fits
def _get_save_filename(self) -> str:
"""
Query user for output filename.
Returns
-------
filename : str
The output path to save to.
"""
init_directory = os.path.expanduser('~')
filename = QtWidgets.QFileDialog.getSaveFileName(self, 'Save File',
init_directory)[0]
return os.path.normpath(filename) if filename else filename
@staticmethod
def _write_parameters(parameters: List[List[Union[str, float]]],
filename: str) -> None:
"""
Write fit parameter rows to a CSV file.
Parameters
----------
parameters : list of list
Fit parameters flattened into a single list, where
each element is a list of row values.
filename : str
Output file path.
"""
header = list(parameters[0].keys())
to_remove = ['axis', 'visible', 'model_id']
for key in to_remove:
try:
header.remove(key)
except ValueError: # pragma: no cover
# this shouldn't be hit, since these keys are always
# in the parameters as passed
continue
with open(filename, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(header)
for param in parameters:
row = list()
for key in header:
try:
value = param[key]
except KeyError:
value = 'NA'
row.append(value)
writer.writerow(row)
[docs]
def change_units(self, units: Dict[str, str],
panes: List[PT] = None,
return_new: Optional[bool] = False
) -> Optional[List[model_fit.ModelFit]]:
"""
Change the units of fits.
The 'axes' parameter of each model fit loaded in model_fits is
compared to the 'axes' of each Pane in `panes` to verify the
model should be updated. The fit along with the current and
desired units are passed to `uc.convert_model_fit` to perform
the actual conversion. If the conversion fails (more common with
changing flux units) then the model fit is left unchanged, but is
hidden so it will not affect the autoscaling of the new Pane.
The table widget is updated with the new units or visibility
settings at the end.
Parameters
----------
units : dict
The new units to apply. Keys are the axis directions 'x', 'y'.
Values are the string representations of the new unit,
like 'um' or 'Jy'.
panes : list, optional
List of Pane objects to apply the fit changes to. If
not provided, all fits will be updated.
return_new : bool, optional
If true, return the new fit parameters. Defaults to False.
Returns
-------
parameters: dict
Details of the new fit parameters.
"""
for model in self.model_fits:
# if panes are provided, check for match to fit
# and take end units from the pane, since actual pane
# units might not match desired end units (eg. for
# multiple panes with incompatible flux units)
if panes is not None:
matching_pane = False
for p in panes:
if p is not None:
if model.axis in p.axes():
matching_pane = True
ending = {'x': p.units['x'],
'y': p.units['y']}
break
# no matching pane for this fit, skip it
if not matching_pane:
continue
else:
ending = units
starting = model.get_units()
fit = model.get_fit()
wave = model.get_mid_point()
try:
uc.convert_model_fit(fit, starting, ending, wave)
except (ValueError, AttributeError) as e:
log.debug(e)
self.hide_fit(model)
continue
model.set_fit(fit)
model.set_units(ending)
lower_limit = uc.convert_wave(model.get_limits('lower'),
starting['x'], ending['x'])
upper_limit = uc.convert_wave(model.get_limits('upper'),
starting['x'], ending['x'])
model.set_limits({'lower': lower_limit, 'upper': upper_limit})
self._update_table()
if return_new:
return self.gather_models()
[docs]
@staticmethod
def hide_fit(fit: model_fit.ModelFit) -> None:
"""Set the fit visibility to False."""
fit.set_visibility(False)
[docs]
def hide_all_fits(self) -> None:
"""Set all fit visibilities to False"""
for fit in self.model_fits:
fit.set_visibility(False)
self._update_table()
[docs]
def update_colors(self, updates):
"""
Update plot colors.
Parameters
----------
updates : dict
Color updates to apply. Must include 'line' kind,
matching model ID and mid-model id, for the 'color'
update to be applied.
"""
fits = self.currently_plotted
for pane_, update in updates.items():
for up in update:
if up.kind == 'line':
for fit in fits:
mid_model = f'{fit.order}.{fit.aperture}'
checks = [fit.model_id == up.model_id,
mid_model == up.mid_model]
if all(checks):
fit.color = up.updates['color']
self._update_figure(current_only=True)