Source code for sofia_redux.instruments.exes.spatial_shift

# Licensed under a 3-clause BSD style license - see LICENSE.rst

from astropy import log
import numpy as np

from sofia_redux.instruments.exes import utils
from sofia_redux.instruments.exes.make_template import make_template
from sofia_redux.toolkit.utilities.fits import set_log_level

__all__ = ['spatial_shift']


[docs] def spatial_shift(data, header, flat, variance, illum=None, good_frames=None, sharpen=False): """ Shift spectra for spatial alignment. A shift is derived based on correlation with average spatial profile from all frames. The spatial template is generated using `make_template`. For each input frame, the image is collapsed over the spectral dimension, weighting by an inverse flat (squared), shifted by a small amount, then added to the collapsed spatial template. This signal is calculated for all spatial shifts of up to 4 pixels: the spatial shift chosen is the one that maximizes the signal. Alternately, if sharpen is not set, the value maximized is the signal-to-noise (the above signal, divided by the collapsed uncertainty). The image is then shifted by the calculated value in the spatial direction, i.e. along the y-axis. Input data is assumed to be undistorted and rotated to orient spectra along the x-axis. Parameters ---------- data : numpy.ndarray Data cube, [nframe, nspec, nspat]. header : fits.Header Header from FITS file. flat : numpy.ndarray Processed flat image, [nspec, nspat]. variance : numpy.ndarray Variance planes corresponding to data, [nframe, nspec, nspat]. illum : numpy.ndarray Indicating illuminated regions of the frame, [nspec, nspat]. 1=illuminated, 0=unilluminated, -1=pixel that does not correspond to any region in the raw frame. good_frames : array-like, optional Array of indices of good frames. If provided, any frame not in `good_frames` will be skipped. sharpen : bool, optional If set, signal will be maximized, rather than the signal-to-noise ratio. Returns ------- data, variance : 2-tuple of numpy.ndarray Shifted data cube [nframe, nspec, nspat] and updated variance data [nframe, nspec, nspat]. """ params = _verify_inputs(data, header, flat, variance, illum=illum, good_frames=good_frames, sharpen=sharpen) log.info('Shifting data to match first image') with set_log_level('WARNING'): _make_all_templates(params) _find_shift(params) shifted = _shift_data(params) return shifted
def _verify_inputs(data, header, flat, variance, illum, good_frames, sharpen): """Check and assemble input data and options.""" # Retrieve data dimensions ny = header['NSPEC'] nx = header['NSPAT'] try: nz = utils.check_data_dimensions(data=data, nx=nx, ny=ny) except RuntimeError: raise RuntimeError(f'Data has wrong dimensions ({data.shape}).' f'Not shifting images.') # Store the order height as well n_slit = header.get('SLTH_PIX', ny) # Check that there are at least some good frames all_frames = np.arange(nz) if good_frames is None: good_frames = np.arange(nz) _, suball, _ = np.intersect1d(all_frames, good_frames, return_indices=True) if len(suball) == 0: raise RuntimeError('No good frames. Not shifting images.') if illum is None: illum = np.ones_like(data[0]) params = {'data': data, 'header': header, 'flat': flat, 'variance': variance, 'illum': illum, 'good_frames': good_frames, 'sharpen': sharpen, 'nx': nx, 'ny': ny, 'nz': nz, 'ns': n_slit} return params def _make_all_templates(params): """Make spatial templates for all input frames.""" header = params['header'].copy() weight_frame = params['flat'] ** 2 illum = params['illum'] good_frames = params['good_frames'] data_templates = [] std_templates = [] for i, frame in enumerate(params['data']): if i not in good_frames: data_templates.append(None) std_templates.append(None) else: # Make a collapsed template from the single frame template = make_template(frame, header, weight_frame, illum=illum, collapsed=True) data_templates.append(template) # Also one from the variance template = make_template(params['variance'][i], header, weight_frame, illum=illum, collapsed=True) std_templates.append(np.sqrt(template)) params['data_templates'] = data_templates params['std_templates'] = std_templates # Also make a normalized unweighted template of the weight # frame itself, inverted template = make_template(weight_frame, header, np.ones_like(weight_frame), collapsed=True) params['weight_template'] = np.nansum(template) / template def _correlation(comparison_template, test_template, shift_array): """Shift and correlate template with reference.""" n_corr = shift_array.size corr = np.zeros(n_corr) for j in range(n_corr): signal_shift = _shift_1d_array(test_template, shift_array[j]) corr[j] = np.nansum((comparison_template + signal_shift) ** 2) return shift_array[np.argmax(corr)] def _shift_1d_array(xs, n): """Shift a 1D array.""" e = np.empty_like(xs) if n == 0: return xs elif n >= 0: e[:n] = xs[0] e[n:] = xs[:-n] else: e[n:] = xs[-1] e[:n] = xs[-n:] return e def _shift_2d_array(data, n): """Shift a 1D array.""" if n == 0: return data # Shift up/down e = np.roll(data, n, axis=0) if n >= 0: # Shift up: fill gap at bottom fill = data[0, :] e[:n, :] = np.expand_dims(fill, axis=0) else: # Shift down: fill gap at top fill = data[-1, :] e[n:, :] = np.expand_dims(fill, axis=0) return e def _find_shift(params): """Find optimum shift.""" # Array of shifts from -4 to 4 (usually) n_shift = int(np.min([4, params['ns'] / 3])) shift_array = np.arange(-n_shift, n_shift + 1, dtype=int) # How much to shift each frame (default 0) i_shift_arr = np.zeros(params['nz'], dtype=int) comparison_template = None for i in range(params['nz']): # Signal template test_template = params['data_templates'][i] if test_template is None: continue # Divide by noise template if desired if params['sharpen']: test_template /= params['std_templates'][i] # Weight by flat, prior to shifting, # to prioritize source trace # todo: check if intent is to correct for source shift # in slit or overall shift including slit - # if the latter, weighting should happen after shift test_template *= params['weight_template'] # Keep the first template to compare to if comparison_template is None: comparison_template = test_template continue # Find shift which maximizes contribution to S**2 or (S/N)**2, # checking all integer values between -n_shift and n_shift, inclusive best_shift = _correlation(comparison_template, test_template, shift_array) # Debug plots (not threadsafe!) # print('Frame ', i, best_shift) # from matplotlib import pyplot as plt # plt.plot(comparison_template) # plt.plot(test_template) # plt.plot(_shift_1D_array(test_template, best_shift)) # plt.show() # If best shift is pegged at either limit, skip it if best_shift == shift_array[0] or best_shift == shift_array[-1]: log.debug(f'Spatial shift out of range for pair {i}. ' f'Setting to 0.') else: i_shift_arr[i] = best_shift # Remove the mean shift if it's > 0 integer pixels mean_shift = int(np.mean(i_shift_arr)) log.debug(f'Initial shifts: {i_shift_arr}') log.debug(f'Mean shift: {mean_shift}') if abs(mean_shift) > 0: i_shift_arr -= mean_shift log.info(f'Derived shifts for all frames: {i_shift_arr}') params['derived_shifts'] = i_shift_arr def _shift_data(params): """Apply shift to data.""" shifted_data = np.full(params['data'].shape, np.nan) shifted_variance = np.full(params['data'].shape, np.nan) shifts = params['derived_shifts'] for i in range(params['nz']): data = params['data'][i] var = params['variance'][i] if i not in params['good_frames'] or shifts[i] == 0: shifted_data[i] = data shifted_variance[i] = var continue shifted_data[i] = _shift_2d_array(data, shifts[i]) shifted_variance[i] = _shift_2d_array(var, shifts[i]) return shifted_data, shifted_variance