pymc-extras 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/distributions/timeseries.py +10 -10
- pymc_extras/inference/dadvi/dadvi.py +14 -83
- pymc_extras/inference/laplace_approx/laplace.py +187 -159
- pymc_extras/inference/pathfinder/pathfinder.py +12 -7
- pymc_extras/inference/smc/sampling.py +2 -2
- pymc_extras/model/marginal/distributions.py +4 -2
- pymc_extras/model/marginal/marginal_model.py +12 -2
- pymc_extras/prior.py +3 -3
- pymc_extras/statespace/core/properties.py +276 -0
- pymc_extras/statespace/core/statespace.py +182 -45
- pymc_extras/statespace/filters/distributions.py +19 -34
- pymc_extras/statespace/filters/kalman_filter.py +13 -12
- pymc_extras/statespace/filters/kalman_smoother.py +2 -2
- pymc_extras/statespace/models/DFM.py +179 -168
- pymc_extras/statespace/models/ETS.py +177 -151
- pymc_extras/statespace/models/SARIMAX.py +149 -152
- pymc_extras/statespace/models/VARMAX.py +134 -145
- pymc_extras/statespace/models/__init__.py +8 -1
- pymc_extras/statespace/models/structural/__init__.py +30 -8
- pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
- pymc_extras/statespace/models/structural/components/cycle.py +119 -80
- pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
- pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
- pymc_extras/statespace/models/structural/components/regression.py +105 -68
- pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
- pymc_extras/statespace/models/structural/core.py +397 -286
- pymc_extras/statespace/models/utilities.py +5 -20
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/METADATA +4 -4
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/RECORD +31 -30
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/WHEEL +0 -0
- {pymc_extras-0.6.0.dist-info → pymc_extras-0.8.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -278,12 +278,13 @@ def alpha_recover(
|
|
|
278
278
|
z = pt.diff(g, axis=0)
|
|
279
279
|
alpha_l_init = pt.ones(N)
|
|
280
280
|
|
|
281
|
-
alpha
|
|
281
|
+
alpha = pytensor.scan(
|
|
282
282
|
fn=compute_alpha_l,
|
|
283
283
|
outputs_info=alpha_l_init,
|
|
284
284
|
sequences=[s, z],
|
|
285
285
|
n_steps=Lp1 - 1,
|
|
286
286
|
allow_gc=False,
|
|
287
|
+
return_updates=False,
|
|
287
288
|
)
|
|
288
289
|
|
|
289
290
|
# assert np.all(alpha.eval() > 0), "alpha cannot be negative"
|
|
@@ -334,11 +335,12 @@ def inverse_hessian_factors(
|
|
|
334
335
|
return pt.set_subtensor(chi_l[j_last], diff_l)
|
|
335
336
|
|
|
336
337
|
chi_init = pt.zeros((J, N))
|
|
337
|
-
chi_mat
|
|
338
|
+
chi_mat = pytensor.scan(
|
|
338
339
|
fn=chi_update,
|
|
339
340
|
outputs_info=chi_init,
|
|
340
341
|
sequences=[diff],
|
|
341
342
|
allow_gc=False,
|
|
343
|
+
return_updates=False,
|
|
342
344
|
)
|
|
343
345
|
|
|
344
346
|
chi_mat = pt.matrix_transpose(chi_mat)
|
|
@@ -377,14 +379,14 @@ def inverse_hessian_factors(
|
|
|
377
379
|
eta = pt.diagonal(E, axis1=-2, axis2=-1)
|
|
378
380
|
|
|
379
381
|
# beta: (L, N, 2J)
|
|
380
|
-
alpha_diag
|
|
382
|
+
alpha_diag = pytensor.scan(lambda a: pt.diag(a), sequences=[alpha], return_updates=False)
|
|
381
383
|
beta = pt.concatenate([alpha_diag @ Z, S], axis=-1)
|
|
382
384
|
|
|
383
385
|
# more performant and numerically precise to use solve than inverse: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.inv.html
|
|
384
386
|
|
|
385
387
|
# E_inv: (L, J, J)
|
|
386
|
-
E_inv = pt.
|
|
387
|
-
eta_diag
|
|
388
|
+
E_inv = pt.linalg.solve_triangular(E, Ij, check_finite=False)
|
|
389
|
+
eta_diag = pytensor.scan(pt.diag, sequences=[eta], return_updates=False)
|
|
388
390
|
|
|
389
391
|
# block_dd: (L, J, J)
|
|
390
392
|
block_dd = (
|
|
@@ -530,7 +532,9 @@ def bfgs_sample_sparse(
|
|
|
530
532
|
|
|
531
533
|
# qr_input: (L, N, 2J)
|
|
532
534
|
qr_input = inv_sqrt_alpha_diag @ beta
|
|
533
|
-
|
|
535
|
+
Q, R = pytensor.scan(
|
|
536
|
+
fn=pt.linalg.qr, sequences=[qr_input], allow_gc=False, return_updates=False
|
|
537
|
+
)
|
|
534
538
|
|
|
535
539
|
IdN = pt.eye(R.shape[1])[None, ...]
|
|
536
540
|
IdN += IdN * REGULARISATION_TERM
|
|
@@ -623,10 +627,11 @@ def bfgs_sample(
|
|
|
623
627
|
|
|
624
628
|
L, N, JJ = beta.shape
|
|
625
629
|
|
|
626
|
-
|
|
630
|
+
alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag = pytensor.scan(
|
|
627
631
|
lambda a: [pt.diag(a), pt.diag(pt.sqrt(1.0 / a)), pt.diag(pt.sqrt(a))],
|
|
628
632
|
sequences=[alpha],
|
|
629
633
|
allow_gc=False,
|
|
634
|
+
return_updates=False,
|
|
630
635
|
)
|
|
631
636
|
|
|
632
637
|
u = pt.random.normal(size=(L, num_samples, N))
|
|
@@ -238,7 +238,7 @@ class SMCDiagnostics(NamedTuple):
|
|
|
238
238
|
def update_diagnosis(i, history, info, state):
|
|
239
239
|
le, lli, ancestors, weights_evolution = history
|
|
240
240
|
return SMCDiagnostics(
|
|
241
|
-
le.at[i].set(state.
|
|
241
|
+
le.at[i].set(state.tempering_param),
|
|
242
242
|
lli.at[i].set(info.log_likelihood_increment),
|
|
243
243
|
ancestors.at[i].set(info.ancestors),
|
|
244
244
|
weights_evolution.at[i].set(state.weights),
|
|
@@ -265,7 +265,7 @@ def inference_loop(rng_key, initial_state, kernel, iterations_to_diagnose, n_par
|
|
|
265
265
|
|
|
266
266
|
def cond(carry):
|
|
267
267
|
i, state, _, _ = carry
|
|
268
|
-
return state.
|
|
268
|
+
return state.tempering_param < 1
|
|
269
269
|
|
|
270
270
|
def one_step(carry):
|
|
271
271
|
i, state, k, previous_info = carry
|
|
@@ -282,11 +282,12 @@ def finite_discrete_marginal_rv_logp(op: MarginalFiniteDiscreteRV, values, *inpu
|
|
|
282
282
|
def logp_fn(marginalized_rv_const, *non_sequences):
|
|
283
283
|
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
|
|
284
284
|
|
|
285
|
-
joint_logps
|
|
285
|
+
joint_logps = scan_map(
|
|
286
286
|
fn=logp_fn,
|
|
287
287
|
sequences=marginalized_rv_domain_tensor,
|
|
288
288
|
non_sequences=[*values, *inputs],
|
|
289
289
|
mode=Mode().including("local_remove_check_parameter"),
|
|
290
|
+
return_updates=False,
|
|
290
291
|
)
|
|
291
292
|
|
|
292
293
|
joint_logp = pt.logsumexp(joint_logps, axis=0)
|
|
@@ -350,12 +351,13 @@ def marginal_hmm_logp(op, values, *inputs, **kwargs):
|
|
|
350
351
|
|
|
351
352
|
P_bcast_dims = (len(chain_shape) - 1) - (P.type.ndim - 2)
|
|
352
353
|
log_P = pt.shape_padright(pt.log(P), P_bcast_dims)
|
|
353
|
-
log_alpha_seq
|
|
354
|
+
log_alpha_seq = scan(
|
|
354
355
|
step_alpha,
|
|
355
356
|
non_sequences=[log_P],
|
|
356
357
|
outputs_info=[log_alpha_init],
|
|
357
358
|
# Scan needs the time dimension first, and we already consumed the 1st logp computing the initial value
|
|
358
359
|
sequences=pt.moveaxis(batch_logp_emissions[..., 1:], -1, 0),
|
|
360
|
+
return_updates=False,
|
|
359
361
|
)
|
|
360
362
|
# Final logp is just the sum of the last scan state
|
|
361
363
|
joint_logp = pt.logsumexp(log_alpha_seq[-1], axis=0)
|
|
@@ -11,7 +11,7 @@ from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_po
|
|
|
11
11
|
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
|
|
12
12
|
from pymc.distributions.transforms import Chain
|
|
13
13
|
from pymc.logprob.transforms import IntervalTransform
|
|
14
|
-
from pymc.model import Model
|
|
14
|
+
from pymc.model import Model, modelcontext
|
|
15
15
|
from pymc.model.fgraph import (
|
|
16
16
|
ModelFreeRV,
|
|
17
17
|
ModelValuedVar,
|
|
@@ -337,8 +337,9 @@ def transform_posterior_pts(model, posterior_pts):
|
|
|
337
337
|
|
|
338
338
|
|
|
339
339
|
def recover_marginals(
|
|
340
|
-
model: Model,
|
|
341
340
|
idata: InferenceData,
|
|
341
|
+
*,
|
|
342
|
+
model: Model | None = None,
|
|
342
343
|
var_names: Sequence[str] | None = None,
|
|
343
344
|
return_samples: bool = True,
|
|
344
345
|
extend_inferencedata: bool = True,
|
|
@@ -389,6 +390,15 @@ def recover_marginals(
|
|
|
389
390
|
|
|
390
391
|
|
|
391
392
|
"""
|
|
393
|
+
# Temporary error message for helping with migration
|
|
394
|
+
# Will be removed in a future release
|
|
395
|
+
if isinstance(idata, Model):
|
|
396
|
+
raise TypeError(
|
|
397
|
+
"The order of arguments of `recover_marginals` changed. The first input must be an idata"
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
model = modelcontext(model)
|
|
401
|
+
|
|
392
402
|
unmarginal_model = unmarginalize(model)
|
|
393
403
|
|
|
394
404
|
# Find the names of the marginalized variables
|
pymc_extras/prior.py
CHANGED
|
@@ -1575,9 +1575,9 @@ def __getattr__(name: str):
|
|
|
1575
1575
|
samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]})
|
|
1576
1576
|
|
|
1577
1577
|
"""
|
|
1578
|
-
#
|
|
1579
|
-
if name
|
|
1580
|
-
|
|
1578
|
+
# Ignore Python internal attributes needed for introspection
|
|
1579
|
+
if name.startswith("__"):
|
|
1580
|
+
raise AttributeError(name)
|
|
1581
1581
|
|
|
1582
1582
|
_get_pymc_distribution(name)
|
|
1583
1583
|
return partial(Prior, distribution=name)
|
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from dataclasses import dataclass, fields
|
|
6
|
+
from typing import Generic, Protocol, Self, TypeVar
|
|
7
|
+
|
|
8
|
+
from pytensor.tensor.variable import TensorVariable
|
|
9
|
+
|
|
10
|
+
from pymc_extras.statespace.utils.constants import (
|
|
11
|
+
ALL_STATE_AUX_DIM,
|
|
12
|
+
ALL_STATE_DIM,
|
|
13
|
+
OBS_STATE_AUX_DIM,
|
|
14
|
+
OBS_STATE_DIM,
|
|
15
|
+
SHOCK_AUX_DIM,
|
|
16
|
+
SHOCK_DIM,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StateSpaceLike(Protocol):
|
|
21
|
+
@property
|
|
22
|
+
def state_names(self) -> tuple[str, ...]: ...
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def observed_states(self) -> tuple[str, ...]: ...
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def shock_names(self) -> tuple[str, ...]: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class Property:
|
|
33
|
+
def __str__(self) -> str:
|
|
34
|
+
return "\n".join(f"{f.name}: {getattr(self, f.name)}" for f in fields(self))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
T = TypeVar("T", bound=Property)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(frozen=True)
|
|
41
|
+
class Info(Generic[T]):
|
|
42
|
+
items: tuple[T, ...] | None
|
|
43
|
+
key_field: str | tuple[str, ...] = "name"
|
|
44
|
+
_index: dict[str | tuple, T] | None = None
|
|
45
|
+
|
|
46
|
+
def __post_init__(self):
|
|
47
|
+
index = {}
|
|
48
|
+
if self.items is None:
|
|
49
|
+
object.__setattr__(self, "items", ())
|
|
50
|
+
else:
|
|
51
|
+
object.__setattr__(self, "items", tuple(self.items))
|
|
52
|
+
|
|
53
|
+
for item in self.items:
|
|
54
|
+
key = self._key(item)
|
|
55
|
+
if key in index:
|
|
56
|
+
raise ValueError(f"Duplicate {self.key_field} '{key}' detected.")
|
|
57
|
+
index[key] = item
|
|
58
|
+
object.__setattr__(self, "_index", index)
|
|
59
|
+
|
|
60
|
+
def _key(self, item: T) -> str | tuple:
|
|
61
|
+
if isinstance(self.key_field, tuple):
|
|
62
|
+
return tuple(getattr(item, f) for f in self.key_field)
|
|
63
|
+
return getattr(item, self.key_field)
|
|
64
|
+
|
|
65
|
+
def get(self, key: str | tuple, default=None) -> T | None:
|
|
66
|
+
return self._index.get(key, default)
|
|
67
|
+
|
|
68
|
+
def __getitem__(self, key: str | tuple) -> T:
|
|
69
|
+
try:
|
|
70
|
+
return self._index[key]
|
|
71
|
+
except KeyError as e:
|
|
72
|
+
available = ", ".join(str(k) for k in self._index.keys())
|
|
73
|
+
raise KeyError(f"No {self.key_field} '{key}'. Available: [{available}]") from e
|
|
74
|
+
|
|
75
|
+
def __contains__(self, key: object) -> bool:
|
|
76
|
+
return key in self._index
|
|
77
|
+
|
|
78
|
+
def __iter__(self) -> Iterator[T]:
|
|
79
|
+
return iter(self.items)
|
|
80
|
+
|
|
81
|
+
def __len__(self) -> int:
|
|
82
|
+
return len(self.items)
|
|
83
|
+
|
|
84
|
+
def __str__(self) -> str:
|
|
85
|
+
return f"{self.key_field}s: {tuple(self._index.keys())}"
|
|
86
|
+
|
|
87
|
+
def add(self, new_item: T) -> Self:
|
|
88
|
+
return type(self)((*self.items, new_item))
|
|
89
|
+
|
|
90
|
+
def merge(self, other: Self, overwrite_duplicates: bool = False) -> Self:
|
|
91
|
+
if not isinstance(other, type(self)):
|
|
92
|
+
raise TypeError(f"Cannot merge {type(other).__name__} with {type(self).__name__}")
|
|
93
|
+
|
|
94
|
+
overlapping = set(self._index.keys()) & set(other._index.keys())
|
|
95
|
+
if overlapping and overwrite_duplicates:
|
|
96
|
+
return type(self)(
|
|
97
|
+
(
|
|
98
|
+
*self.items,
|
|
99
|
+
*(item for item in other.items if self._key(item) not in overlapping),
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return type(self)(self.items + other.items)
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def names(self) -> tuple[str, ...]:
|
|
107
|
+
if isinstance(self.key_field, tuple):
|
|
108
|
+
return tuple(item.name for item in self.items)
|
|
109
|
+
return tuple(self._index.keys())
|
|
110
|
+
|
|
111
|
+
def copy(self) -> Info[T]:
|
|
112
|
+
return deepcopy(self)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(frozen=True)
|
|
116
|
+
class Parameter(Property):
|
|
117
|
+
name: str
|
|
118
|
+
shape: tuple[int, ...] | None = None
|
|
119
|
+
dims: tuple[str, ...] | None = None
|
|
120
|
+
constraints: str | None = None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass(frozen=True)
|
|
124
|
+
class ParameterInfo(Info[Parameter]):
|
|
125
|
+
def __init__(self, parameters: tuple[Parameter, ...] | None):
|
|
126
|
+
super().__init__(items=parameters, key_field="name")
|
|
127
|
+
|
|
128
|
+
def to_dict(self):
|
|
129
|
+
return {
|
|
130
|
+
param.name: {"shape": param.shape, "constraints": param.constraints, "dims": param.dims}
|
|
131
|
+
for param in self.items
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@dataclass(frozen=True)
|
|
136
|
+
class Data(Property):
|
|
137
|
+
name: str
|
|
138
|
+
shape: tuple[int | None, ...] | None = None
|
|
139
|
+
dims: tuple[str, ...] | None = None
|
|
140
|
+
is_exogenous: bool = False
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@dataclass(frozen=True)
|
|
144
|
+
class DataInfo(Info[Data]):
|
|
145
|
+
def __init__(self, data: tuple[Data, ...] | None):
|
|
146
|
+
super().__init__(items=data, key_field="name")
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def needs_exogenous_data(self) -> bool:
|
|
150
|
+
return any(d.is_exogenous for d in self.items)
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def exogenous_names(self) -> tuple[str, ...]:
|
|
154
|
+
return tuple(d.name for d in self.items if d.is_exogenous)
|
|
155
|
+
|
|
156
|
+
def __str__(self) -> str:
|
|
157
|
+
return f"data: {[d.name for d in self.items]}\nneeds exogenous data: {self.needs_exogenous_data}"
|
|
158
|
+
|
|
159
|
+
def to_dict(self):
|
|
160
|
+
return {
|
|
161
|
+
data.name: {"shape": data.shape, "dims": data.dims, "exogenous": data.is_exogenous}
|
|
162
|
+
for data in self.items
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(frozen=True)
|
|
167
|
+
class Coord(Property):
|
|
168
|
+
dimension: str
|
|
169
|
+
labels: tuple[str | int, ...]
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@dataclass(frozen=True)
|
|
173
|
+
class CoordInfo(Info[Coord]):
|
|
174
|
+
def __init__(self, coords: tuple[Coord, ...] | None = None):
|
|
175
|
+
super().__init__(items=coords, key_field="dimension")
|
|
176
|
+
|
|
177
|
+
def __str__(self) -> str:
|
|
178
|
+
base = "coordinates:"
|
|
179
|
+
for coord in self.items:
|
|
180
|
+
coord_str = str(coord)
|
|
181
|
+
indented = "\n".join(" " + line for line in coord_str.splitlines())
|
|
182
|
+
base += "\n" + indented + "\n"
|
|
183
|
+
return base
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def default_coords_from_model(cls, model: StateSpaceLike) -> CoordInfo:
|
|
187
|
+
states = tuple(model.state_names)
|
|
188
|
+
obs_states = tuple(model.observed_states)
|
|
189
|
+
shocks = tuple(model.shock_names)
|
|
190
|
+
|
|
191
|
+
dim_to_labels = (
|
|
192
|
+
(ALL_STATE_DIM, states),
|
|
193
|
+
(ALL_STATE_AUX_DIM, states),
|
|
194
|
+
(OBS_STATE_DIM, obs_states),
|
|
195
|
+
(OBS_STATE_AUX_DIM, obs_states),
|
|
196
|
+
(SHOCK_DIM, shocks),
|
|
197
|
+
(SHOCK_AUX_DIM, shocks),
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
coords = tuple(Coord(dimension=dim, labels=labels) for dim, labels in dim_to_labels)
|
|
201
|
+
return cls(coords=coords)
|
|
202
|
+
|
|
203
|
+
def to_dict(self):
|
|
204
|
+
return {coord.dimension: tuple(coord.labels) for coord in self.items}
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@dataclass(frozen=True)
|
|
208
|
+
class State(Property):
|
|
209
|
+
name: str
|
|
210
|
+
observed: bool
|
|
211
|
+
shared: bool = False
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@dataclass(frozen=True)
|
|
215
|
+
class StateInfo(Info[State]):
|
|
216
|
+
def __init__(self, states: tuple[State, ...] | None):
|
|
217
|
+
super().__init__(items=states, key_field=("name", "observed"))
|
|
218
|
+
|
|
219
|
+
def __contains__(self, key: object) -> bool:
|
|
220
|
+
if isinstance(key, str):
|
|
221
|
+
return any(s.name == key for s in self.items)
|
|
222
|
+
return key in self._index
|
|
223
|
+
|
|
224
|
+
def __str__(self) -> str:
|
|
225
|
+
return (
|
|
226
|
+
f"states: {[s.name for s in self.items]}\nobserved: {[s.observed for s in self.items]}"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def observed_state_names(self) -> tuple[str, ...]:
|
|
231
|
+
return tuple(s.name for s in self.items if s.observed)
|
|
232
|
+
|
|
233
|
+
@property
|
|
234
|
+
def unobserved_state_names(self) -> tuple[str, ...]:
|
|
235
|
+
return tuple(s.name for s in self.items if not s.observed)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@dataclass(frozen=True)
|
|
239
|
+
class Shock(Property):
|
|
240
|
+
name: str
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@dataclass(frozen=True)
|
|
244
|
+
class ShockInfo(Info[Shock]):
|
|
245
|
+
def __init__(self, shocks: tuple[Shock, ...] | None):
|
|
246
|
+
super().__init__(items=shocks, key_field="name")
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@dataclass(frozen=True)
|
|
250
|
+
class SymbolicVariable(Property):
|
|
251
|
+
name: str
|
|
252
|
+
symbolic_variable: TensorVariable
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@dataclass(frozen=True)
|
|
256
|
+
class SymbolicVariableInfo(Info[SymbolicVariable]):
|
|
257
|
+
def __init__(self, symbolic_variables: tuple[SymbolicVariable, ...] | None = None):
|
|
258
|
+
super().__init__(items=symbolic_variables, key_field="name")
|
|
259
|
+
|
|
260
|
+
def to_dict(self):
|
|
261
|
+
return {variable.name: variable.symbolic_variable for variable in self.items}
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@dataclass(frozen=True)
|
|
265
|
+
class SymbolicData(Property):
|
|
266
|
+
name: str
|
|
267
|
+
symbolic_data: TensorVariable
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dataclass(frozen=True)
|
|
271
|
+
class SymbolicDataInfo(Info[SymbolicData]):
|
|
272
|
+
def __init__(self, symbolic_data: tuple[SymbolicData, ...] | None = None):
|
|
273
|
+
super().__init__(items=symbolic_data, key_field="name")
|
|
274
|
+
|
|
275
|
+
def to_dict(self):
|
|
276
|
+
return {data.name: data.symbolic_data for data in self.items}
|