Source code for sofia_redux.instruments.exes.coadd

# Licensed under a 3-clause BSD style license - see LICENSE.rst

import warnings

from astropy import log
import numpy as np

from sofia_redux.instruments.exes import utils, make_template
from sofia_redux.toolkit.image.combine import combine_images

__all__ = ['coadd']


[docs] def coadd(data, header, flat, variance, illum=None, weights=None, good_frames=None, weight_mode=None, std_wt=False, threshold=0): """ Combine individual frames to increase signal-to-noise. First, a template of the spectrum is created from all frames, averaging in the spectral direction. By default, all input frames are weighted by their correlation with this spatial template, so that if a frame is unusually noisy had some other error, it will not contribute significantly to the coadded frame. If this is not desired, an unweighted coadd can be performed, or the user can directly specify the weights to use. The weighted frames are then summed to effect a weighted mean of the input data. If the threshold parameter is provided, outlier rejection is additionally performed before the weighted mean. If provided, the variance is propagated accordingly. Parameters ---------- data : numpy.ndarray 3D data cube [nframe, ny, nx]. Input data is assumed to be distortion corrected and rotated as needed to align the spectral direction along the x-axis. header : fits.Header Header of FITS file. Will be updated in place. flat : numpy.ndarray 2D processed flat, as produced by makeflat [ny, nx]. variance : numpy.ndarray 3D variance cube [nframe, ny, nx]. Will be updated in place. illum : numpy.ndarray of int, optional Indicates illuminated regions of the frame [ny, nx]. 1=illuminated, 0=unilluminated, -1=pixel that does not correspond to any region in the raw frame. weights : array-like, optional Array [nframe] of fractional weights to use for input frames. If a weight is set to zero the frame will be skipped. Must add up to 1.0. If not provided, weights will be calculated for all frames. good_frames : array-like, optional Array of indices of good frames, max `nframe` length. If provided, any frame not in `good_frames` will be skipped. weight_mode : {'unweighted', 'useweights', None}, optional If 'unweighted', all good frames will be given equal weight. If 'useweights', the `weights` array will be used to weight the frames. Otherwise, weights will be calculated from a correlation with the spatial template. std_wt : bool, optional If set, frames will be weighted by the square root of their variance planes. Otherwise, they will be weighted by the flat. threshold : float, optional If >0, will be used as an outlier rejection threshold in the mean combination. Returns ------- coadded_data, coadded_variance : numpy.ndarray, numpy.ndarray The coadded data and variance, both [ny,nx]. """ params = _verify_inputs(data, header, flat, variance, illum=illum, weights=weights, good_frames=good_frames, weight_mode=weight_mode, std_wt=std_wt, threshold=threshold) if params['nz'] == 1: log.info('Only 1 frame available; no coadd performed.') return data[0], variance[0] _determine_weighting_method(params) _generate_template(params) log.info('Coadding diffs') _calculate_weights(params) _combine_data(params) _update_integration_time(params) return params['coadded'], params['coadded_var']
def _verify_inputs(data, header, flat, variance, illum, weights, good_frames, weight_mode, std_wt, threshold): """ Check input shape and values. Later functions in this step take the output params dictionary as input and add to it or update it as output. Parameters ---------- data : numpy.ndarray header : fits.Header flat : numpy.ndarray variance : numpy.ndarray illum : numpy.ndarray weights : array-like of float good_frames : array-like of int weight_mode : str std_wt : bool threshold : float Returns ------- params : dict Contains all input data, reformatted as needed. Keys are: - 'data': data array - 'variance': variance array - 'header': header object - 'nx': data shape, x-direction - 'ny': data shape, y-direction - 'illum': illumination array - 'weights': list of weights for frames - 'good_frames': input list of good frame indices - 'weight_mode': input weight mode - 'std_wt': input standard deviation weight flag - 'threshold': input robust threshold - 'flat': flat array - 'suball': list of good frames indices - 'zwt_sum': sum of frame weights """ nx = header['NSPAT'] ny = header['NSPEC'] params = {'data': data, 'variance': variance, 'header': header, 'nx': nx, 'ny': ny, 'illum': illum, 'weights': weights, 'good_frames': good_frames} try: nz = utils.check_data_dimensions(params=params) except RuntimeError: raise ValueError(f'Data has wrong dimensions {data.shape}. ' f'Not coadding images.') from None params['nz'] = nz try: utils.check_variance_dimensions(variance, nx, ny, nz) except RuntimeError: raise ValueError(f'Variance has wrong dimensions {variance.shape}. ' f'Not coadding images.') from None if params['illum'] is None: params['illum'] = np.ones((ny, nx)) # Check which frames are good all_frames = np.arange(nz) if params['good_frames'] is None: params['good_frames'] = np.arange(nz) _, suball, _ = np.intersect1d(all_frames, params['good_frames'], return_indices=True) if len(suball) == 0: raise ValueError('No good frames. Not coadding images.') # Weight good frames equally if not provided if params['weights'] is None or len(params['weights']) != nz: weights = np.zeros(nz) weights[suball] = 1 else: weights = params['weights'] zwt_sum = np.nansum(np.abs(weights)) if zwt_sum == 0: raise ValueError('All weights are zero. Not coadding images.') params.update({'weight_mode': weight_mode, 'std_wt': std_wt, 'threshold': threshold, 'flat': flat, 'weights': weights, 'suball': suball, 'zwt_sum': zwt_sum}) return params def _determine_weighting_method(params): """Determine weighting method from input parameters.""" unweighted = False do_weights = True if params['weight_mode'] is not None: if params['weight_mode'].lower() in ['unweighted']: unweighted = True elif params['weight_mode'].lower() in ['useweights', 'use_weights']: do_weights = False else: pass if unweighted or params['nz'] < 4: log.info('Doing unweighted addition of pairs.') unweighted = True # Initially weight good frames equally if not provided if params['weights'] is None or len(params['weights']) != params['nz']: params['weights'] = np.zeros(params['nz']) params['weights'][params['suball']] = 1 do_weights = True zwt_sum = np.sum(np.abs(params['weights'])) if zwt_sum == 0: raise ValueError('All weights are zero. ' 'Not coadding images.') if not do_weights and not np.allclose(zwt_sum, 1): raise ValueError('Weights do not add up to 1. ' 'Not coadding images.') params['unweighted'] = unweighted params['do_weights'] = do_weights params['zwt_sum'] = zwt_sum def _generate_template(params): """Make a spatial template from the input data.""" if params['std_wt']: weight_frame = np.sqrt(params['variance']) else: weight_frame = params['flat'] if params['unweighted']: template = None else: template = make_template.make_template( params['data'], params['header'], weight_frame, good_frames=params['good_frames'], illum=params['illum']) if template is None: log.error('Problem making template. Using unweighted coadd.') params['unweighted'] = True params['template'] = template params['weight_frame'] = weight_frame def _calculate_weights(params): """Calculate weights from correlation with spatial template.""" # Weight in proportion to S/N in spectrum extracted by # multiplying by template # Get good data from illum weights = params['weights'] if params['do_weights']: gz = (params['illum'] == 1) & (params['flat'] != 0) params['gz'] = gz params['bz'] = ~params['gz'] if np.sum(gz) == 0: raise ValueError('No good data. Not coadding images.') for i in range(params['nz']): # Check if weights should be calculated or just used as is if params['weights'][i] == 0: continue # Check if frame should be multiplied by -1 if params['weights'][i] < 0: sign = -1 else: sign = 1 if params['unweighted']: weights[i] = sign continue t = params['template'][gz] a = (params['data'][i] * sign)[gz] if params['std_wt']: s2 = (params['weight_frame'][i] ** 2)[gz] else: s2 = (params['weight_frame'] ** 2)[gz] sum1 = np.nansum(t * a / s2) sum2 = np.nansum(t ** 2 / s2) wti = sum1 / sum2 if wti == 0 or not np.isfinite(wti): log.warning(f'Correlation zero on pair {i + 1}') wti = 0 elif wti < 0.1: log.warning(f'Correlation negative on pair {i + 1}') wti = 0 weights[i] = sign * wti sum_wt = np.nansum(np.abs(weights)) sum_wt_sq = np.nansum(np.abs(weights) ** 2) if (params['do_weights'] or params['unweighted']) and sum_wt > 0: weights = weights / sum_wt wt_max = np.nanmax(weights) idx = weights == 0 if np.sum(idx) > 0: log.info(f'{np.sum(idx)} pair(s) given zero weight') if np.sum(idx) == params['nz']: raise ValueError('All weights zero. Not coadding images.') params['sum_wt'] = sum_wt params['sum_wt_sq'] = sum_wt_sq params['weights'] = weights params['wt_max'] = wt_max def _combine_data(params): """Combine data with weighted mean.""" # weights are by frame at this point: expand them to match data log.info(f'Weights: {params["weights"]}') weight_image = np.zeros_like(params['data']) with warnings.catch_warnings(): warnings.simplefilter('ignore', RuntimeWarning) weight_image[:] = 1 / params['weights'][:, None, None] if params['threshold'] > 0: robust = True sigma = params['threshold'] log.info(f'Performing robust mean with threshold {sigma}.') else: robust = False sigma = None coadded = combine_images(params['data'], variance=weight_image, method='mean', weighted=True, robust=robust, returned=False, sigma=sigma) _, coadded_var = combine_images(params['data'], variance=params['variance'], method='mean', weighted=True, robust=robust, returned=True, sigma=sigma) params['coadded'] = coadded params['coadded_var'] = coadded_var def _update_integration_time(params): """Update integration time for combined data.""" int_time = (params['header']['BEAMTIME'] * params['nz'] * params['header']['NINT']) instmode = str(params['header'].get('INSTMODE', 'UNKNOWN')).strip().upper() if instmode == 'NOD_ON_SLIT': int_time *= 2 log.info(f'Total on-source integration time: {int_time}') params['header']['EXPTIME'] = int_time params['header']['NEXP'] = params['nz'] * params['header']['NINT']