pymc-extras 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. pymc_extras/__init__.py +5 -1
  2. pymc_extras/deserialize.py +224 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/inference/find_map.py +62 -17
  6. pymc_extras/inference/laplace.py +10 -7
  7. pymc_extras/prior.py +1356 -0
  8. pymc_extras/statespace/core/statespace.py +191 -52
  9. pymc_extras/statespace/filters/distributions.py +15 -16
  10. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  11. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  12. pymc_extras/statespace/models/ETS.py +10 -0
  13. pymc_extras/statespace/models/SARIMAX.py +26 -5
  14. pymc_extras/statespace/models/VARMAX.py +12 -2
  15. pymc_extras/statespace/models/structural.py +18 -5
  16. pymc_extras-0.2.7.dist-info/METADATA +321 -0
  17. pymc_extras-0.2.7.dist-info/RECORD +66 -0
  18. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/WHEEL +1 -2
  19. pymc_extras/utils/pivoted_cholesky.py +0 -69
  20. pymc_extras/version.py +0 -11
  21. pymc_extras/version.txt +0 -1
  22. pymc_extras-0.2.5.dist-info/METADATA +0 -112
  23. pymc_extras-0.2.5.dist-info/RECORD +0 -108
  24. pymc_extras-0.2.5.dist-info/top_level.txt +0 -2
  25. tests/__init__.py +0 -13
  26. tests/distributions/__init__.py +0 -19
  27. tests/distributions/test_continuous.py +0 -185
  28. tests/distributions/test_discrete.py +0 -210
  29. tests/distributions/test_discrete_markov_chain.py +0 -258
  30. tests/distributions/test_multivariate.py +0 -304
  31. tests/distributions/test_transform.py +0 -77
  32. tests/model/__init__.py +0 -0
  33. tests/model/marginal/__init__.py +0 -0
  34. tests/model/marginal/test_distributions.py +0 -132
  35. tests/model/marginal/test_graph_analysis.py +0 -182
  36. tests/model/marginal/test_marginal_model.py +0 -967
  37. tests/model/test_model_api.py +0 -38
  38. tests/statespace/__init__.py +0 -0
  39. tests/statespace/test_ETS.py +0 -411
  40. tests/statespace/test_SARIMAX.py +0 -405
  41. tests/statespace/test_VARMAX.py +0 -184
  42. tests/statespace/test_coord_assignment.py +0 -181
  43. tests/statespace/test_distributions.py +0 -270
  44. tests/statespace/test_kalman_filter.py +0 -326
  45. tests/statespace/test_representation.py +0 -175
  46. tests/statespace/test_statespace.py +0 -872
  47. tests/statespace/test_statespace_JAX.py +0 -156
  48. tests/statespace/test_structural.py +0 -836
  49. tests/statespace/utilities/__init__.py +0 -0
  50. tests/statespace/utilities/shared_fixtures.py +0 -9
  51. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  52. tests/statespace/utilities/test_helpers.py +0 -310
  53. tests/test_blackjax_smc.py +0 -222
  54. tests/test_find_map.py +0 -103
  55. tests/test_histogram_approximation.py +0 -109
  56. tests/test_laplace.py +0 -281
  57. tests/test_linearmodel.py +0 -208
  58. tests/test_model_builder.py +0 -306
  59. tests/test_pathfinder.py +0 -297
  60. tests/test_pivoted_cholesky.py +0 -24
  61. tests/test_printing.py +0 -98
  62. tests/test_prior_from_trace.py +0 -172
  63. tests/test_splines.py +0 -77
  64. tests/utils.py +0 -0
  65. {pymc_extras-0.2.5.dist-info → pymc_extras-0.2.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,108 +0,0 @@
1
- pymc_extras/__init__.py,sha256=lYGf9TcwUHROIElkX7Epnb7-IppcmiSEYuxdtRzqS3s,1195
2
- pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
3
- pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
4
- pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
5
- pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
6
- pymc_extras/version.txt,sha256=6Vn3UOktu3YUriislvCjcnLK7YHYu7dYeRr3v7thBqA,6
7
- pymc_extras/distributions/__init__.py,sha256=fDbrBt9mxEVp2CDPwnyCW3oiutzZ0PduB8EUH3fUrjI,1377
8
- pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
9
- pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
10
- pymc_extras/distributions/histogram_utils.py,sha256=5RTvlGCUrp2qzshrchmPyWxjhs6RIYL62SMikjDM1TU,5814
11
- pymc_extras/distributions/timeseries.py,sha256=M5MZ-nik_tgkaoZ1hdUGEZ9g04DQyVLwszVJqSKwNcY,12719
12
- pymc_extras/distributions/multivariate/__init__.py,sha256=E8OeLW9tTotCbrUjEo4um76-_WQD56PehsPzkKmhfyA,93
13
- pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNPi727uTbvAdxl9fm1zNBqY,16005
14
- pymc_extras/distributions/transforms/__init__.py,sha256=FUp2vyRE6_2eUcQ_FVt5Dn0-vy5I-puV-Kz13-QtLNc,104
15
- pymc_extras/distributions/transforms/partial_order.py,sha256=oEZlc9WgnGR46uFEjLzKEUxlhzIo2vrUUbBE3vYrsfQ,8404
16
- pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
17
- pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
18
- pymc_extras/inference/__init__.py,sha256=UH6S0bGfQKKyTSuqf7yezdy9PeE2bDU8U1v4eIRv4ZI,887
19
- pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOnJ8LU,16513
20
- pymc_extras/inference/fit.py,sha256=oe20RAajImZ-VD9Ucbzri8Bof4Y2KHNhNRG19v9O3lI,1336
21
- pymc_extras/inference/laplace.py,sha256=cqarAdbFaOH74AkPUF4c7c_Hswa5mqmhgHpsgrkebHY,21860
22
- pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
23
- pymc_extras/inference/pathfinder/importance_sampling.py,sha256=NwxepXOFit3cA5zEebniKdlnJ1rZWg56aMlH4MEOcG4,6264
24
- pymc_extras/inference/pathfinder/lbfgs.py,sha256=GOoJBil5Kft_iFwGNUGKSeqzI5x_shA4KQWDwgGuQtQ,7110
25
- pymc_extras/inference/pathfinder/pathfinder.py,sha256=GW04HQurj_3Nlo1C6_K2tEIeigo8x0buV3FqDLA88PQ,64439
26
- pymc_extras/inference/smc/__init__.py,sha256=wyaT4NJl1YsSQRLiDy-i0Jq3CbJZ2BQd4nnCk-dIngY,603
27
- pymc_extras/inference/smc/sampling.py,sha256=AYwmKqGoV6pBtKnh9SUbBKbN7VcoFgb3MmNWV7SivMA,15365
28
- pymc_extras/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
- pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_4,2400
30
- pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
- pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
32
- pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
33
- pymc_extras/model/marginal/marginal_model.py,sha256=oIdikaSnefCkyMxmzAe222qGXNucxZpHYk7548fK6iA,23631
34
- pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
- pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
36
- pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
- pymc_extras/preprocessing/standard_scaler.py,sha256=Vajp33ma6OkwlU54JYtSS8urHbMJ3CRiRFxZpvFNuus,600
38
- pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53CF6ND0,429
39
- pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
40
- pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
41
- pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
42
- pymc_extras/statespace/core/statespace.py,sha256=Tx-821UNNLqsZgHzRmwaQ6s-agp_OthqSsbfwDpA1o0,96927
43
- pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
44
- pymc_extras/statespace/filters/distributions.py,sha256=ejimTFLgBFZMEznxY5zh6u4Vrqij60i0k2_sxdPcZ3A,11878
45
- pymc_extras/statespace/filters/kalman_filter.py,sha256=HELC3aK4k8EdWlUAk5_F7y7YkIz-Xi_0j2AwRgAXgcc,31949
46
- pymc_extras/statespace/filters/kalman_smoother.py,sha256=ypH9K_88nfJ5K2Cq737aWL3p8v4UfI7MxnYs54WPdDs,4329
47
- pymc_extras/statespace/filters/utilities.py,sha256=iwdaYnO1cO06t_XUjLLRmqb8vwzzVH6Nx1iyZcbJL2k,1584
48
- pymc_extras/statespace/models/ETS.py,sha256=o039M-6aCxyMXbbKvUeNVZhheCKvvNIAmuj0f-ziMEc,28047
49
- pymc_extras/statespace/models/SARIMAX.py,sha256=SX0eiSK1pOt4dHBjWzBqVpRz67pBGLN5pQQgXcOiOgY,21607
50
- pymc_extras/statespace/models/VARMAX.py,sha256=xkIuftNc_5NHFpqZalExni99-1kovnzm5OjMIDNgaxY,15989
51
- pymc_extras/statespace/models/__init__.py,sha256=U79b8rTHBNijVvvGOd43nLu4PCloPUH1rwlN87-n88c,317
52
- pymc_extras/statespace/models/structural.py,sha256=sep9pesJdRN4X8Bea6_RhO3112uWOZRuYRxO6ibl_OA,63943
53
- pymc_extras/statespace/models/utilities.py,sha256=G9GuHKsghmIYOlfkPtvxBWF-FZY5-5JI1fJQM8N7EnE,15373
54
- pymc_extras/statespace/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
55
- pymc_extras/statespace/utils/constants.py,sha256=Kf6j75ABaDQeRODxKQ76wTUQV4F5sTjn1KBcZgCQx20,2403
56
- pymc_extras/statespace/utils/coord_tools.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
- pymc_extras/statespace/utils/data_tools.py,sha256=01sz6XDtLYK9I5xghxYpD-PuDzGXv9D-wFGfTV6FGEw,6566
58
- pymc_extras/utils/__init__.py,sha256=yxI9cJ7fCtVQS0GFw0y6mDGZIQZiK53vm3UNKqIuGSk,758
59
- pymc_extras/utils/linear_cg.py,sha256=KkXhuimFsrKtNd_0By2ApxQQQNm5FdBtmDQJOVbLYkA,10056
60
- pymc_extras/utils/model_equivalence.py,sha256=8QIftID2HDxD659i0RXHazQ-l2Q5YegCRLcDqb2p9Pc,2187
61
- pymc_extras/utils/pivoted_cholesky.py,sha256=QtnjP0pAl9b77fLAu-semwT4_9dcoiqx3dz1xKGBjMk,1871
62
- pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,6841
63
- pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
64
- pymc_extras-0.2.5.dist-info/licenses/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
65
- tests/__init__.py,sha256=-ree9OWVCyTeXLR944OWjrQX2os15HXrRNkhJ7QdRjc,603
66
- tests/test_blackjax_smc.py,sha256=jcNgcMBxaKyPg9UvHnWQtwoL79LXlSpZfALe3RGEZnQ,7233
67
- tests/test_find_map.py,sha256=B8ThnXNyfTQeem24QaLoTitFrsxKoq2VQINUdOwzna0,3379
68
- tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6YxzqA5X4,4674
69
- tests/test_laplace.py,sha256=fArHjwMR7x98K-gZLvrvb3AwNZ7_fo7E0A4SJyt4EGU,9843
70
- tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
71
- tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
72
- tests/test_pathfinder.py,sha256=vlMI1p2Ja5X4QIaSV4h6U41I303rEppfO0JqE3xe1Rs,10023
73
- tests/test_pivoted_cholesky.py,sha256=PuMdMSCzO4KdQWpUF4SEBeuH_qsINCIH8TYtmmJ1NKo,692
74
- tests/test_printing.py,sha256=HnvwwjrjBuxXFAJdyU0K_lvKGLgh4nzHAnhsIUpenbY,5211
75
- tests/test_prior_from_trace.py,sha256=HOzR3l98pl7TEJquo_kSugED4wBTgHo4-8lgnpmacs8,5516
76
- tests/test_splines.py,sha256=xSZi4hqqReN1H8LHr0xjDmpomhDQm8auIsWQjFOyjbM,2608
77
- tests/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
78
- tests/distributions/__init__.py,sha256=jt-oloszTLNFwi9AgU3M4m6xKQ8xpQE338rmmaMZcMs,795
79
- tests/distributions/test_continuous.py,sha256=1-bu-IP6RgLUJnuPYpOD8ZS1ahYbKtsJ9oflBfqCaFo,5477
80
- tests/distributions/test_discrete.py,sha256=CjjaUpppsvQ6zLzV15ZsbwNOKrDmEdz4VWcleoCXUi0,7776
81
- tests/distributions/test_discrete_markov_chain.py,sha256=8RCHZXSB8IWjniuKaGGlM_iTWGmdrcOqginxmrAeEJg,9212
82
- tests/distributions/test_multivariate.py,sha256=LBvBuoT_3rzi8rR38b8L441Y-9Ff0cIXeRBKiEn6kjs,10452
83
- tests/distributions/test_transform.py,sha256=QM9sSQ5eSbuT2pM76nUMWqb-tQa7DGZbT9uwFDqIRUk,2672
84
- tests/model/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
85
- tests/model/test_model_api.py,sha256=FJvMTmexovRELZOUcUyk-6Vwk9qSiH7hIFoiArgl5mk,1040
86
- tests/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
87
- tests/model/marginal/test_distributions.py,sha256=p5f73g4ogxYkdZaBndZV_1ra8TCppXiRlUpaaTwEe-M,5195
88
- tests/model/marginal/test_graph_analysis.py,sha256=raoj41NusMOj1zzPCrxrlQODqX6Ey8Ft_o32pNTe5qg,6712
89
- tests/model/marginal/test_marginal_model.py,sha256=uOmARalkdWq3sDbnJQ0KjiLwviqauZOAnafmYS_Cnd8,35475
90
- tests/statespace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
91
- tests/statespace/test_ETS.py,sha256=IPg3uQ7xEGqDMEHu993vtUTV7r-uNAxmw23sr5MVGfQ,15582
92
- tests/statespace/test_SARIMAX.py,sha256=1BYNOm9aSHnpn-qbpe3YsQVH8m-mXcp_gvKgWhWn1W4,12948
93
- tests/statespace/test_VARMAX.py,sha256=rJnea9_WEGo9I0iv2eaSbwwFQv0tlIjpN7KAE0eQewU,6336
94
- tests/statespace/test_coord_assignment.py,sha256=2Mo5196ibkBTwscE7kqQoUsgQphdaagVkOccDi7D4RI,5980
95
- tests/statespace/test_distributions.py,sha256=WQ_ROyd-PL3cimXTyEtyVaMEVtS7Hue2Z0lN7UnGDyo,9122
96
- tests/statespace/test_kalman_filter.py,sha256=s2n62FzXl9elU_uqaMNaEaexUfq3SXe3_YvQ2lM6hiQ,11600
97
- tests/statespace/test_representation.py,sha256=1KAJY4ZaVhb1WdAJLx2UYSXuVYsMNWX98gEDF7P0B4s,6210
98
- tests/statespace/test_statespace.py,sha256=JoupFFpG8PmpB_NFV471IuTmyXhEd6_vOISwVCRrBBM,30570
99
- tests/statespace/test_statespace_JAX.py,sha256=hZOc6xxYdVeATPCKmcHMLOVcuvdzGRzgQQ4RrDenwk8,5279
100
- tests/statespace/test_structural.py,sha256=HD8OaGbjuH4y3xv_uG-R1xLZpPpcb4-3dbcTeb_imLY,29306
101
- tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
102
- tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
103
- tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
104
- tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
105
- pymc_extras-0.2.5.dist-info/METADATA,sha256=02v5liTQQ55sV8xeFl5EjFpwbOKSYKG6g5lE_4htpBo,5227
106
- pymc_extras-0.2.5.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
107
- pymc_extras-0.2.5.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
108
- pymc_extras-0.2.5.dist-info/RECORD,,
@@ -1,2 +0,0 @@
1
- pymc_extras
2
- tests
tests/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- # Copyright 2020 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
@@ -1,19 +0,0 @@
1
- # Copyright 2022 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from pymc_extras.distributions import histogram_utils
17
- from pymc_extras.distributions.histogram_utils import histogram_approximation
18
-
19
- __all__ = ["histogram_utils", "histogram_approximation"]
@@ -1,185 +0,0 @@
1
- # Copyright 2020 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import numpy as np
15
- import pymc as pm
16
-
17
- # general imports
18
- import pytest
19
- import scipy.stats.distributions as sp
20
-
21
-
22
- # test support imports from pymc
23
- from pymc.testing import (
24
- BaseTestDistributionRandom,
25
- Domain,
26
- R,
27
- Rplus,
28
- Rplusbig,
29
- assert_support_point_is_expected,
30
- check_logcdf,
31
- check_logp,
32
- seeded_scipy_distribution_builder,
33
- select_by_precision,
34
- )
35
-
36
- # the distributions to be tested
37
- from pymc_extras.distributions import Chi, GenExtreme, Maxwell
38
-
39
-
40
- class TestGenExtremeClass:
41
- """
42
- Wrapper class so that tests of experimental additions can be dropped into
43
- PyMC directly on adoption.
44
-
45
- pm.logp(GenExtreme.dist(mu=0.,sigma=1.,xi=0.5),value=-0.01)
46
- """
47
-
48
- def test_logp(self):
49
- def ref_logp(value, mu, sigma, xi):
50
- if 1 + xi * (value - mu) / sigma > 0:
51
- return sp.genextreme.logpdf(value, c=-xi, loc=mu, scale=sigma)
52
- else:
53
- return -np.inf
54
-
55
- check_logp(
56
- GenExtreme,
57
- R,
58
- {
59
- "mu": R,
60
- "sigma": Rplusbig,
61
- "xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
62
- },
63
- ref_logp,
64
- )
65
-
66
- def test_logcdf(self):
67
- def ref_logcdf(value, mu, sigma, xi):
68
- if 1 + xi * (value - mu) / sigma > 0:
69
- return sp.genextreme.logcdf(value, c=-xi, loc=mu, scale=sigma)
70
- else:
71
- return -np.inf
72
-
73
- check_logcdf(
74
- GenExtreme,
75
- R,
76
- {
77
- "mu": R,
78
- "sigma": Rplusbig,
79
- "xi": Domain([-1, -0.99, -0.5, 0, 0.5, 0.99, 1]),
80
- },
81
- ref_logcdf,
82
- decimal=select_by_precision(float64=6, float32=2),
83
- )
84
-
85
- @pytest.mark.parametrize(
86
- "mu, sigma, xi, size, expected",
87
- [
88
- (0, 1, 0, None, 0),
89
- (1, np.arange(1, 4), 0.1, None, 1 + np.arange(1, 4) * (1.1**-0.1 - 1) / 0.1),
90
- (np.arange(5), 1, 0.1, None, np.arange(5) + (1.1**-0.1 - 1) / 0.1),
91
- (
92
- 0,
93
- 1,
94
- np.linspace(-0.2, 0.2, 6),
95
- None,
96
- ((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1)
97
- / np.linspace(-0.2, 0.2, 6),
98
- ),
99
- (1, 2, 0.1, 5, np.full(5, 1 + 2 * (1.1**-0.1 - 1) / 0.1)),
100
- (
101
- np.arange(6),
102
- np.arange(1, 7),
103
- np.linspace(-0.2, 0.2, 6),
104
- (3, 6),
105
- np.full(
106
- (3, 6),
107
- np.arange(6)
108
- + np.arange(1, 7)
109
- * ((1 + np.linspace(-0.2, 0.2, 6)) ** -np.linspace(-0.2, 0.2, 6) - 1)
110
- / np.linspace(-0.2, 0.2, 6),
111
- ),
112
- ),
113
- ],
114
- )
115
- def test_genextreme_support_point(self, mu, sigma, xi, size, expected):
116
- with pm.Model() as model:
117
- GenExtreme("x", mu=mu, sigma=sigma, xi=xi, size=size)
118
- assert_support_point_is_expected(model, expected)
119
-
120
- def test_gen_extreme_scipy_kwarg(self):
121
- dist = GenExtreme.dist(xi=1, scipy=False)
122
- assert dist.owner.inputs[-1].eval() == 1
123
-
124
- dist = GenExtreme.dist(xi=1, scipy=True)
125
- assert dist.owner.inputs[-1].eval() == -1
126
-
127
-
128
- class TestGenExtreme(BaseTestDistributionRandom):
129
- pymc_dist = GenExtreme
130
- pymc_dist_params = {"mu": 0, "sigma": 1, "xi": -0.1}
131
- expected_rv_op_params = {"mu": 0, "sigma": 1, "xi": -0.1}
132
- # Notice, using different parametrization of xi sign to scipy
133
- reference_dist_params = {"loc": 0, "scale": 1, "c": 0.1}
134
- reference_dist = seeded_scipy_distribution_builder("genextreme")
135
- tests_to_run = [
136
- "check_pymc_params_match_rv_op",
137
- "check_pymc_draws_match_reference",
138
- "check_rv_size",
139
- ]
140
-
141
-
142
- class TestChiClass:
143
- """
144
- Wrapper class so that tests of experimental additions can be dropped into
145
- PyMC directly on adoption.
146
- """
147
-
148
- def test_logp(self):
149
- check_logp(
150
- Chi,
151
- Rplus,
152
- {"nu": Rplus},
153
- lambda value, nu: sp.chi.logpdf(value, df=nu),
154
- )
155
-
156
- def test_logcdf(self):
157
- check_logcdf(
158
- Chi,
159
- Rplus,
160
- {"nu": Rplus},
161
- lambda value, nu: sp.chi.logcdf(value, df=nu),
162
- )
163
-
164
-
165
- class TestMaxwell:
166
- """
167
- Wrapper class so that tests of experimental additions can be dropped into
168
- PyMC directly on adoption.
169
- """
170
-
171
- def test_logp(self):
172
- check_logp(
173
- Maxwell,
174
- Rplus,
175
- {"a": Rplus},
176
- lambda value, a: sp.maxwell.logpdf(value, scale=a),
177
- )
178
-
179
- def test_logcdf(self):
180
- check_logcdf(
181
- Maxwell,
182
- Rplus,
183
- {"a": Rplus},
184
- lambda value, a: sp.maxwell.logcdf(value, scale=a),
185
- )
@@ -1,210 +0,0 @@
1
- # Copyright 2023 The PyMC Developers
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import numpy as np
15
- import pymc as pm
16
- import pytensor
17
- import pytensor.tensor as pt
18
- import pytest
19
- import scipy.stats
20
-
21
- from pymc.logprob.utils import ParameterValueError
22
- from pymc.testing import (
23
- BaseTestDistributionRandom,
24
- Domain,
25
- I,
26
- Rplus,
27
- assert_support_point_is_expected,
28
- check_logp,
29
- discrete_random_tester,
30
- )
31
- from pytensor import config
32
-
33
- from pymc_extras.distributions import (
34
- BetaNegativeBinomial,
35
- GeneralizedPoisson,
36
- Skellam,
37
- )
38
-
39
-
40
- class TestGeneralizedPoisson:
41
- class TestRandomVariable(BaseTestDistributionRandom):
42
- pymc_dist = GeneralizedPoisson
43
- pymc_dist_params = {"mu": 4.0, "lam": 1.0}
44
- expected_rv_op_params = {"mu": 4.0, "lam": 1.0}
45
- tests_to_run = [
46
- "check_pymc_params_match_rv_op",
47
- "check_rv_size",
48
- ]
49
-
50
- def test_random_matches_poisson(self):
51
- discrete_random_tester(
52
- dist=self.pymc_dist,
53
- paramdomains={"mu": Rplus, "lam": Domain([0], edges=(None, None))},
54
- ref_rand=lambda mu, lam, size: scipy.stats.poisson.rvs(mu, size=size),
55
- )
56
-
57
- @pytest.mark.parametrize("mu", (2.5, 20, 50))
58
- def test_random_lam_expected_moments(self, mu):
59
- lam = np.array([-0.9, -0.7, -0.2, 0, 0.2, 0.7, 0.9])
60
- dist = self.pymc_dist.dist(mu=mu, lam=lam, size=(10_000, len(lam)))
61
- draws = dist.eval()
62
-
63
- expected_mean = mu / (1 - lam)
64
- np.testing.assert_allclose(draws.mean(0), expected_mean, rtol=1e-1)
65
-
66
- expected_std = np.sqrt(mu / (1 - lam) ** 3)
67
- np.testing.assert_allclose(draws.std(0), expected_std, rtol=1e-1)
68
-
69
- def test_logp_matches_poisson(self):
70
- # We are only checking this distribution for lambda=0 where it's equivalent to Poisson.
71
- mu = pt.scalar("mu")
72
- lam = pt.scalar("lam")
73
- value = pt.vector("value", dtype="int64")
74
-
75
- logp = pm.logp(GeneralizedPoisson.dist(mu, lam), value)
76
- logp_fn = pytensor.function([value, mu, lam], logp)
77
-
78
- test_value = np.array([0, 1, 2, 30])
79
- for test_mu in (0.01, 0.1, 0.9, 1, 1.5, 20, 100):
80
- np.testing.assert_allclose(
81
- logp_fn(test_value, test_mu, lam=0),
82
- scipy.stats.poisson.logpmf(test_value, test_mu),
83
- rtol=1e-7 if config.floatX == "float64" else 1e-5,
84
- )
85
-
86
- # Check out-of-bounds values
87
- value = pt.scalar("value")
88
- logp = pm.logp(GeneralizedPoisson.dist(mu, lam), value)
89
- logp_fn = pytensor.function([value, mu, lam], logp)
90
-
91
- logp_fn(-1, mu=5, lam=0) == -np.inf
92
- logp_fn(9, mu=5, lam=-1) == -np.inf
93
-
94
- # Check mu/lam restrictions
95
- with pytest.raises(ParameterValueError):
96
- logp_fn(1, mu=1, lam=2)
97
-
98
- with pytest.raises(ParameterValueError):
99
- logp_fn(1, mu=0, lam=0)
100
-
101
- with pytest.raises(ParameterValueError):
102
- logp_fn(1, mu=1, lam=-1)
103
-
104
- def test_logp_lam_expected_moments(self):
105
- mu = 30
106
- lam = np.array([-0.9, -0.7, -0.2, 0, 0.2, 0.7, 0.9])
107
- with pm.Model():
108
- x = GeneralizedPoisson("x", mu=mu, lam=lam)
109
- trace = pm.sample(chains=1, draws=10_000, random_seed=96).posterior
110
-
111
- expected_mean = mu / (1 - lam)
112
- np.testing.assert_allclose(trace["x"].mean(("chain", "draw")), expected_mean, rtol=1e-1)
113
-
114
- expected_std = np.sqrt(mu / (1 - lam) ** 3)
115
- np.testing.assert_allclose(trace["x"].std(("chain", "draw")), expected_std, rtol=1e-1)
116
-
117
- @pytest.mark.parametrize(
118
- "mu, lam, size, expected",
119
- [
120
- (50, [-0.6, 0, 0.6], None, np.floor(50 / (1 - np.array([-0.6, 0, 0.6])))),
121
- ([5, 50], -0.1, (4, 2), np.full((4, 2), np.floor(np.array([5, 50]) / 1.1))),
122
- ],
123
- )
124
- def test_moment(self, mu, lam, size, expected):
125
- with pm.Model() as model:
126
- GeneralizedPoisson("x", mu=mu, lam=lam, size=size)
127
- assert_support_point_is_expected(model, expected)
128
-
129
-
130
- class TestBetaNegativeBinomial:
131
- """
132
- Wrapper class so that tests of experimental additions can be dropped into
133
- PyMC directly on adoption.
134
- """
135
-
136
- def test_logp(self):
137
- """
138
-
139
- Beta Negative Binomial logp function test values taken from R package as
140
- there is currently no implementation in scipy.
141
- https://github.com/scipy/scipy/issues/17330
142
-
143
- The test values can be generated in R with the following code:
144
-
145
- .. code-block:: r
146
-
147
- library(extraDistr)
148
-
149
- create.test.rows <- function(alpha, beta, r, x) {
150
- logp <- dbnbinom(x, alpha, beta, r, log=TRUE)
151
- paste0("(", paste(alpha, beta, r, x, logp, sep=", "), ")")
152
- }
153
-
154
- x <- c(0, 1, 250, 5000)
155
- print(create.test.rows(1, 1, 1, x), quote=FALSE)
156
- print(create.test.rows(1, 1, 10, x), quote=FALSE)
157
- print(create.test.rows(1, 10, 1, x), quote=FALSE)
158
- print(create.test.rows(10, 1, 1, x), quote=FALSE)
159
- print(create.test.rows(10, 10, 10, x), quote=FALSE)
160
-
161
- """
162
- alpha, beta, r, value = pt.scalars("alpha", "beta", "r", "value")
163
- logp = pm.logp(BetaNegativeBinomial.dist(alpha, beta, r), value)
164
- logp_fn = pytensor.function([value, alpha, beta, r], logp)
165
-
166
- tests = [
167
- # 1, 1, 1
168
- (1, 1, 1, 0, -0.693147180559945),
169
- (1, 1, 1, 1, -1.79175946922805),
170
- (1, 1, 1, 250, -11.0548820266432),
171
- (1, 1, 1, 5000, -17.0349862828565),
172
- # 1, 1, 10
173
- (1, 1, 10, 0, -2.39789527279837),
174
- (1, 1, 10, 1, -2.58021682959232),
175
- (1, 1, 10, 250, -8.82261694534392),
176
- (1, 1, 10, 5000, -14.7359968760473),
177
- # 1, 10, 1
178
- (1, 10, 1, 0, -2.39789527279837),
179
- (1, 10, 1, 1, -2.58021682959232),
180
- (1, 10, 1, 250, -8.82261694534418),
181
- (1, 10, 1, 5000, -14.7359968760446),
182
- # 10, 1, 1
183
- (10, 1, 1, 0, -0.0953101798043248),
184
- (10, 1, 1, 1, -2.58021682959232),
185
- (10, 1, 1, 250, -43.5891148758123),
186
- (10, 1, 1, 5000, -76.2953173311091),
187
- # 10, 10, 10
188
- (10, 10, 10, 0, -5.37909807285049),
189
- (10, 10, 10, 1, -4.17512526852455),
190
- (10, 10, 10, 250, -21.781591505836),
191
- (10, 10, 10, 5000, -53.4836799634603),
192
- ]
193
- for test_alpha, test_beta, test_r, test_value, expected_logp in tests:
194
- np.testing.assert_allclose(
195
- logp_fn(test_value, test_alpha, test_beta, test_r), expected_logp
196
- )
197
-
198
-
199
- class TestSkellam:
200
- def test_logp(self):
201
- # Scipy Skellam underflows to -inf earlier than PyMC
202
- Rplus_small = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 10, np.inf])
203
- # Suppress warnings coming from Scipy logpmf implementation
204
- with np.errstate(divide="ignore", invalid="ignore"):
205
- check_logp(
206
- Skellam,
207
- I,
208
- {"mu1": Rplus_small, "mu2": Rplus_small},
209
- lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2),
210
- )