Source code for skretrieval.retrieval.statevector
from __future__ import annotations
import abc
from collections.abc import Iterable
import numpy as np
import xarray as xr
[docs]
class StateVectorElement(abc.ABC):
"""
A state vector element is a component of the full state vector used in the retrieval. Each state vector element
has a state, and a prior state/covariance associated with it. The state vector element must also be able
to update itself, calculate the jacobian matrix for itself.
"""
[docs]
def __init__(self, enabled: bool = True):
self._enabled = enabled
@abc.abstractmethod
def state(self) -> np.array:
pass
def inverse_apriori_covariance(self) -> np.ndarray:
n = len(self.state())
return np.zeros((n, n))
def apriori_state(self) -> np.array:
return np.zeros_like(self.state())
def lower_bound(self) -> np.array:
n = len(self.state())
return np.ones(n) * (-np.inf)
def upper_bound(self) -> np.array:
n = len(self.state())
return np.ones(n) * (np.inf)
@property
def enabled(self) -> bool:
return self._enabled
@enabled.setter
def enabled(self, e: bool):
self._enabled = e
@abc.abstractmethod
def name(self) -> str:
pass
@abc.abstractmethod
def propagate_wf(self, radiance: xr.Dataset) -> xr.Dataset:
pass
@abc.abstractmethod
def update_state(self, x: np.array):
pass
def modify_input_radiance(self, radiance: xr.Dataset):
return radiance
def describe(self, **kwargs) -> xr.Dataset | None:
return None
[docs]
class StateVector:
[docs]
def __init__(self, elements: Iterable[StateVectorElement]):
"""
A full state vector made up of a collection of state vector elements.
Parameters
----------
elements: Iterable[StateVectorElement]
A collection of state vector elements
"""
self._elements = elements
@property
def state_elements(self):
return self._elements
[docs]
def update_sasktran_radiance(self, radiance: xr.Dataset, drop_old_wf: bool = False):
"""
Modifies radiances output from sasktran based on the state vector elements if applicable, e.g., if a state
vector element is a wavelength shift this will apply it.
Propagates weighting functions from the sasktran radiance raw output to weighting functions for each
state vector element.
If drop_old_wf is set to true then the old weighting functions are removed from the radiance.
Parameters
----------
radiance: xr.Dataset
Output from sk.Engine.calculate_radiance(output_format='xarray')
drop_old_wf: bool, Optional
If true then the old weighting functions are removed after being propagated to the state vector. Default
False
Returns
-------
radiance: xr.Dataset
Modified radiance with a new key 'wf' that is the jacobian with respect to the full state vector.
"""
all_jacobian = []
for state_element in self._elements:
all_jacobian.append(state_element.propagate_wf(radiance))
radiance = state_element.modify_input_radiance(radiance)
new_wf = xr.concat(all_jacobian, dim="x")
radiance["wf"] = new_wf
if drop_old_wf:
wf_names = [key for key in radiance if key.startswith("wf_")]
radiance = radiance.drop(wf_names)
return radiance
def describe(self, rodgers_output: dict, **kwargs) -> xr.Dataset:
all_ds = []
covar = rodgers_output["error_covariance_from_noise"]
averaging_kernel = rodgers_output["averaging_kernel"]
start = 0
for state_element in self._elements:
end = start + len(state_element.state())
s = slice(start, end)
ds = state_element.describe(
covariance=covar[s, s],
averaging_kernel=averaging_kernel[s, s],
**kwargs,
)
if ds is not None:
all_ds.append(ds)
start = end
return xr.merge(all_ds)