[docs]
class Drawing(object):
"""
Class to hold an individual matplotlib artist.
"""
def __init__(self, **kwargs):
self._artist = kwargs.get('artist', None)
self._high_model = str(kwargs.get('high_model', ''))
self._mid_model = str(kwargs.get('mid_model', ''))
self._data_id = str(kwargs.get('data_id', '')) # Wavelength of ref
self._model_id = kwargs.get('model_id', None) # UUID
self._kind = kwargs.get('kind', '')
self._pane = kwargs.get('pane', None)
self._axes = kwargs.get('axes', 'primary')
self._fields = self._parse_fields(kwargs.get('fields'))
self._label = kwargs.get('label', '')
self._updates = kwargs.get('updates', dict())
self._state = 'new'
self._update = False
if self.artist:
self.populate_properties(kwargs)
def __eq__(self, other):
if isinstance(other, Drawing):
checks = [self.match_high_model(other.get_high_model()),
self.match_kind(other.get_kind()),
self.match_mid_model(other.get_mid_model()),
self.match_data_id(other.get_data_id()),
self.match_pane(other.get_pane()),
self.match_fields(other.get_fields()),
self.match_axes(other.get_axes())
]
return all(checks)
return False
@staticmethod
def _type_error(type_, field):
raise TypeError(f'Improper type {type_} for {field}.')
@staticmethod
def _value_error(value, field):
raise ValueError(f'Invalid value {value} for {field}.')
@property
def artist(self):
"""
matplotlib.collections.PolyCollection : Matplotlib artist.
Matplotlib artist appropriate to the data type, such as
a rectangle for the border or a line for the data plot.
"""
return self._artist
@artist.setter
def artist(self, artist):
if isinstance(artist, mart.Artist):
self._artist = artist
else:
self._type_error(type(artist), 'artist')
@property
def high_model(self):
"""str : High-level model name for the drawing."""
return self._high_model
@high_model.setter
def high_model(self, model):
self._high_model = str(model)
@property
def mid_model(self):
"""str : Mid-level model name for the drawing."""
return self._mid_model
@mid_model.setter
def mid_model(self, model):
self._mid_model = str(model)
@property
def model_id(self) -> uuid.UUID:
"""
uuid.UUID : Model ID for the drawing.
Unique id associated with a single input file.
"""
return self._model_id
@model_id.setter
def model_id(self, model_id: uuid.UUID) -> None:
self._model_id = model_id
@property
def kind(self):
"""
str : Classification category for the drawing.
Available kinds:
- line
- border
- error
- crosshair
- guide
- cursor
- fit_line
- fit_center
- text
- reference
- patch
"""
return self._kind
@kind.setter
def kind(self, kind):
self._kind = str(kind)
@property
def data_id(self):
"""
str : Reference data ID.
Labels reference data lines, e.g. with wavelength values.
"""
return self._data_id
@data_id.setter
def data_id(self, data_id):
self._data_id = str(data_id)
@property
def color(self):
"""str : Matplotlib color for artist."""
if self.artist is None:
return None
try:
return self.artist.get_color()
except AttributeError: # pragma: no cover
# doesn't seem reachable with current artists
return self.artist.get_facecolor()
@color.setter
def color(self, color):
if self.artist is None:
return
try:
self.artist.set_color(color)
except AttributeError: # pragma: no cover
# doesn't seem reachable with current artists
self.artist.set_facecolor(color)
@property
def visible(self):
"""bool : Visibility state for the artist"""
if self.artist is None:
return None
else:
return self.artist.get_visible()
@visible.setter
def visible(self, visible: bool):
if self.artist is not None:
self.artist.set_visible(visible)
@property
def marker(self):
"""str : Marker associated with the artist."""
if self.artist is None:
return None
else:
try:
return self.artist.get_marker()
except AttributeError:
return None
@marker.setter
def marker(self, marker: str):
if self.artist is not None:
try:
self.artist.set_marker(marker)
except AttributeError:
pass
@property
def state(self):
"""
str : Current state of the drawing.
Set to 'new' on initialization.
"""
return self._state
@state.setter
def state(self, state: str):
self._state = str(state)
@property
def pane(self):
"""Pane : Display pane containing the drawing."""
return self._pane
@pane.setter
def pane(self, pane):
self._pane = pane
@property
def axes(self):
"""
str : Axes containing the drawing artist.
May be 'primary' or 'alternate'.
"""
return self._axes
@axes.setter
def axes(self, axes):
if isinstance(axes, str):
axes = str(axes).strip().lower()
if axes in ['pri', 'p', 'primary']:
self._axes = 'primary'
elif axes in ['sec', 's', 'secondary', 'alt', 'alternate']:
self._axes = 'alt'
else:
self._value_error(axes, 'axes')
else:
self._type_error(type(axes), 'axes')
@property
def fields(self):
"""
dict : Plot fields associated with drawing axes.
Typical values are: 'wavepos', 'spectral_flux', 'spectral_error',
'transmission', 'response'.
"""
return self._fields
@fields.setter
def fields(self, fields):
new = self._parse_fields(fields)
self._fields.update(new)
@property
def label(self):
"""str : Label for the drawing."""
return self._label
@label.setter
def label(self, label):
self._label = str(label)
@property
def updates(self):
"""
dict : New updates to apply to the drawing.
Keys represent the axis being changed and the values are lists of
values to update.
"""
return self._updates
@updates.setter
def updates(self, updates):
if isinstance(updates, dict):
self._updates = updates
else:
self._type_error(type(updates), 'updates')
@property
def update(self):
"""bool : Flag to indicate an update is required."""
return self._update
@update.setter
def update(self, update):
self._update = bool(update)
[docs]
def update_options(self, other, kind='default') -> bool:
"""
Update plot options from another drawing.
Parameters
----------
other : Drawing
Drawing to copy updates from.
kind : str
Kind of drawing to apply updates to. If default, all available
updates are applied.
Returns
-------
success : bool
True if an update is found and applied; False otherwise.
"""
if not isinstance(other, Drawing):
raise RuntimeError('Can only update with another Drawing')
props = dict()
updates = other.get_updates()
if kind in ['default', 'cursor', 'fit', 'ref', 'border',
'error_range', 'line', 'patch']:
props['color'] = updates.get('color', self.color)
if kind in ['default', 'fit', 'ref', 'line', 'patch', 'border',
'error_range']:
props['visible'] = bool(updates.get('visible', self.visible))
if kind in ['default', 'line']:
props['marker'] = updates.get('marker', self.marker)
if props:
self._artist.update(props)
return True
return False
[docs]
def clear_updates(self) -> None:
"""Clear any existing updates."""
self.updates = dict()
[docs]
def populate_properties(self, kwargs: Dict):
"""
Set standard properties for the artist
Parameters
----------
kwargs : dict
Keys may be: 'visible', 'color', 'marker', 'alpha', 'linestyle'.
Values must be appropriate for the Matplotlib artist for the
associated property.
"""
props = dict()
artist_props = ['visible', 'color', 'marker', 'alpha', 'linestyle']
for key in artist_props:
try:
props[key] = kwargs[key]
except KeyError:
continue
self.artist.update(props)
def _parse_fields(self, fields):
"""Parse plot fields by axis from input list or dict."""
if fields is None:
return dict()
if isinstance(fields, dict):
return fields
elif isinstance(fields, list):
parsed = dict()
for ax, field in zip(['x', 'y', 'y_alt'], fields):
parsed[ax] = field
return parsed
else:
self._type_error(type(fields), 'fields')
[docs]
def matches(self, other, strict=False):
"""
Check if this drawing matches another.
Required attributes for matching are model_id, kind,
mid_model, and data_id. If `strict` is set, pane,
fields, and axes must additionally match.
Parameters
----------
other : Drawing
The other drawing to compare against.
strict : bool
If True, drawings must match exactly.
Returns
-------
match : bool
True if drawings match; False otherwise.
"""
# Change to compare UUID instead of high_model/mid_model
checks = [self.match_id(other.model_id),
self.match_kind(other.kind),
self.match_mid_model(other.mid_model),
self.match_data_id(other.data_id)]
if strict:
checks.extend([self.match_pane(other.pane),
self.match_fields(other.fields),
self.match_axes(other.axes)])
return all(checks)
[docs]
def match_id(self, model_id: uuid.UUID) -> bool:
"""
Match model IDs.
Parameters
----------
model_id : uuid.UUID
Model ID to compare to this drawing's model_id attribute.
Returns
-------
success : bool
True if model IDs match; False otherwise.
"""
match = model_id == self._model_id
return match
[docs]
def match_high_model(self, name):
"""
Match high models.
Parameters
----------
name : str
Model name to compare to this drawing's high_model attribute.
Returns
-------
success : bool
True if model IDs match; False otherwise.
"""
match = str(name).lower() in self.high_model.lower()
return match
[docs]
def match_kind(self, kind):
"""
Match drawing kinds.
Parameters
----------
kind : str
Kind to compare to this drawing's kind attribute.
Returns
-------
success : bool
True if kinds match; False otherwise.
"""
if 'fit' in kind and 'fit' in self.kind:
return True
elif 'error' in kind and 'error' in self.kind:
return True
else:
return kind == self.kind
[docs]
def match_mid_model(self, name):
"""
Match mid models.
Parameters
----------
name : str
Model name to compare to this drawing's mid_model attribute.
For multi-order spectra, mid-model is formatted as
<order>.<aperture>. Both must match.
Returns
-------
success : bool
True if model IDs match; False otherwise.
"""
try:
re.match(r'\d+\.\d+', self.mid_model)[0]
except TypeError:
try:
other = int(name)
this = int(self.mid_model)
except ValueError:
match = str(name).lower() in self.mid_model.lower()
else:
match = this == other
else:
match = name == self.mid_model
return match
[docs]
def match_data_id(self, data_id):
"""
Match data IDs.
Parameters
----------
data_id : str
Data ID to compare to this drawing's data_id attribute.
Returns
-------
success : bool
True if data IDs match; False otherwise.
"""
match = str(data_id).lower() == self.data_id.lower()
return match
[docs]
def match_pane(self, pane):
"""
Match panes.
Parameters
----------
pane : Pane
Pane to compare to this drawing's pane attribute.
Returns
-------
success : bool
True if panes match; False otherwise.
"""
return self.pane == pane
[docs]
def match_fields(self, fields):
"""
Match fields.
Parameters
----------
fields : list or dict
Fields to compare to this drawing's fields attribute.
Returns
-------
success : bool
True if fields match; False otherwise.
"""
checks = list()
if isinstance(fields, dict):
for ax, field in fields.items():
try:
checks.append(field == self.fields[ax])
except KeyError:
checks.append(False)
elif isinstance(fields, list):
for field, ax in zip(fields, ['x', 'y', 'y_alt']):
checks.append(field == self.fields[ax])
return all(checks)
[docs]
def match_axes(self, axes):
"""
Match axes.
If either value is 'any', True is always returned.
Parameters
----------
axes : str
Axes to compare to this drawing's axes attribute.
Returns
-------
success : bool
True if fields match; False otherwise.
"""
if axes == 'any' or self.axes == 'any':
return True
else:
return self.axes == axes
[docs]
def match_text(self, artist):
"""
Match text.
Parameters
----------
artist : matplotlib.Artist
Artist containing text values, retrievable via get_text().
Returns
-------
success : bool
True if text values match; False otherwise.
"""
if isinstance(self.artist, type(artist)):
try:
return str(self.artist.get_text()) == str(artist.get_text())
except AttributeError:
return False
else:
return False
[docs]
def apply_updates(self, updates):
"""
Apply updates to the current drawing.
Parameters
----------
updates : dict
Updates to apply.
"""
self.set_data(updates)
[docs]
def set_data(self, data=None, axis: Optional[str] = None,
update: Optional = None):
"""
Set data for the current artist.
Parameters
----------
data : array-like, optional
If provided, may be used to directly set the data for the artist.
axis : {'x', 'y'}, optional
Specifies the axis to set data for.
update : dict, optional
Keys may be 'x_data', 'y_data', 'artist'. If 'x_data' or
'y_data' are provided, they override the `data` and `axis`
inputs.
"""
artist = None
if update is not None:
try:
data = update.updates['x_data']
except KeyError:
try:
data = update.updates['y_data']
except KeyError:
artist = update.updates['artist']
else:
axis = 'y'
except AttributeError:
try:
data = update['x_data']
except KeyError:
try:
data = update['y_data']
except KeyError:
artist = update['artist']
else:
axis = 'y'
else:
axis = 'x'
else:
axis = 'x'
if isinstance(self.artist, Line2D):
self._set_line_data(data=data, axis=axis, artist=artist)
elif isinstance(self.artist, PathCollection):
self._set_scatter_data(data=data, axis=axis, artist=artist)
else:
log.debug(f'Unable to process artist type '
f'{type(self.artist)}')
def _set_line_data(self, data=None, axis=None, artist=None):
try:
if data is not None and axis is not None:
props = {f'{axis}data': data}
else:
props = {'xdata': artist.get_xdata(),
'ydata': artist.get_ydata()}
self.artist.update(props)
except AttributeError:
pass
def _set_scatter_data(self, data=None, axis=None, artist=None):
current_data = self.get_artist().get_offsets()
if data is not None and axis is not None:
if axis == 'all':
new_data = mask.array(data)
else:
if axis == 'x':
x_data = data
y_data = current_data[:, 1]
elif axis == 'y':
x_data = current_data[:, 0]
y_data = data
else:
x_data = current_data[:, 0]
y_data = current_data[:, 1]
new_data = mask.array(np.vstack((x_data, y_data)).T)
elif artist is not None:
new_data = artist.get_offsets()
self.artist.set_offsets(new_data)
[docs]
def update_line_fields(self, update): # pragma: no cover
"""
Update line fields.
Currently has no effect.
"""
# todo - implement or remove placeholder
pass
[docs]
def in_pane(self, pane, alt=False) -> bool:
"""
Check if current drawing is in specified pane.
Parameters
----------
pane : Pane
Pane instance to check.
alt : bool, optional
If set, alternate axes are checked as well as the primary.
Returns
-------
bool
True if artist is in specified pane; False otherwise.
"""
if (self.artist in pane.ax.get_children()
or (alt and self.artist in pane.ax_alt.get_children())):
return True
else:
return False
[docs]
def set_artist(self, artist):
"""Set the `artist` attribute."""
self.artist = artist
[docs]
def set_high_model(self, high_model):
"""Set the `high_model` attribute."""
self.high_model = str(high_model)
[docs]
def set_mid_model(self, mid_model):
"""Set the `mid_model` attribute."""
self.mid_model = str(mid_model)
[docs]
def set_data_id(self, data_id):
"""Set the `data_id` attribute."""
self.data_id = str(data_id)
[docs]
def set_model_id(self, model_id):
"""Set the `model_id` attribute."""
self.model_id = str(model_id)
[docs]
def set_kind(self, kind):
"""Set the `kind` attribute."""
self.kind = kind
[docs]
def set_pane(self, pane):
"""Set the `pane` attribute."""
self.pane = pane
[docs]
def set_axes(self, axes):
"""Set the `axes` attribute."""
self.axes = axes
[docs]
def set_label(self, label):
"""Set the `label` attribute."""
self.label = label
[docs]
def set_state(self, state):
"""Set the `state` attribute."""
self.state = state
[docs]
def set_update(self, update):
"""Set the `update` attribute."""
self.update = update
[docs]
def set_updates(self, updates):
"""Update the `updates` attribute."""
self.updates.update(updates)
[docs]
def set_fields(self, fields):
"""Set the `fields` attribute."""
self.fields = fields
[docs]
def set_visible(self, visible):
"""Set the `visible` attribute."""
self.visible = visible
[docs]
def set_color(self, color):
"""Set the `color` attribute."""
self.color = color
[docs]
def set_marker(self, marker):
"""Set the `marker` attribute."""
self.marker = marker
[docs]
def get_artist(self):
"""Get the `artist` attribute."""
return self.artist
[docs]
def get_high_model(self):
"""Get the `high_model` attribute."""
return self.high_model
[docs]
def get_mid_model(self):
"""Get the `mid_model` attribute."""
return self.mid_model
[docs]
def get_data_id(self):
"""Get the `data_id` attribute."""
return self.data_id
[docs]
def get_model_id(self):
"""Get the `model_id` attribute."""
return self.model_id
[docs]
def get_kind(self):
"""Get the `kind` attribute."""
return self.kind
[docs]
def get_pane(self):
"""Get the `pane` attribute."""
return self.pane
[docs]
def get_axes(self):
"""Get the `axes` attribute."""
return self.axes
[docs]
def get_label(self):
"""Get the `label` attribute."""
return self.label
[docs]
def get_state(self):
"""Get the `state` attribute."""
return self.state
[docs]
def get_update(self):
"""Get the `update` attribute."""
return self.update
[docs]
def get_updates(self):
"""Get the `updates` attribute."""
return self.updates
[docs]
def get_fields(self):
"""Get the `fields` attribute."""
return self.fields
[docs]
def get_visible(self):
"""Get the `visible` attribute."""
return self.visible
[docs]
def get_color(self):
"""Get the `color` attribute."""
return self.color
[docs]
def get_marker(self):
"""Get the `marker` attribute."""
return self.marker
[docs]
def get_linestyle(self):
"""Get the linestyle associated with the artist."""
if self.artist:
return self.artist.get_linestyle()
else:
return None
[docs]
def set_animated(self, state):
"""Set the animated state for the artist."""
self._artist.set_animated(state)
[docs]
def get_animated(self):
"""Get the animated state for the artist."""
if self._artist:
return self._artist.get_animated()
else:
return None
[docs]
@staticmethod
def convert_to_scatter(line_artist: Line2D,
marker: str) -> PathCollection:
"""
Convert a line plot to a scatter plot.
Parameters
----------
line_artist : Line2D
The line artist to replace.
marker : str
The marker symbol to use in the plot. If None,
the default 'o' symbol is used.
Returns
-------
scatter_artist : PathCollection
The new scatter plot artist.
"""
x, y = line_artist.get_data()
color = line_artist.get_color()
ax = line_artist.axes
label = line_artist.get_label()
if marker is None:
marker = 'o'
scatter_artist = ax.scatter(x, y, color=color, label=label,
animated=True, marker=marker)
return scatter_artist
[docs]
def convert_to_line(self, drawstyle: str, marker: str) -> None:
"""
Convert a scatter plot to a line plot.
Parameters
----------
drawstyle : str
The line drawing style for the new line plot.
Should be 'line' or 'step'.
marker : str
The marker symbol to use in the plot. If None,
the default 'o' symbol is used.
"""
data = self.artist.get_offsets()
color = self.artist.get_facecolor()[0]
ax = self.artist.axes
label = self.artist.get_label()
line_artist = ax.plot(data[:, 0], data[:, 1], c=color,
label=label, animated=True,
marker=marker)[0]
if drawstyle == 'line':
line_artist.set_drawstyle('default')
else:
line_artist.set_drawstyle('steps-mid')
self.artist.remove()
self.artist = line_artist
[docs]
def remove(self):
"""Remove an artist from the plot and from this drawing."""
self._artist.remove()
self._artist = None