from __future__ import annotations
from copy import copy
from typing import ClassVar
import numpy as np
import sasktran2 as sk2
import skretrieval.retrieval.prior as prior
from skretrieval.retrieval.rodgers import Rodgers
from skretrieval.retrieval.scipy import SciPyMinimizer, SciPyMinimizerGrad
from skretrieval.retrieval.statevector.constituent import StateVectorElementConstituent
from skretrieval.retrieval.statevector.shifts import WavenumberShift
from skretrieval.retrieval.statevector.spline import (
MultiplicativeSpline,
MultiplicativeSplineOne,
)
from .ancillary import Ancillary, US76Ancillary
from .forwardmodel import ForwardModelHandler, IdealViewingSpectrograph
from .measvec import MeasurementVector, select
from .observation import Observation
from .statevector.altitude import AltitudeNativeStateVector
from .target.mvtarget import MeasVecTarget
[docs]
class Retrieval:
_optical_property_fns: ClassVar[dict[str, callable]] = {}
_prior_fns: ClassVar[dict[str, callable]] = {}
_state_fns: ClassVar[dict[str, dict[str, callable]]] = {
"absorbers": {},
"aerosols": {},
"splines": {},
"shifts": {},
"surface": {},
"other": {},
}
def _context_fn(_):
return {}
@classmethod
def register_context(cls):
def decorator(context_fn: callable):
cls._context_fn = context_fn
return context_fn
return decorator
@classmethod
def register_optical_property(cls, species_name: str):
def decorator(optical_property_fn: callable):
cls._optical_property_fns[species_name] = optical_property_fn
return optical_property_fn
return decorator
@classmethod
def register_prior(cls, species_name: str):
def decorator(prior_fn: callable):
cls._prior_fns[species_name] = prior_fn
return prior_fn
return decorator
@classmethod
def register_state(cls, category: str, species_name: str):
def decorator(state_fn: callable):
cls._state_fns[category][species_name] = state_fn
return state_fn
return decorator
[docs]
def __init__(
self,
observation: Observation,
measvec: dict[MeasurementVector] | None = None,
forward_model_cfg: dict | None = None,
minimizer="rodgers",
ancillary: Ancillary | None = None,
l1_kwargs: dict | None = None,
model_kwargs: dict | None = None,
minimizer_kwargs: dict | None = None,
target_kwargs: dict | None = None,
state_kwargs: dict | None = None,
**kwargs,
) -> None:
"""
The main processing script that handles the retrieval
Parameters
----------
observation : Observation
measvec : dict[MeasurementVector] | None, optional
Measurement vectors to use, by default will use the raw radiances, by default None
minimizer : str, optional
Selects which minimizer to use, default is "rodgers", by default "rodgers"
l1_kwargs : dict | None, optional
Additional arguments passed to the observation when constructing the L1, by default None
model_kwargs : dict | None, optional
Additional arguments passed to the SASKTRAN2 engine, by default None
minimizer_kwargs : dict | None, optional
Additional arguments passed to the minimizer, by default None
target_kwargs : dict | None, optional
Additional arguments passed to the retrieval target, by default None
state_kwargs : dict | None, optional
Arguments to construct the state vector, by default None
forward_model_cfg : dict | None, optional
Additional arguments passed to the forward model, by default None
"""
if minimizer.lower() == "rodgers":
# Override the default Rodgers options
self._minimizer_kwargs = {
"lm_damping_method": "fletcher",
"lm_damping": 0.1,
"max_iter": 30,
"lm_change_factor": 2,
"iterative_update_lm": True,
"retreat_lm": False,
"apply_cholesky_scaling": True,
"convergence_factor": 1e-2,
"convergence_check_method": "dcost",
}
else:
self._minimizer_kwargs = {}
if minimizer_kwargs is not None:
self._minimizer_kwargs.update(minimizer_kwargs)
if state_kwargs is None:
state_kwargs = {}
if target_kwargs is None:
target_kwargs = {}
if model_kwargs is None:
model_kwargs = {}
if l1_kwargs is None:
l1_kwargs = {}
if forward_model_cfg is None:
forward_model_cfg = {"*": {"class": IdealViewingSpectrograph}}
self._options = kwargs
self._l1_kwargs = l1_kwargs
self._minimizer = minimizer
self._target_kwargs = target_kwargs
self._state_kwargs = state_kwargs
self._model_kwargs = model_kwargs
self._measurement_vector = measvec
self._forward_model_cfg = forward_model_cfg
self._ancillary = ancillary
self._measurement_vector = self._construct_measurement_vector()
self._observation = observation
self._anc = self._construct_ancillary()
self._context = self._context_fn()
self._state_vector = self._construct_state_vector()
self._forward_model = self._construct_forward_model()
self._target = self._construct_target()
self._obs_l1 = self._observation.skretrieval_l1(
self._forward_model, self._state_vector, self._l1_kwargs
)
def _construct_measurement_vector(self):
if self._measurement_vector is not None:
return self._measurement_vector
return {
"measurement": MeasurementVector(
lambda l1, ctxt, **kwargs: select(l1, **kwargs) # noqa: ARG005
)
}
def _construct_forward_model(self):
engine_config = sk2.Config()
for k, v in self._model_kwargs.items():
setattr(engine_config, k, v)
return ForwardModelHandler(
self._forward_model_cfg,
self._observation,
self._state_vector,
self._measurement_vector,
self._anc,
engine_config,
)
def _const_from_mipas(
self,
alt_grid,
species_name,
optical,
prior_infl=1e-2,
tikh=1e8,
log_space=False,
min_val=0,
max_val=1,
):
const = sk2.climatology.mipas.constituent(species_name, optical)
altitudes_m = const._constituent.altitudes_m
new_vmr = np.interp(alt_grid, altitudes_m, const.vmr)
new_const = sk2.constituent.VMRAltitudeAbsorber(
optical, alt_grid, new_vmr, out_of_bounds_mode="extend"
)
min_val = 1e-40 if log_space else min_val
return StateVectorElementConstituent(
new_const,
species_name,
["vmr"],
min_value={"vmr": min_val},
max_value={"vmr": max_val},
prior={
"vmr": tikh * prior.VerticalTikhonov(1, new_vmr)
+ prior_infl * prior.ConstantDiagonalPrior()
},
log_space=log_space,
)
def _optical_property(self, species_name: str):
return self._optical_property_fns[species_name]()
@staticmethod
def _default_state_absorber(
processor: Retrieval, name: str, native_alt_grid: np.array, cfg: dict
):
const = processor._const_from_mipas(
native_alt_grid,
name,
processor._optical_property(name),
tikh=cfg["tikh_factor"],
prior_infl=cfg["prior_influence"],
log_space=cfg["log_space"],
min_val=cfg["min_value"],
max_val=cfg["max_value"],
)
const.enabled = cfg.get("enabled", True)
return const
@staticmethod
def _default_state_surface(
_processor: Retrieval,
name: str,
_native_alt_grid: np.array,
cfg: dict, # noqa: ARG004
):
msg = f"Surface {name} does not have a default implementation"
raise ValueError(msg)
@staticmethod
def _default_state_spline(
_processor: Retrieval,
name: str,
_native_alt_grid: np.array,
cfg: dict, # noqa: ARG004
):
msg = f"Spline {name} does not have a default implementation"
raise ValueError(msg)
@staticmethod
def _default_state_aerosol(
_processor: Retrieval,
name: str,
_native_alt_grid: np.array,
cfg: dict, # noqa: ARG004
):
msg = f"aerosol {name} does not have a default implementation"
raise ValueError(msg)
@staticmethod
def _default_state_shift(
_processor: Retrieval,
name: str,
_native_alt_grid: np.array,
cfg: dict, # noqa: ARG004
):
msg = f"shift {name} does not have a default implementation"
raise ValueError(msg)
def _construct_state_vector(self):
native_alt_grid = self._state_kwargs["altitude_grid"]
absorbers = {}
for name, options in self._state_kwargs.get("absorbers", {}).items():
absorbers[name] = self._state_fns["absorbers"].get(
name, self._default_state_absorber
)(self, name, native_alt_grid, options)
surface = {}
for name, options in self._state_kwargs.get("surface", {}).items():
surface[name] = self._state_fns["surface"].get(
name, self._default_state_surface
)(self, name, native_alt_grid, options)
aerosols = {}
for name, aerosol in self._state_kwargs.get("aerosols", {}).items():
aerosols[f"{name}"] = self._state_fns["aerosols"].get(
aerosol["type"], self._default_state_aerosol
)(self, name, native_alt_grid, aerosol)
splines = {}
for name, spline in self._state_kwargs.get("splines", {}).items():
splines[name] = self._state_fns["splines"].get(
name, self._default_state_spline
)(self, name, native_alt_grid, spline)
shifts = {}
for name, shift in self._state_kwargs.get("shifts", {}).items():
shifts[name] = self._state_fns["shifts"].get(
shift["type"], self._default_state_shift
)(self, name, native_alt_grid, shift)
others = {}
for name, other in self._state_kwargs.get("other", {}).items():
others[name] = self._state_fns["other"].get(
other["type"], self._default_state_shift
)(self, name, native_alt_grid, other)
return AltitudeNativeStateVector(
native_alt_grid,
**absorbers,
**surface,
**aerosols,
**splines,
**shifts,
**others,
)
def _construct_ancillary(self):
if self._ancillary is None:
return US76Ancillary()
return self._ancillary
def _construct_target(self):
return MeasVecTarget(
self._state_vector,
self._measurement_vector,
self._context,
**self._target_kwargs,
)
def _construct_output(self, rodgers_output: dict):
return rodgers_output
def retrieve(
self,
enabled_state_elements: list[str] | None = None,
enabled_measurement_vectors: list[str] | None = None,
):
if self._minimizer == "rodgers":
minimizer = Rodgers(**self._minimizer_kwargs)
elif self._minimizer == "scipy":
minimizer = SciPyMinimizer(**self._minimizer_kwargs)
elif self._minimizer == "scipy_grad":
minimizer = SciPyMinimizerGrad()
if enabled_state_elements is not None:
for key, val in self._state_vector.sv.items():
if key in enabled_state_elements:
val.enabled = True
else:
val.enabled = False
if enabled_measurement_vectors is not None:
for key, val in self._measurement_vector.items():
if key in enabled_measurement_vectors:
val.enabled = True
else:
val.enabled = False
self._target.update_state_slices()
min_results = minimizer.retrieve(
self._obs_l1, self._forward_model, self._target
)
# Reset the enabled flag
for _, val in self._state_vector.sv.items():
val.enabled = True
for _, val in self._measurement_vector.items():
val.enabled = True
# Post process
final_l1 = self._forward_model.calculate_radiance()
meas_l1 = self._obs_l1
results = {}
results["minimizer"] = min_results
results["meas_l1"] = meas_l1
results["simulated_l1"] = final_l1
results["state"] = self._state_vector.describe(min_results)
return self._construct_output(results)
# Register all the default optical properties
@Retrieval.register_optical_property("o3")
def o3_optical_property(*args, **kwargs):
return sk2.optical.O3DBM()
@Retrieval.register_optical_property("no2")
def no2_optical_property(*args, **kwargs):
return sk2.optical.NO2Vandaele()
@Retrieval.register_optical_property("bro")
def bro_optical_property(*args, **kwargs):
return sk2.optical.HITRANUV("BrO")
@Retrieval.register_optical_property("so2")
def so2_optical_property(*args, **kwargs):
return sk2.optical.HITRANUV("SO2")
# Register the default Lambertian surface state
@Retrieval.register_state("surface", "lambertian_albedo")
def lambertian_state(self, name, native_alt_grid: np.array, cfg: dict): # noqa: ARG001
albedo_wavel = cfg["wavelengths"]
albedo_start = np.ones(len(albedo_wavel)) * cfg["initial_value"]
albedo_const = sk2.constituent.LambertianSurface(
albedo_start, albedo_wavel, cfg.get("out_of_bounds_mode", "extend")
)
sv_ele = StateVectorElementConstituent(
albedo_const,
name,
["albedo"],
min_value={"albedo": 0},
max_value={"albedo": 1},
prior={
"albedo": cfg["tikh_factor"] * prior.VerticalTikhonov(1)
+ cfg["prior_influence"] * prior.ConstantDiagonalPrior()
},
log_space=False,
)
sv_ele.enabled = cfg.get("enabled", True)
return sv_ele
@Retrieval.register_state("aerosols", "extinction_profile")
def aerosol_extinction_profile(self, name: str, native_alt_grid: np.array, cfg: dict):
if cfg.get("prior_state") is not None and False:
ext = cfg["prior_state"]
else:
aero_const = sk2.test_util.scenarios.test_aerosol_constituent(native_alt_grid)
ext = copy(aero_const.extinction_per_m)
low_boundary = np.nonzero(ext)[0][0]
ext[:low_boundary] = ext[low_boundary]
ext[ext < 1e-15] = 1e-15
scale_factor = cfg.get("scale_factor", 1)
secondary_kwargs = {
name: np.ones_like(native_alt_grid) * cfg["prior"][name]["value"]
for name in cfg["prior"]
if name != "extinction_per_m"
}
db = self._optical_property(name)
aero_const = sk2.constituent.ExtinctionScatterer(
db,
native_alt_grid,
ext * scale_factor,
cfg.get("nominal_wavelength", 745),
"extend",
**secondary_kwargs,
)
sv_ele = StateVectorElementConstituent(
aero_const,
f"{name}",
cfg["retrieved_quantities"].keys(),
min_value={
name: val["min_value"] for name, val in cfg["retrieved_quantities"].items()
},
max_value={
name: val["max_value"] for name, val in cfg["retrieved_quantities"].items()
},
prior={
name: val["tikh_factor"]
* prior.VerticalTikhonov(
1, prior_state=secondary_kwargs.get(name, ext * scale_factor)
)
+ val["prior_influence"] * prior.ConstantDiagonalPrior()
for name, val in cfg["retrieved_quantities"].items()
},
log_space=False,
)
sv_ele.enabled = cfg.get("enabled", True)
if cfg.get("initial_guess") is not None:
sv_ele.update_state(cfg["initial_guess"])
return sv_ele
@Retrieval.register_state("shifts", "wavenumber_shift")
def wavenumber_shift(
self, name: str, native_alt_grid: np.array, cfg: dict # noqa: ARG001
):
sv_ele = WavenumberShift(
cfg["num_los"],
numerical_delta=cfg.get("numerical_delta", 1e-5),
tikh_factor=cfg.get("tikh_factor", 1e4),
prior_factor=cfg.get("prior_factor", 0),
min_shift=cfg.get("min_shift", -0.1),
max_shift=cfg.get("max_shift", 0.1),
apply_to_measurement=cfg.get("apply_to_measurement"),
)
sv_ele.enabled = cfg.get("enabled", True)
return sv_ele
@Retrieval.register_state("splines", "constant_los")
def constant_los(self, name: str, native_alt_grid: np.array, cfg: dict): # noqa: ARG001
return MultiplicativeSplineOne(
cfg["low_wavelength_nm"],
cfg["high_wavelength_nm"],
cfg["num_wv"],
cfg.get("s", 1),
order=cfg["order"],
min_value=cfg.get("min_value", 0.5),
max_value=cfg.get("max_value", 1.5),
)
@Retrieval.register_state("aerosols", "gaussian_extinction_profile")
def gaussian_extinction_profile(self, name: str, native_alt_grid: np.array, cfg: dict):
scale_factor = cfg.get("scale_factor", 1)
secondary_kwargs = {
name: np.ones_like(native_alt_grid) * cfg["prior"][name]["value"]
for name in cfg["prior"]
if name not in ["vertical_optical_depth", "width_fwhm_m", "height_m"]
}
db = self._optical_property(name)
aero_const = sk2.constituent.GaussianHeightExtinction(
db,
cfg["prior"]["height_m"]["value"],
cfg["prior"]["width_fwhm_m"]["value"],
cfg["prior"]["vertical_optical_depth"]["value"],
cfg.get("nominal_wavelength", 745),
native_alt_grid,
"extend",
**secondary_kwargs,
)
sv_ele = StateVectorElementConstituent(
aero_const,
f"{name}",
cfg["retrieved_quantities"].keys(),
min_value={
name: val["min_value"] for name, val in cfg["retrieved_quantities"].items()
},
max_value={
name: val["max_value"] for name, val in cfg["retrieved_quantities"].items()
},
prior={
name: val["tikh_factor"]
* prior.VerticalTikhonov(
1,
prior_state=secondary_kwargs.get(
name, np.atleast_1d(cfg["prior"][name]["value"] * scale_factor)
),
)
+ val["prior_influence"] * prior.ConstantDiagonalPrior()
for name, val in cfg["retrieved_quantities"].items()
},
log_space=False,
)
sv_ele.enabled = cfg.get("enabled", True)
if cfg.get("initial_guess") is not None:
sv_ele.update_state(cfg["initial_guess"])
return sv_ele
@Retrieval.register_state("splines", "multiplicative")
def multiplicative(
self, name: str, native_alt_grid: np.array, cfg: dict # noqa: ARG001
):
return MultiplicativeSpline(
cfg["num_los"],
cfg["low_wavelength_nm"],
cfg["high_wavelength_nm"],
cfg["num_wv"],
cfg.get("s", 1),
order=cfg["order"],
min_value=cfg.get("min_value", 0.5),
max_value=cfg.get("max_value", 1.5),
)
@Retrieval.register_state("other", "volume_emission_rate")
def volume_emission_rate(
self, name: str, native_alt_grid: np.array, cfg: dict # noqa: ARG001
):
emission_wavelength = cfg["wavelength_nm"]
initial_ver = np.ones_like(native_alt_grid) * 1e-12
# p = prior.VerticalTikhonov(cfg["tikhonov_factor"], prior_state = copy(initial_ver)) + 1e-10 * prior.ConstantDiagonalPrior()
p = 1e-12 * prior.ConstantDiagonalPrior() + cfg[
"tikhonov_factor"
] * prior.VerticalTikhonov(1, prior_state=copy(initial_ver))
const = sk2.constituent.MonochromaticVolumeEmissionRate(
native_alt_grid, initial_ver, emission_wavelength
)
return StateVectorElementConstituent(
const,
name,
["ver"],
prior={"ver": p},
log_space=False,
min_value={"ver": -1e-9},
max_value={"ver": 1e-9},
)