from __future__ import annotations
import abc
import dataclasses
from copy import copy
import numpy as np
from skretrieval.retrieval.statevector import StateVectorElement
from skretrieval.retrieval.tikhonov import (
two_dim_vertical_first_deriv,
two_dim_vertical_second_deriv,
)
@dataclasses.dataclass
class Prior:
"""
A dataceass to hold a prior state of size (n,) and an inverse covariance of size (n, n)
"""
state: np.array
inverse_covariance: np.ndarray
class BasePrior:
"""
Base class to handle prior states
"""
@property
@abc.abstractmethod
def state(self) -> np.array:
"""
The prior state $x_a$ of size (n,)
"""
@property
@abc.abstractmethod
def inverse_covariance(self):
"""
The inverse covariance of the prior state $S_a^{-1}$ of size (n, n)
"""
def __mul__(self, other):
return MultipliedPrior(self, other)
__rmul__ = __mul__
def __add__(self, other):
return AdditivePrior(self, other)
def init(self, sv: StateVectorElement, sv_slice: slice | None = None):
pass
class MultipliedPrior(BasePrior):
def __init__(self, prior: BasePrior, multiplier: float):
"""
A prior where the inverse covariance is multiplied by a scalar, the prior state remains unchanged
Parameters
----------
prior : BasePrior
multiplier : float
"""
self._prior = prior
self._multiplier = multiplier
@property
def state(self):
return self._prior.state
@property
def inverse_covariance(self):
return self._prior.inverse_covariance * self._multiplier
def init(self, sv: StateVectorElement, sv_slice: slice | None = None):
self._prior.init(sv, sv_slice)
class AdditivePrior(BasePrior):
def __init__(self, prior1: BasePrior, prior2: BasePrior):
"""
A prior where two priors are added together. This results in a sum of the inverse covariance,
and then a new prior state x_a
Parameters
----------
prior1 : BasePrior
prior2 : BasePrior
"""
self._prior1 = prior1
self._prior2 = prior2
@property
def state(self):
# Have to solve the system to get the equivalent prior state
inv_S_a_1 = self._prior1.inverse_covariance
inv_S_a_2 = self._prior2.inverse_covariance
x_a_1 = self._prior1.state
x_a_2 = self._prior2.state
full_inv_S_a = inv_S_a_1 + inv_S_a_2
rhs = inv_S_a_1 @ x_a_1 + inv_S_a_2 @ x_a_2
# For some priors the inverse covariance will be singular
try:
return np.linalg.solve(full_inv_S_a, rhs)
except np.linalg.LinAlgError:
# If the inverse covariance is singular, we can't solve the system
# TODO: Is this actually right? It seems okay in most cases, but in general
# i'm not so sure
return 0.5 * (x_a_1 + x_a_2)
@property
def inverse_covariance(self):
return self._prior1.inverse_covariance + self._prior2.inverse_covariance
def init(self, sv: StateVectorElement, sv_slice: slice | None = None):
self._prior1.init(sv, sv_slice)
self._prior2.init(sv, sv_slice)
class VerticalPrior(BasePrior):
def __init__(self, altitudes: np.array):
self._altitudes = altitudes
[docs]
class VerticalTikhonov(VerticalPrior):
[docs]
def __init__(
self,
order: int,
prior_state: np.array = None,
tikhonov: np.array = None,
):
"""
A prior that is constructed as a Tikhonov constraint.
Parameters
----------
order : int
Order of the Tikhonov constraint, only 1 and 2 are supported
prior_state : np.array, optional
Prior state. If set to None a zero prior is used, by default None
tikhonov : np.array, optional
Array of factors to multiply the constraint by, by default None
"""
self._tikhonov = tikhonov
self._prior_state = prior_state
self._order = order
def init(self, sv: StateVectorElement, sv_slice: slice | None = None):
n = len(sv.state()[sv_slice])
if self._order == 1:
self._gamma = two_dim_vertical_first_deriv(1, n, factor=1)
elif self._order == 2:
self._gamma = two_dim_vertical_second_deriv(1, n, factor=1)
else:
msg = f"Order {self._order} not supported."
raise ValueError(msg)
if self._tikhonov is not None:
# Scale by the weights
self._gamma *= self._tikhonov[np.newaxis, :]
self._prior = Prior(
inverse_covariance=self._gamma.T @ self._gamma,
state=(np.zeros(n) if self._prior_state is None else self._prior_state),
)
@property
def state(self):
return self._prior.state
@property
def inverse_covariance(self):
return self._prior.inverse_covariance
[docs]
class ManualPrior(BasePrior):
[docs]
def __init__(self, state: np.array, inverse_covariance: np.array):
"""
A prior that is manually specified, both the prior state and it's covariance
Parameters
----------
state : np.array
inverse_covariance : np.array
"""
self._state = state
self._inverse_covariance = inverse_covariance
@property
def state(self):
return self._state
@property
def inverse_covariance(self):
return self._inverse_covariance
def init(self, sv: StateVectorElement, sv_slice: slice | None = None):
pass
[docs]
class ConstantDiagonalPrior(BasePrior):
[docs]
def __init__(self, value: float = 1.0):
"""
A prior that is constant along the diagonal. The initial state is pulled from
the StateVectorElement upon initialization.
Parameters
----------
value : float, optional
_description_, by default 1.0
"""
self._value = value
def init(self, sv: StateVectorElement, sv_slice: slice | None = None):
n = len(sv.state()[sv_slice])
self._prior = Prior(
inverse_covariance=np.eye(n) * self._value,
state=copy(sv.state()[sv_slice]),
)
@property
def state(self):
return self._prior.state
@property
def inverse_covariance(self):
return self._prior.inverse_covariance