Source code for skretrieval.retrieval.statevector.constituent

from __future__ import annotations

import numpy as np
import sasktran2 as sk2
import xarray as xr
from scipy.linalg import block_diag

from skretrieval.retrieval.prior import BasePrior

from . import StateVectorElement


[docs] class StateVectorElementConstituent( StateVectorElement, sk2.constituent.base.Constituent ):
[docs] def __init__( self, constituent: sk2.constituent.base.Constituent, constituent_name: str, property_names: list[str], min_value=None, max_value=None, prior: dict[BasePrior] | None = None, log_space=False, enabled=True, scale_factor: float = 1.0, ): """ A state vector element that is a sasktran2.constituent Parameters ---------- constituent : sk2.constituent.base.Constituent The sasktran2 constituent constituent_name : str A name for the constituent property_names : list[str] Property names of the constituent that will be retrieved min_value : dict, optional Minimum values for the property names as a dictionary, by default {} max_value : dict, optional maximumum values for the property names as a dictionary, by default {} prior : dict, optional Prior objects for each property name, by default {} log_space : bool, optional If true then the state elements will be rescaled to logarithmic space, by default False scale_factor : float, optional Constant multiplicative factor between constituent properties and retrieval state. A state value of 1 corresponds to a constituent property value of 1 / scale_factor, by default 1.0 """ if prior is None: prior = {} if max_value is None: max_value = {} if min_value is None: min_value = {} if scale_factor <= 0: msg = "scale_factor must be positive" raise ValueError(msg) self._log_space = log_space self._scale_factor = float(scale_factor) self._constituent = constituent self._property_names = property_names self._constituent_name = constituent_name self._min_value = min_value self._max_value = max_value self._prior = prior start = 0 for property_name in self._property_names: if property_name in self._prior: n = len(np.atleast_1d(getattr(self._constituent, property_name))) self._prior[property_name].init(self, slice(start, start + n)) start += n else: self._prior[property_name] = BasePrior() super().__init__(enabled)
def state(self) -> np.array: data = [] for property_name in self._property_names: data.append(getattr(self._constituent, property_name) * self._scale_factor) if self._log_space: return np.log(np.hstack(data)) return np.hstack(data) def lower_bound(self) -> np.array: data = [] for property_name in self._property_names: x = getattr(self._constituent, property_name) bound = self._min_value.get(property_name, -np.inf) data.append(np.ones(len(x)) * bound * self._scale_factor) if self._log_space: return np.log(np.hstack(data)) return np.hstack(data) def upper_bound(self) -> np.array: data = [] for property_name in self._property_names: x = getattr(self._constituent, property_name) bound = self._max_value.get(property_name, np.inf) data.append(np.ones(len(x)) * bound * self._scale_factor) if self._log_space: return np.log(np.hstack(data)) return np.hstack(data) def inverse_apriori_covariance(self) -> np.ndarray: prior_mats = [] for property_name in self._property_names: inv_S_a = self._prior[property_name].inverse_covariance if self._log_space: prior_mats.append(inv_S_a) else: prior_mats.append( inv_S_a / np.outer( self._prior[property_name].state, self._prior[property_name].state, ) ) return block_diag(*prior_mats) def apriori_state(self) -> np.array: return np.concatenate( [self._prior[property].state for property in self._property_names] ) def name(self) -> str: return self._constituent_name def propagate_wf(self, radiance: xr.Dataset) -> xr.Dataset: if "extinction_per_m" in self._property_names: radiance = radiance.rename( { f"wf_{self._constituent_name}_extinction": f"wf_{self._constituent_name}_extinction_per_m" } ) wfs = [] for property_name in self._property_names: wfs.append( radiance[f"wf_{self._constituent_name}_{property_name}"].rename( { radiance[f"wf_{self._constituent_name}_{property_name}"].dims[ 0 ]: "x" } ) ) if self._log_space: x = getattr(self._constituent, property_name) wfs[-1].values *= x[:, np.newaxis, np.newaxis, np.newaxis] else: wfs[-1].values = wfs[-1].values / self._scale_factor return xr.concat(wfs, dim="x") def update_state(self, x: np.array): start = 0 for property_name in self._property_names: current = getattr(self._constituent, property_name) property_length = len(np.atleast_1d(current)) if self._log_space: sv = np.exp(x[start : start + property_length]) / self._scale_factor if np.sum(np.isnan(sv)) > 0: sv[np.isnan(sv)] = self._max_value[property_name] else: sv = x[start : start + property_length] / self._scale_factor if property_name in self._min_value: sv[sv < self._min_value[property_name]] = self._min_value[property_name] if property_name in self._max_value: sv[sv > self._max_value[property_name]] = self._max_value[property_name] self._constituent.__setattr__(property_name, sv) start += property_length def modify_input_radiance(self, radiance: xr.Dataset): return radiance def add_to_atmosphere(self, atmo: sk2.Atmosphere): return self._constituent.add_to_atmosphere(atmo) def register_derivative(self, atmo: sk2.Atmosphere, name: str): return self._constituent.register_derivative(atmo, name) def adjust_constituent_attributes(self, **kwargs): for key, value in kwargs.items(): if isinstance(value, dict): for k, v in value.items(): if k.lower() == "scale": setattr( self._constituent, key, getattr(self._constituent, key) * v ) if k.lower() == "set": setattr(self._constituent, key, v) else: setattr(self._constituent, key, getattr(self._constituent, key) * value) def describe(self, **kwargs) -> xr.Dataset | None: ds = xr.Dataset() if ( type(self._constituent) is sk2.constituent.brdf.lambertiansurface.LambertianSurface ): albedo = getattr(self._constituent, self._property_names[0]) ds[self._constituent_name] = xr.DataArray( albedo, dims=[self._constituent._interp_var], coords={self._constituent._interp_var: self._constituent._x}, ) ds[self._constituent_name + "_1sigma_error"] = xr.DataArray( np.sqrt(np.diag(kwargs["covariance"])), dims=[self._constituent._interp_var], coords={self._constituent._interp_var: self._constituent._x}, ) else: start = 0 for property_name in self._property_names: end = start + len( np.atleast_1d(getattr(self._constituent, property_name)) ) if self._log_space: prior_values = ( np.exp(self._prior[property_name].state) / self._scale_factor ) else: prior_values = self._prior[property_name].state / self._scale_factor ds[self._constituent_name + "_" + property_name + "_prior"] = ( xr.DataArray(prior_values, dims=["altitude"]) ) if end - start == 1: # scalar property ds[self._constituent_name + "_" + property_name] = xr.DataArray( float(getattr(self._constituent, property_name)) ) ds[self._constituent_name + "_" + property_name + "_prior"] = float( self._prior[property_name].state ) if "covariance" in kwargs: if self._log_space: ds[ self._constituent_name + "_" + property_name + "_1sigma_error" ] = float( np.sqrt(np.diag(kwargs["covariance"])[start:end]) * getattr(self._constituent, property_name) ) else: ds[ self._constituent_name + "_" + property_name + "_1sigma_error" ] = float(np.sqrt(np.diag(kwargs["covariance"])[start:end])) if "averaging_kernel" in kwargs: ds[ self._constituent_name + "_" + property_name + "_averaging_kernel" ] = float(kwargs["averaging_kernel"][start:end, start:end]) else: ds[self._constituent_name + "_" + property_name] = xr.DataArray( getattr(self._constituent, property_name), dims=["altitude"] ) ds[self._constituent_name + "_" + property_name + "_prior"] = ( xr.DataArray( self._prior[property_name].state, dims=["altitude"] ) ) if "covariance" in kwargs: if self._log_space: ds[ self._constituent_name + "_" + property_name + "_1sigma_error" ] = xr.DataArray( np.sqrt(np.diag(kwargs["covariance"])[start:end]) * getattr(self._constituent, property_name), dims=["altitude"], ) else: ds[ self._constituent_name + "_" + property_name + "_1sigma_error" ] = xr.DataArray( np.sqrt(np.diag(kwargs["covariance"])[start:end]), dims=["altitude"], ) if "averaging_kernel" in kwargs: ds[ self._constituent_name + "_" + property_name + "_averaging_kernel" ] = xr.DataArray( kwargs["averaging_kernel"][start:end, start:end], dims=["altitude", "altitude_2"], ) start = end return ds