pymc-extras 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ import functools as ft
2
2
  import warnings
3
3
 
4
4
  from collections import defaultdict
5
+ from copyreg import remove_extension
5
6
  from typing import Optional
6
7
 
7
8
  import numpy as np
@@ -592,13 +593,18 @@ def test_autoregressive_model(order, rng):
592
593
 
593
594
  @pytest.mark.parametrize("s", [10, 25, 50])
594
595
  @pytest.mark.parametrize("innovations", [True, False])
595
- def test_time_seasonality(s, innovations, rng):
596
+ @pytest.mark.parametrize("remove_first_state", [True, False])
597
+ def test_time_seasonality(s, innovations, remove_first_state, rng):
596
598
  def random_word(rng):
597
599
  return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))
598
600
 
599
601
  state_names = [random_word(rng) for _ in range(s)]
600
602
  mod = st.TimeSeasonality(
601
- season_length=s, innovations=innovations, name="season", state_names=state_names
603
+ season_length=s,
604
+ innovations=innovations,
605
+ name="season",
606
+ state_names=state_names,
607
+ remove_first_state=remove_first_state,
602
608
  )
603
609
  x0 = np.zeros(mod.k_states, dtype=floatX)
604
610
  x0[0] = 1
@@ -615,7 +621,8 @@ def test_time_seasonality(s, innovations, rng):
615
621
  # Check coords
616
622
  mod.build(verbose=False)
617
623
  _assert_basic_coords_correct(mod)
618
- assert mod.coords["season_state"] == state_names[1:]
624
+ test_slice = slice(1, None) if remove_first_state else slice(None)
625
+ assert mod.coords["season_state"] == state_names[test_slice]
619
626
 
620
627
 
621
628
  def get_shift_factor(s):
@@ -8,7 +8,7 @@
8
8
  # pass
9
9
  # import numpy as np
10
10
  #
11
- # import pymc_experimental as pmx
11
+ # import pymc_extras as pmx
12
12
  #
13
13
  #
14
14
  # def test_match_gpytorch_linearcg_output():
tests/utils.py CHANGED
@@ -1,31 +0,0 @@
1
- from collections.abc import Sequence
2
-
3
- from pytensor.compile import SharedVariable
4
- from pytensor.graph import Constant, graph_inputs
5
- from pytensor.graph.basic import Variable, equal_computations
6
- from pytensor.tensor.random.type import RandomType
7
-
8
-
9
- def equal_computations_up_to_root(
10
- xs: Sequence[Variable], ys: Sequence[Variable], ignore_rng_values=True
11
- ) -> bool:
12
- # Check if graphs are equivalent even if root variables have distinct identities
13
-
14
- x_graph_inputs = [var for var in graph_inputs(xs) if not isinstance(var, Constant)]
15
- y_graph_inputs = [var for var in graph_inputs(ys) if not isinstance(var, Constant)]
16
- if len(x_graph_inputs) != len(y_graph_inputs):
17
- return False
18
- for x, y in zip(x_graph_inputs, y_graph_inputs):
19
- if x.type != y.type:
20
- return False
21
- if x.name != y.name:
22
- return False
23
- if isinstance(x, SharedVariable):
24
- if not isinstance(y, SharedVariable):
25
- return False
26
- if isinstance(x.type, RandomType) and ignore_rng_values:
27
- continue
28
- if not x.type.values_eq(x.get_value(), y.get_value()):
29
- return False
30
-
31
- return equal_computations(xs, ys, in_xs=x_graph_inputs, in_ys=y_graph_inputs)