pymc-extras 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pymc-extras
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Project-URL: Documentation, https://pymc-extras.readthedocs.io/
6
6
  Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
@@ -234,6 +234,8 @@ Classifier: Topic :: Scientific/Engineering
234
234
  Classifier: Topic :: Scientific/Engineering :: Mathematics
235
235
  Requires-Python: >=3.10
236
236
  Requires-Dist: better-optimize>=0.1.2
237
+ Requires-Dist: preliz
238
+ Requires-Dist: pydantic>=2.0.0
237
239
  Requires-Dist: pymc>=5.21.1
238
240
  Requires-Dist: scikit-learn
239
241
  Provides-Extra: complete
@@ -245,6 +247,7 @@ Requires-Dist: xhistogram; extra == 'dask-histogram'
245
247
  Provides-Extra: dev
246
248
  Requires-Dist: blackjax; extra == 'dev'
247
249
  Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
250
+ Requires-Dist: pytest-mock; extra == 'dev'
248
251
  Requires-Dist: pytest>=6.0; extra == 'dev'
249
252
  Requires-Dist: statsmodels; extra == 'dev'
250
253
  Provides-Extra: docs
@@ -1,7 +1,9 @@
1
1
  pymc_extras/__init__.py,sha256=YsR6OG72aW73y6dGS7w3nGGMV-V-ImHkmUOXKMPfMRA,1230
2
+ pymc_extras/deserialize.py,sha256=dktK5gsR96X3zAUoRF5udrTiconknH3uupiAWqkZi0M,5937
2
3
  pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
3
4
  pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
5
  pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
6
+ pymc_extras/prior.py,sha256=dmw9Jz4DXRxT9jA-L3QSgMOODKqcim4NC5XguARSbxU,38718
5
7
  pymc_extras/distributions/__init__.py,sha256=fDbrBt9mxEVp2CDPwnyCW3oiutzZ0PduB8EUH3fUrjI,1377
6
8
  pymc_extras/distributions/continuous.py,sha256=530wvcO-QcYVdiVN-iQRveImWfyJzzmxiZLMVShP7w4,11251
7
9
  pymc_extras/distributions/discrete.py,sha256=HNi-K0_hnNWTcfyBkWGh26sc71FwBgukQ_EjGAaAOjY,13036
@@ -56,10 +58,9 @@ pymc_extras/statespace/utils/data_tools.py,sha256=01sz6XDtLYK9I5xghxYpD-PuDzGXv9
56
58
  pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
57
59
  pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
58
60
  pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
59
- pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
60
61
  pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
61
62
  pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
62
- pymc_extras-0.2.6.dist-info/METADATA,sha256=zzdhVkdzXhL7MQH3R0uiCsrcl5i5uh1JLVdRBG6jJyY,18813
63
- pymc_extras-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
64
- pymc_extras-0.2.6.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
65
- pymc_extras-0.2.6.dist-info/RECORD,,
63
+ pymc_extras-0.2.7.dist-info/METADATA,sha256=hJCZrC9jdx_GkpFhHSzWgdUGxx6TZGw6jwYXeNIJ8-c,18909
64
+ pymc_extras-0.2.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
65
+ pymc_extras-0.2.7.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
66
+ pymc_extras-0.2.7.dist-info/RECORD,,
@@ -1,69 +0,0 @@
1
- try:
2
- import torch
3
-
4
- from gpytorch.utils.permutation import apply_permutation
5
- except ImportError as e:
6
- raise ImportError("PyTorch and GPyTorch not found.") from e
7
-
8
- import numpy as np
9
-
10
-
11
- def pp(x):
12
- return np.array2string(x, precision=4, floatmode="fixed")
13
-
14
-
15
- def pivoted_cholesky(mat: np.matrix, error_tol=1e-6, max_iter=np.inf):
16
- """
17
- mat: numpy matrix of N x N
18
-
19
- This is to replicate what is done in GPyTorch verbatim.
20
- """
21
- n = mat.shape[-1]
22
- max_iter = min(int(max_iter), n)
23
-
24
- d = np.array(np.diag(mat))
25
- orig_error = np.max(d)
26
- error = np.linalg.norm(d, 1) / orig_error
27
- pi = np.arange(n)
28
-
29
- L = np.zeros((max_iter, n))
30
-
31
- m = 0
32
- while m < max_iter and error > error_tol:
33
- permuted_d = d[pi]
34
- max_diag_idx = np.argmax(permuted_d[m:])
35
- max_diag_idx = max_diag_idx + m
36
- max_diag_val = permuted_d[max_diag_idx]
37
- i = max_diag_idx
38
-
39
- # swap pi_m and pi_i
40
- pi[m], pi[i] = pi[i], pi[m]
41
- pim = pi[m]
42
-
43
- L[m, pim] = np.sqrt(max_diag_val)
44
-
45
- if m + 1 < n:
46
- row = apply_permutation(
47
- torch.from_numpy(mat), torch.tensor(pim), right_permutation=None
48
- ) # left permutation just swaps row
49
- row = row.numpy().flatten()
50
- pi_i = pi[m + 1 :]
51
- L_m_new = row[pi_i] # length = 9
52
-
53
- if m > 0:
54
- L_prev = L[:m, pi_i]
55
- update = L[:m, pim]
56
- prod = update @ L_prev
57
- L_m_new = L_m_new - prod # np.sum(prod, axis=-1)
58
-
59
- L_m = L[m, :]
60
- L_m_new = L_m_new / L_m[pim]
61
- L_m[pi_i] = L_m_new
62
-
63
- matrix_diag_current = d[pi_i]
64
- d[pi_i] = matrix_diag_current - L_m_new**2
65
-
66
- L[m, :] = L_m
67
- error = np.linalg.norm(d[pi_i], 1) / orig_error
68
- m = m + 1
69
- return L, pi