pymc-extras 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +5 -1
- pymc_extras/distributions/timeseries.py +1 -1
- pymc_extras/model/marginal/distributions.py +100 -3
- pymc_extras/model/marginal/graph_analysis.py +8 -9
- pymc_extras/model/marginal/marginal_model.py +437 -424
- pymc_extras/statespace/models/structural.py +21 -6
- pymc_extras/utils/model_equivalence.py +66 -0
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/METADATA +3 -4
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/RECORD +18 -17
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/WHEEL +1 -1
- tests/model/marginal/test_distributions.py +12 -11
- tests/model/marginal/test_marginal_model.py +301 -201
- tests/statespace/test_structural.py +10 -3
- tests/test_pivoted_cholesky.py +1 -1
- tests/utils.py +0 -31
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.0.dist-info → pymc_extras-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@ import functools as ft
|
|
|
2
2
|
import warnings
|
|
3
3
|
|
|
4
4
|
from collections import defaultdict
|
|
5
|
+
from copyreg import remove_extension
|
|
5
6
|
from typing import Optional
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
@@ -592,13 +593,18 @@ def test_autoregressive_model(order, rng):
|
|
|
592
593
|
|
|
593
594
|
@pytest.mark.parametrize("s", [10, 25, 50])
|
|
594
595
|
@pytest.mark.parametrize("innovations", [True, False])
|
|
595
|
-
|
|
596
|
+
@pytest.mark.parametrize("remove_first_state", [True, False])
|
|
597
|
+
def test_time_seasonality(s, innovations, remove_first_state, rng):
|
|
596
598
|
def random_word(rng):
|
|
597
599
|
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))
|
|
598
600
|
|
|
599
601
|
state_names = [random_word(rng) for _ in range(s)]
|
|
600
602
|
mod = st.TimeSeasonality(
|
|
601
|
-
season_length=s,
|
|
603
|
+
season_length=s,
|
|
604
|
+
innovations=innovations,
|
|
605
|
+
name="season",
|
|
606
|
+
state_names=state_names,
|
|
607
|
+
remove_first_state=remove_first_state,
|
|
602
608
|
)
|
|
603
609
|
x0 = np.zeros(mod.k_states, dtype=floatX)
|
|
604
610
|
x0[0] = 1
|
|
@@ -615,7 +621,8 @@ def test_time_seasonality(s, innovations, rng):
|
|
|
615
621
|
# Check coords
|
|
616
622
|
mod.build(verbose=False)
|
|
617
623
|
_assert_basic_coords_correct(mod)
|
|
618
|
-
|
|
624
|
+
test_slice = slice(1, None) if remove_first_state else slice(None)
|
|
625
|
+
assert mod.coords["season_state"] == state_names[test_slice]
|
|
619
626
|
|
|
620
627
|
|
|
621
628
|
def get_shift_factor(s):
|
tests/test_pivoted_cholesky.py
CHANGED
tests/utils.py
CHANGED
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
from collections.abc import Sequence
|
|
2
|
-
|
|
3
|
-
from pytensor.compile import SharedVariable
|
|
4
|
-
from pytensor.graph import Constant, graph_inputs
|
|
5
|
-
from pytensor.graph.basic import Variable, equal_computations
|
|
6
|
-
from pytensor.tensor.random.type import RandomType
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def equal_computations_up_to_root(
|
|
10
|
-
xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
|
|
11
|
-
) -> bool:
|
|
12
|
-
# Check if graphs are equivalent even if root variables have distinct identities
|
|
13
|
-
|
|
14
|
-
x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
|
|
15
|
-
y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
|
|
16
|
-
if len(x_graph_inputs) != len(y_graph_inputs):
|
|
17
|
-
return False
|
|
18
|
-
for x, y in zip(x_graph_inputs, y_graph_inputs):
|
|
19
|
-
if x.type != y.type:
|
|
20
|
-
return False
|
|
21
|
-
if x.name != y.name:
|
|
22
|
-
return False
|
|
23
|
-
if isinstance(x, SharedVariable):
|
|
24
|
-
if not isinstance(y, SharedVariable):
|
|
25
|
-
return False
|
|
26
|
-
if isinstance(x.type, RandomType) and ignore_rng_values:
|
|
27
|
-
continue
|
|
28
|
-
if not x.type.values_eq(x.get_value(), y.get_value()):
|
|
29
|
-
return False
|
|
30
|
-
|
|
31
|
-
return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)
|
|
File without changes
|
|
File without changes
|