Source code for skretrieval.retrieval.forwardmodel

from __future__ import annotations

import abc
from collections.abc import Callable

import numpy as np
import sasktran2 as sk

from skretrieval.core.lineshape import DeltaFunction, LineShape
from skretrieval.core.sasktranformat import SASKTRANRadiance
from skretrieval.core.sensor.spectrograph import SpectrographOnlySpectral
from skretrieval.geodetic import geodetic
from skretrieval.retrieval import ForwardModel

from .ancillary import Ancillary
from .measvec import MeasurementVector
from .observation import FilteredObservation, Observation
from .statevector.altitude import AltitudeNativeStateVector


class StandardForwardModel(ForwardModel):
    def __init__(
        self,
        observation: Observation,
        state_vector: AltitudeNativeStateVector,
        meas_vec: MeasurementVector,
        ancillary: Ancillary,
        engine_config: sk.Config,
        **kwargs,
    ) -> None:
        """
        A forward model for the Retrieval class.  This is a base class that should be inherited from.


        Parameters
        ----------
        observation : Observation
            The observation
        state_vector : AltitudeNativeStateVector
            The State Vector
        ancillary : Ancillary
            The Ancillary Object
        engine_config : sk.Config
            Configuration for the engine
        """
        ForwardModel.__init__(self)

        self._state_vector = state_vector
        self._engine_config = engine_config
        self._ancillary = ancillary
        self._observation = observation
        self._meas_vec = meas_vec

        self._viewing_geo = self._construct_viewing_geo()
        self._model_geometry = self._construct_model_geometry()
        self._model_wavelength = self._construct_model_wavelength()

        self._atmosphere = self._construct_atmosphere()

        self._engine = self._construct_engine()

        self._inst_model = self._construct_inst_model()

        self._solar_model = self._construct_solar_model()

    @abc.abstractmethod
    def _construct_model_geometry(self):
        pass

    @abc.abstractmethod
    def _construct_model_wavelength(self):
        pass

    @abc.abstractmethod
    def _construct_viewing_geo(self):
        pass

    @abc.abstractmethod
    def _construct_inst_model(self):
        pass

    def _construct_solar_model(self):
        pass

    def _construct_atmosphere(self):
        atmo = {}

        for key in self._model_geometry:
            atmo[key] = sk.Atmosphere(
                self._model_geometry[key],
                self._engine_config,
                wavelengths_nm=self._model_wavelength[key],
                pressure_derivative=False,
                temperature_derivative=False,
            )

            self._state_vector.add_to_atmosphere(atmo[key])
            self._ancillary.add_to_atmosphere(atmo[key])

            self._model_geometry[key].refractive_index = (
                sk.optical.refraction.ciddor_index_of_refraction(
                    atmo[key].temperature_k,
                    atmo[key].pressure_pa,
                    np.zeros_like(atmo[key].temperature_k),
                    450,
                    np.nanmean(atmo[key].wavelengths_nm),
                )
            )

        return atmo

    def _construct_engine(self):
        engines = {}

        for key in self._model_geometry:
            engines[key] = sk.Engine(
                self._engine_config, self._model_geometry[key], self._viewing_geo[key]
            )

        return engines

    def calculate_radiance(self):
        l1 = {}
        for key in self._engine:
            sk2_rad = self._engine[key].calculate_radiance(self._atmosphere[key])
            sk2_rad = self._state_vector.post_process_sk2_radiances(sk2_rad)
            sk2_rad = SASKTRANRadiance.from_sasktran2(sk2_rad)

            model_result = self._inst_model[key].model_radiance(sk2_rad, None)

            if isinstance(model_result, dict):
                if len(model_result) == 1:
                    l1[key] = next(iter(model_result.values()))
                else:
                    for k, v in model_result.items():
                        l1[f"{key}_{k}"] = v
            else:
                l1[key] = model_result

            self._observation.append_information_to_l1(l1)

        return l1


[docs] class SpectrometerMixin:
[docs] def __init__( self, lineshape_fn: Callable[[float], LineShape] | None = None, model_res_nm=0.02, model_res_cminv=0.02, spectral_native_coordinate="wavelength_nm", round_decimal=2, stokes_sensitivities=None, ) -> None: """ Mixin for adding a spectrometer to the forward model Parameters ---------- lineshape_fn : Callable[[float], LineShape] | None, optional Function that takes in wavelength in nm and returns back a LineShape, by default None model_res_nm : float, optional Model Resolution to use in [nm], by default 0.02 model_res_cminv : float, optional Model Resolution to use in [cm^-1], by default 0.02 spectral_native_coordinate : str, optional The native coordinate for the spectral axis, by default "wavelength_nm", can also be "wavenumber_cminv" round_decimal : int, optional Decimal points to round the wavelengths to in the radiative transfer calculation, by default 2 stokes_sensitivities : dict, optional Dictionary of stokes sensitivities, by default None. Can be set to multiple measurements, e.g., {"I": np.array([1, 0, 0, 0]), "Q": np.array([0, 1, 0, 0]), "U": np.array([0, 0, 1, 0])} to measure multiple stokes parameters separately. """ if lineshape_fn is None: self._lineshape_fn = lambda _: DeltaFunction() else: self._lineshape_fn = lineshape_fn self._model_res_nm = model_res_nm self._model_res_cminv = model_res_cminv self._round_decimal = round_decimal self._spectral_native_coordinate = spectral_native_coordinate self._stokes_sensitivities = stokes_sensitivities
def _get_required_wavelength(self): obs_samples = self._observation.sample_wavelengths() mv_required_samples = {} for key, val in self._meas_vec.items(): mv_required_samples[key] = val.required_sample_wavelengths(obs_samples) sample_wavelengths = {} for key in obs_samples: sample_wavelengths[key] = np.unique( np.concatenate( [np.atleast_1d(d[key]) for d in mv_required_samples.values()] ) ) return sample_wavelengths def _construct_model_wavelength(self): """ Evaluates the lineshape at the sample wavelengths and returns back the model wavelengths spaced by the model resolution """ sample_wavelengths = self._get_required_wavelength() ws = {} for k, v in sample_wavelengths.items(): if self._spectral_native_coordinate == "wavelength_nm": bounds = [ ( self._lineshape_fn(w).bounds(center=0) if self._lineshape_fn(w).zero_centered() else self._lineshape_fn(w).bounds(center=-w) ) for w in v ] ws[k] = np.unique( np.concatenate( [ np.around( np.arange( a, b + self._model_res_nm / 2, self._model_res_nm ) + w, self._round_decimal, ) for (a, b), w in zip(bounds, v) ] ) ) else: bounds = [ ( self._lineshape_fn(w).bounds(center=0) if self._lineshape_fn(w).zero_centered() else self._lineshape_fn(w).bounds(center=-1e7 / w) ) for w in v ] ws[k] = ( 1e7 / np.unique( np.concatenate( [ np.around( np.arange( 1e7 / (b + w), 1e7 / (a + w) + self._model_res_cminv / 2, self._model_res_cminv, ), self._round_decimal, ) for (a, b), w in zip(bounds, v) ] ) )[::-1] ) return ws def _construct_inst_model(self): """ Constructs the instrument model """ sample_wavelengths = self._get_required_wavelength() inst_models = {} for key in sample_wavelengths: inst_models[key] = SpectrographOnlySpectral( sample_wavelengths[key], [self._lineshape_fn(x) for x in sample_wavelengths[key]], spectral_native_coordinate=self._spectral_native_coordinate, assign_coord=( "wavelength" if self._spectral_native_coordinate == "wavelength_nm" else "wavenumber" ), stokes_sensitivity=self._stokes_sensitivities, ) return inst_models
[docs] class IdealViewingMixin:
[docs] def __init__(self, observation: Observation, model_altitude_grid: np.array) -> None: """ Mixin for adding an ideal viewing geometry to the forward model. This means that a single line of sight is used by the forward model for each observation rather than using a spatial PSF. Parameters ---------- observation : Observation model_altitude_grid : np.array """ self._model_altitude_grid = model_altitude_grid self._obs = observation
def _construct_viewing_geo(self): return self._obs.sk2_geometry() def _construct_model_geometry(self): # Construct the model geometry # State vector tells us the engine altitude grid altitude_grid_m = self._model_altitude_grid # Observation tells us the reference point cos_sza = self._obs.reference_cos_sza() ref_lat = self._obs.reference_latitude() ref_lon = self._obs.reference_longitude() geometry = {} geo = geodetic() for key in cos_sza: geo.from_lat_lon_alt(ref_lat[key], ref_lon[key], 0.0) earth_radius_m = np.linalg.norm(geo.location) geometry[key] = sk.Geometry1D( cos_sza[key], 0.0, earth_radius_m, altitude_grid_m, sk.InterpolationMethod.LinearInterpolation, sk.GeometryType.Spherical, ) return geometry
[docs] class IdealViewingSpectrograph( IdealViewingMixin, SpectrometerMixin, StandardForwardModel ):
[docs] def __init__( self, observation: Observation, state_vector: AltitudeNativeStateVector, meas_vec: MeasurementVector, ancillary: Ancillary, engine_config: sk.Config, **kwargs, ) -> None: """ A forward model for the retrieval that uses an ideal viewing geometry and a spectrometer Parameters ---------- observation : Observation state_vector : AltitudeNativeStateVector ancillary : Ancillary engine_config : sk.Config kwargs : dict Additional arguments passed to `SpectrometerMixin` """ IdealViewingMixin.__init__(self, observation, state_vector.altitude_grid) SpectrometerMixin.__init__( self, lineshape_fn=kwargs.get("lineshape_fn", lambda _: DeltaFunction()), model_res_cminv=kwargs.get("model_res_cminv", 0.02), model_res_nm=kwargs.get("model_res_nm", 0.02), round_decimal=kwargs.get("round_decimal", 2), spectral_native_coordinate=kwargs.get( "spectral_native_coordinate", "wavelength_nm" ), stokes_sensitivities=kwargs.get("stokes_sensitivities"), ) StandardForwardModel.__init__( self, observation, state_vector, meas_vec, ancillary, engine_config, **kwargs, )
class ForwardModelHandler(ForwardModel): def __init__( self, cfg: dict, observation: Observation, state_vector: AltitudeNativeStateVector, meas_vec: MeasurementVector, ancillary: Ancillary, engine_config: sk.Config, **kwargs, ): super().__init__() # Construct the internal forward models self._forward_models = {} for key, val in cfg.items(): self._forward_models[key] = val.get("class", IdealViewingSpectrograph)( FilteredObservation(observation, key), state_vector, meas_vec, ancillary, engine_config, **val.get("kwargs", {}), ) def calculate_radiance(self): result = {} for _, v in self._forward_models.items(): result.update(v.calculate_radiance()) return result