Source code for skretrieval.retrieval.measvec

from __future__ import annotations

import fnmatch
from collections.abc import Callable
from dataclasses import dataclass

import numpy as np
import xarray as xr
from scipy import sparse
from scipy.linalg import block_diag
from simpleeval import simple_eval

from skretrieval.core.radianceformat import RadianceGridded


def _resolve_value(expr, variables):
    if isinstance(expr, str):
        expr = expr.replace("$", "")
        return simple_eval(expr, names=variables)
    return expr


@dataclass
class Measurement:
    """
    A dataclass representing the core objects of a measurement vector. This is an internal object
    that is passed around when doing measurement vector transformations
    """

    y: np.array
    K: np.array
    Sy: np.array


[docs] class MeasurementVector:
[docs] def __init__(self, fn: Callable, apply_to_filter="*", sample_fn=None): """ A class that represents a measurement vector. This is a callable object that can be used to transform L1 data to a measurement vector. Parameters ---------- fn : Callable Function which takes in L1 data and returns a Measurement object apply_to_filter : str, optional Only L1 data matching the apply_to_filter will be affected by this measurement vector, by default "*" """ self._sample_fn = sample_fn self._fn = fn self._filter = apply_to_filter self._enabled = True
@property def fn(self): return self._fn @property def filter(self): return self._filter @property def enabled(self): return self._enabled @enabled.setter def enabled(self, value: bool): self._enabled = value
[docs] def apply( self, l1_data: dict[RadianceGridded], ctxt: dict | None = None ) -> Measurement: """ Applies the function to the l1 data, returning back a Measurement object Parameters ---------- l1_data : dict[RadianceGridded] Returns ------- Measurement """ if not self._enabled: return None apply_vals = { k: d for k, d in l1_data.items() if fnmatch.fnmatch(k, self._filter) } if len(apply_vals) > 0: local_ctxt = ctxt if ctxt is not None else {} return self._fn(apply_vals, ctxt=local_ctxt, filter=self._filter) return None
[docs] def required_sample_wavelengths( self, obs_samples: dict[np.array] ) -> dict[np.array]: """ Determines which sample wavelengths are required for this measurement vector Default is to just return back all of the observation wavelengths Parameters ---------- obs_samples : dict[np.array] Returns ------- dict[np.array] """ if self._sample_fn is None: return obs_samples return self._sample_fn(obs_samples)
def pre_process(l1: dict[RadianceGridded], n: int = 1) -> dict[RadianceGridded]: """ Called before the measurement vector is applied. This function will ensure that the L1 data always has the necessary fields for the measurement vector to work. Parameters ---------- l1 : dict[RadianceGridded] n: int, optional Number of elements in the state vector, used to create the dummy Jacobian. Default is 1 which can be used if the number of elements in the state vector is not important. Returns ------- dict[RadianceGridded] """ # Copy and modify the data to always include wf and noise values new_l1 = {} for key, val in l1.items(): new_val = val.data.copy(deep=True) if "wf" not in new_val: new_val["wf"] = xr.zeros_like( new_val["radiance"].expand_dims({"x": n}, axis=-1) ) if "radiance_noise" not in new_val: new_val["radiance_noise"] = new_val["radiance"] * 1 new_l1[key] = RadianceGridded(new_val) return new_l1
[docs] def concat(measurements: list[Measurement]) -> Measurement: """ Concatenates a list of measurements into a single measurement Parameters ---------- measurements : list[Measurement] Returns ------- Measurement """ if len(measurements) == 0: return None return Measurement( y=np.concatenate([m.y for m in measurements]), K=np.vstack([m.K for m in measurements]), Sy=sparse.block_diag( [sparse.csc_matrix(m.Sy) for m in measurements], format="csc" ), )
def post_process(measurement: Measurement) -> dict: """ Called after the measurement vector is applied. This function will convert the measurement object back into a dictionary for the retrieval to use. Parameters ---------- measurement : Measurement Returns ------- dict """ # At this stage we have to remove the jacobian if it was a dummy one in the beginning res = {"y": measurement.y, "jacobian": measurement.K, "y_error": measurement.Sy} if measurement.K.shape[-1] == 0: del res["jacobian"] return res
[docs] def select(l1: dict[RadianceGridded], filter: str = "*", **kwargs) -> Measurement: """ Selects the L1 data that matches the filter and applies the selector stored in kwargs to the underlying xarray datasets Parameters ---------- l1 : dict[RadianceGridded] filter : str, optional by default "*" Returns ------- Measurement """ measurements = [] for key, val in l1.items(): if fnmatch.fnmatch(key, filter): selected = val.data.sel(**kwargs) measurements.append( Measurement( y=selected["radiance"].to_numpy().flatten(), K=selected["wf"].to_numpy().reshape((-1, len(selected["x"]))), Sy=sparse.csc_matrix( sparse.diags( selected["radiance_noise"].to_numpy().flatten() ** 2, 0 ) ), ) ) return concat(measurements)
def nearest_selector(l1: dict[RadianceGridded], filter: str = "*", **kwargs) -> dict: """ A special selector that will select the nearest value to the kwargs in the L1 data. Returns back another dictionary with the same keys as the input dictionary but with the data modified to only contain the nearest values to the kwargs Parameters ---------- l1 : dict[RadianceGridded] filter : str, optional , by default "*" Returns ------- dict """ res = {} for key, val in l1.items(): if fnmatch.fnmatch(key, filter): res[key] = RadianceGridded( val.data.sel(**kwargs, method="nearest").assign_coords(**kwargs) ) return res
[docs] def log(measurement: Measurement) -> Measurement: """ Log transform the measurement Parameters ---------- measurement : Measurement Returns ------- Measurement """ return Measurement( y=np.log(measurement.y), K=measurement.K / measurement.y[:, np.newaxis], Sy=measurement.Sy / np.outer(measurement.y, measurement.y), )
[docs] def mean(measurement: Measurement) -> Measurement: """ Take the mean of the measurement Parameters ---------- measurement : Measurement Returns ------- Measurement """ return Measurement( y=np.mean(measurement.y), K=np.mean(measurement.K, axis=0), Sy=np.mean(measurement.Sy.diagonal()), )
[docs] def multiply(measurement: Measurement, factor: float) -> Measurement: """ Multiply the measurement by a factor Parameters ---------- measurement : Measurement factor : float Returns ------- Measurement """ return Measurement( y=measurement.y * factor, K=measurement.K * factor, Sy=measurement.Sy * factor**2, )
[docs] def subtract(measurement: Measurement, other: Measurement) -> Measurement: """ Subtract one measurement from another Parameters ---------- measurement : Measurement other : Measurement Returns ------- Measurement """ return Measurement( y=measurement.y - other.y, K=measurement.K - other.K, Sy=( measurement.Sy.toarray() if sparse.issparse(measurement.Sy) else measurement.Sy ) + other.Sy, )
[docs] def add(measurement: Measurement, other: Measurement) -> Measurement: """ Add two measurements together Parameters ---------- measurement : Measurement other : Measurement Returns ------- Measurement """ return Measurement( y=measurement.y + other.y, K=measurement.K + other.K, Sy=( measurement.Sy.toarray() if sparse.issparse(measurement.Sy) else measurement.Sy ) + other.Sy, )
def wavelength_mean( l1: dict[RadianceGridded], filter: str = "*", **kwargs ) -> Measurement: """ Takes the mean over a wavelength band Parameters ---------- l1 : dict[RadianceGridded] filter : str, optional by default "*" Returns ------- Measurement """ measurements = [] for key, val in l1.items(): if fnmatch.fnmatch(key, filter): selected = val.data.sel(**kwargs).mean(dim="wavelength") measurements.append( Measurement( y=selected["radiance"].to_numpy().flatten(), K=selected["wf"].to_numpy().reshape((-1, len(selected["x"]))), Sy=sparse.csc_matrix( sparse.diags( selected["radiance_noise"].to_numpy().flatten() ** 2, 0 ) ), ) ) return concat(measurements)
[docs] class Triplet(MeasurementVector):
[docs] def __init__( self, wavelength: list[int], weights: list[float], altitude_range: list[float], normalization_range: list[float], normalize=True, log_space=True, **kwargs, ): """ A class that represents a measurement vector that is a weighted combination of log radiances, high altitude normalized Note that this measurement vector requires the l1 data to contain the "tangent_altitude" field. Both altitude_range and normalization_range can be set through the retrieval context by prefixing the value with a '$' Parameters ---------- wavelength : list[int] Wavelengths to select weights : list[float] Weights to apply to the wavelengths altitude_range : list[float] Altidude range to select normalization_range : list[float] Altitude range to normalize to """ self._wavelength = wavelength def y(l1, ctxt, **kwargs): res_altitude_range = [_resolve_value(v, ctxt) for v in altitude_range] res_norm_range = [_resolve_value(v, ctxt) for v in normalization_range] t_vals = [] for w, weight in zip(wavelength, weights): # Get the useful wavelength data if log_space: wavel_data = log( select( nearest_selector(l1, wavelength=w), tangent_altitude=slice( res_altitude_range[0], res_altitude_range[1] ), **kwargs, ) ) else: wavel_data = select( nearest_selector(l1, wavelength=w), tangent_altitude=slice( res_altitude_range[0], res_altitude_range[1] ), **kwargs, ) # And the normalization value norm_vals = mean( log( select( nearest_selector(l1, wavelength=w), tangent_altitude=slice( res_norm_range[0], res_norm_range[1] ), **kwargs, ) ) ) # The triplet value is the difference of the log radiances subtracted by the normalization multiplied by the weight if normalize: t_vals.append(multiply(subtract(wavel_data, norm_vals), weight)) else: t_vals.append(multiply(wavel_data, weight)) # Add all of the wavelengths together res = t_vals[0] for i in range(1, len(t_vals)): res = add(res, t_vals[i]) return res super().__init__(y, **kwargs)
[docs] def required_sample_wavelengths( self, obs_samples: dict[np.array] ) -> dict[np.array]: """ Determines which sample wavelengths are required for this measurement vector Default is to just return back all of the observation wavelengths Parameters ---------- obs_samples : dict[np.array] Returns ------- dict[np.array] """ all_wv = {} for key, val in obs_samples.items(): all_wv[key] = [] if fnmatch.fnmatch(key, self.filter): all_wv[key] = np.array( [val[np.abs(val - w).argmin()] for w in self._wavelength] ) return all_wv
class IntegratedLine(MeasurementVector): def __init__( self, central_wavelength: float, integration_range: float, baseline_range: float, **kwargs, ): self._left_boundary = central_wavelength - integration_range - baseline_range self._right_boundary = central_wavelength + integration_range + baseline_range def y(l1, ctxt, **kwargs): # noqa: ARG001 ta_s = slice(70000, 110000) integration_vals = wavelength_mean( l1, wavelength=slice( central_wavelength - integration_range, central_wavelength + integration_range, ), tangent_altitude=ta_s, ) baseline_left = wavelength_mean( l1, wavelength=slice( central_wavelength - integration_range - baseline_range, central_wavelength - integration_range, ), tangent_altitude=ta_s, ) baseline_right = wavelength_mean( l1, wavelength=slice( central_wavelength + integration_range, central_wavelength + integration_range + baseline_range, ), tangent_altitude=ta_s, ) baseline = multiply(add(baseline_left, baseline_right), 0.5) return subtract(integration_vals, baseline) super().__init__(y, **kwargs) def required_sample_wavelengths( self, obs_samples: dict[np.array] ) -> dict[np.array]: """ Determines which sample wavelengths are required for this measurement vector Default is to just return back all of the observation wavelengths Parameters ---------- obs_samples : dict[np.array] Returns ------- dict[np.array] """ all_wv = {} for key, val in obs_samples.items(): all_wv[key] = [] if fnmatch.fnmatch(key, self.filter): all_wv[key] = val[ (val > self._left_boundary) & (val < self._right_boundary) ] return all_wv class WavelengthAltitude(MeasurementVector): def __init__( self, wavelength_range: list[float], altitude_range: list[float], **kwargs, ): """ A measurement vector that selects all measurements inside wavelength and tangent altitude ranges. Both ranges can be set through the retrieval context by prefixing values with '$'. Parameters ---------- wavelength_range : list[float] Wavelength range to select as [min, max] altitude_range : list[float] Tangent altitude range to select as [min, max] """ self._wavelength_range = wavelength_range def y(l1, ctxt, **kwargs): res_wavelength_range = [_resolve_value(v, ctxt) for v in wavelength_range] res_altitude_range = [_resolve_value(v, ctxt) for v in altitude_range] return select( l1, wavelength=slice(res_wavelength_range[0], res_wavelength_range[1]), tangent_altitude=slice(res_altitude_range[0], res_altitude_range[1]), **kwargs, ) super().__init__(y, **kwargs) def required_sample_wavelengths( self, obs_samples: dict[np.array] ) -> dict[np.array]: """ Determines which sample wavelengths are required for this measurement vector. Parameters ---------- obs_samples : dict[np.array] Returns ------- dict[np.array] """ all_wv = {} # If range values are context-dependent, we cannot resolve static sampling here. if any(isinstance(v, str) for v in self._wavelength_range): return obs_samples left = self._wavelength_range[0] right = self._wavelength_range[1] for key, val in obs_samples.items(): all_wv[key] = [] if fnmatch.fnmatch(key, self.filter): all_wv[key] = val[(val >= left) & (val <= right)] return all_wv