pymc-extras 0.2.4__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/PKG-INFO +4 -2
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/__init__.py +1 -3
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/distributions/__init__.py +2 -0
- pymc_extras-0.2.5/pymc_extras/distributions/transforms/__init__.py +3 -0
- pymc_extras-0.2.5/pymc_extras/distributions/transforms/partial_order.py +227 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/__init__.py +4 -2
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/fit.py +6 -4
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/laplace.py +4 -1
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/pathfinder/lbfgs.py +49 -13
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/pathfinder/pathfinder.py +89 -103
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/utils/data_tools.py +24 -9
- pymc_extras-0.2.5/pymc_extras/version.txt +1 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras.egg-info/PKG-INFO +4 -2
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras.egg-info/SOURCES.txt +3 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/setup.py +1 -0
- pymc_extras-0.2.5/tests/distributions/test_transform.py +77 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_coord_assignment.py +65 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_laplace.py +16 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_pathfinder.py +101 -7
- pymc_extras-0.2.4/pymc_extras/version.txt +0 -1
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/CODE_OF_CONDUCT.md +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/CONTRIBUTING.md +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/LICENSE +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/MANIFEST.in +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/README.md +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/distributions/continuous.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/distributions/discrete.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/distributions/histogram_utils.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/distributions/multivariate/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/distributions/timeseries.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/gp/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/gp/latent_approx.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/find_map.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/pathfinder/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/smc/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/inference/smc/sampling.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/linearmodel.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model/marginal/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model/marginal/distributions.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model/marginal/graph_analysis.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model/marginal/marginal_model.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model/model_api.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model/transforms/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model/transforms/autoreparam.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/model_builder.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/preprocessing/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/preprocessing/standard_scaler.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/printing.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/core/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/core/compile.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/core/representation.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/core/statespace.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/filters/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/filters/distributions.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/filters/utilities.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/models/ETS.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/models/SARIMAX.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/models/VARMAX.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/models/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/models/structural.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/models/utilities.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/utils/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/utils/constants.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/statespace/utils/coord_tools.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/utils/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/utils/linear_cg.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/utils/model_equivalence.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/utils/pivoted_cholesky.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/utils/prior.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/utils/spline.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras/version.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras.egg-info/dependency_links.txt +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras.egg-info/requires.txt +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pymc_extras.egg-info/top_level.txt +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/pyproject.toml +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/requirements-dev.txt +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/requirements-docs.txt +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/requirements.txt +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/setup.cfg +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/distributions/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/distributions/test_continuous.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/distributions/test_discrete.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/distributions/test_discrete_markov_chain.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/distributions/test_multivariate.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/model/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/model/marginal/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/model/marginal/test_distributions.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/model/marginal/test_graph_analysis.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/model/marginal/test_marginal_model.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/model/test_model_api.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_ETS.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_SARIMAX.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_VARMAX.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_distributions.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_kalman_filter.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_representation.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_statespace.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_statespace_JAX.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/test_structural.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/utilities/__init__.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/utilities/shared_fixtures.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/utilities/statsmodel_local_level.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/statespace/utilities/test_helpers.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_blackjax_smc.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_find_map.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_histogram_approximation.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_linearmodel.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_model_builder.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_pivoted_cholesky.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_printing.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_prior_from_trace.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/test_splines.py +0 -0
- {pymc_extras-0.2.4 → pymc_extras-0.2.5}/tests/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
@@ -12,6 +12,7 @@ Classifier: Programming Language :: Python :: 3
|
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
16
|
Classifier: License :: OSI Approved :: Apache Software License
|
|
16
17
|
Classifier: Intended Audience :: Science/Research
|
|
17
18
|
Classifier: Topic :: Scientific/Engineering
|
|
@@ -40,6 +41,7 @@ Dynamic: description
|
|
|
40
41
|
Dynamic: description-content-type
|
|
41
42
|
Dynamic: home-page
|
|
42
43
|
Dynamic: license
|
|
44
|
+
Dynamic: license-file
|
|
43
45
|
Dynamic: maintainer
|
|
44
46
|
Dynamic: maintainer-email
|
|
45
47
|
Dynamic: provides-extra
|
|
@@ -15,9 +15,7 @@ import logging
|
|
|
15
15
|
|
|
16
16
|
from pymc_extras import gp, statespace, utils
|
|
17
17
|
from pymc_extras.distributions import *
|
|
18
|
-
from pymc_extras.inference
|
|
19
|
-
from pymc_extras.inference.fit import fit
|
|
20
|
-
from pymc_extras.inference.laplace import fit_laplace
|
|
18
|
+
from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
|
|
21
19
|
from pymc_extras.model.marginal.marginal_model import (
|
|
22
20
|
MarginalModel,
|
|
23
21
|
marginalize,
|
|
@@ -26,6 +26,7 @@ from pymc_extras.distributions.discrete import (
|
|
|
26
26
|
from pymc_extras.distributions.histogram_utils import histogram_approximation
|
|
27
27
|
from pymc_extras.distributions.multivariate import R2D2M2CP
|
|
28
28
|
from pymc_extras.distributions.timeseries import DiscreteMarkovChain
|
|
29
|
+
from pymc_extras.distributions.transforms import PartialOrder
|
|
29
30
|
|
|
30
31
|
__all__ = [
|
|
31
32
|
"Chi",
|
|
@@ -37,4 +38,5 @@ __all__ = [
|
|
|
37
38
|
"R2D2M2CP",
|
|
38
39
|
"Skellam",
|
|
39
40
|
"histogram_approximation",
|
|
41
|
+
"PartialOrder",
|
|
40
42
|
]
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright 2025 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pytensor.tensor as pt
|
|
16
|
+
|
|
17
|
+
from pymc.logprob.transforms import Transform
|
|
18
|
+
|
|
19
|
+
__all__ = ["PartialOrder"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def dtype_minval(dtype):
|
|
23
|
+
"""Find the minimum value for a given dtype"""
|
|
24
|
+
return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def padded_where(x, to_len, padval=-1):
|
|
28
|
+
"""A padded version of np.where"""
|
|
29
|
+
w = np.where(x)
|
|
30
|
+
return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PartialOrder(Transform):
|
|
34
|
+
"""Create a PartialOrder transform
|
|
35
|
+
|
|
36
|
+
A more flexible version of the pymc ordered transform that
|
|
37
|
+
allows specifying a (strict) partial order on the elements.
|
|
38
|
+
|
|
39
|
+
Examples
|
|
40
|
+
--------
|
|
41
|
+
.. code:: python
|
|
42
|
+
|
|
43
|
+
import numpy as np
|
|
44
|
+
import pymc as pm
|
|
45
|
+
import pymc_extras as pmx
|
|
46
|
+
|
|
47
|
+
# Define two partial orders on 4 elements
|
|
48
|
+
# am[i,j] = 1 means i < j
|
|
49
|
+
adj_mats = np.array([
|
|
50
|
+
# 0 < {1, 2} < 3
|
|
51
|
+
[[0, 1, 1, 0],
|
|
52
|
+
[0, 0, 0, 1],
|
|
53
|
+
[0, 0, 0, 1],
|
|
54
|
+
[0, 0, 0, 0]],
|
|
55
|
+
|
|
56
|
+
# 1 < 0 < 3 < 2
|
|
57
|
+
[[0, 0, 0, 1],
|
|
58
|
+
[1, 0, 0, 0],
|
|
59
|
+
[0, 0, 0, 0],
|
|
60
|
+
[0, 0, 1, 0]],
|
|
61
|
+
])
|
|
62
|
+
|
|
63
|
+
# Create the partial order from the adjacency matrices
|
|
64
|
+
po = pmx.PartialOrder(adj_mats)
|
|
65
|
+
|
|
66
|
+
with pm.Model() as model:
|
|
67
|
+
# Generate 3 samples from both partial orders
|
|
68
|
+
pm.Normal("po_vals", shape=(3,2,4), transform=po,
|
|
69
|
+
initval=po.initvals((3,2,4)))
|
|
70
|
+
|
|
71
|
+
idata = pm.sample()
|
|
72
|
+
|
|
73
|
+
# Verify that for first po, the zeroth element is always the smallest
|
|
74
|
+
assert (idata.posterior['po_vals'][:,:,:,0,0] <
|
|
75
|
+
idata.posterior['po_vals'][:,:,:,0,1:]).all()
|
|
76
|
+
|
|
77
|
+
# Verify that for second po, the second element is always the largest
|
|
78
|
+
assert (idata.posterior['po_vals'][:,:,:,1,2] >=
|
|
79
|
+
idata.posterior['po_vals'][:,:,:,1,:]).all()
|
|
80
|
+
|
|
81
|
+
Technical notes
|
|
82
|
+
----------------
|
|
83
|
+
Partial order needs to be strict, i.e. without equalities.
|
|
84
|
+
A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
|
|
85
|
+
Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
|
|
86
|
+
where N is the number of nodes in the dag and D is the maximum
|
|
87
|
+
in-degree of a node in the transitive reduction.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
name = "partial_order"
|
|
91
|
+
|
|
92
|
+
def __init__(self, adj_mat):
|
|
93
|
+
"""
|
|
94
|
+
Initialize the PartialOrder transform
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
adj_mat: ndarray
|
|
99
|
+
adjacency matrix for the DAG that generates the partial order,
|
|
100
|
+
where ``adj_mat[i][j] = 1`` denotes ``i < j``.
|
|
101
|
+
Note this also accepts multiple DAGs if RV is multidimensional
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
# Basic input checks
|
|
105
|
+
if adj_mat.ndim < 2:
|
|
106
|
+
raise ValueError("Adjacency matrix must have at least 2 dimensions")
|
|
107
|
+
if adj_mat.shape[-2] != adj_mat.shape[-1]:
|
|
108
|
+
raise ValueError("Adjacency matrix is not square")
|
|
109
|
+
if adj_mat.min() != 0 or adj_mat.max() != 1:
|
|
110
|
+
raise ValueError("Adjacency matrix must contain only 0s and 1s")
|
|
111
|
+
|
|
112
|
+
# Create index over the first ellipsis dimensions
|
|
113
|
+
idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]])
|
|
114
|
+
|
|
115
|
+
# Transitive closure using Floyd-Warshall
|
|
116
|
+
tc = adj_mat.astype(bool)
|
|
117
|
+
for k in range(tc.shape[-1]):
|
|
118
|
+
tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :])
|
|
119
|
+
|
|
120
|
+
# Check if the dag is acyclic
|
|
121
|
+
if np.any(tc.diagonal(axis1=-2, axis2=-1)):
|
|
122
|
+
raise ValueError("Partial order contains equalities")
|
|
123
|
+
|
|
124
|
+
# Transitive reduction using the closure
|
|
125
|
+
# This gives the minimum description of the partial order
|
|
126
|
+
# This is to minmax the input degree
|
|
127
|
+
adj_mat = tc * (1 - np.matmul(tc, tc))
|
|
128
|
+
|
|
129
|
+
# Find the maximum in-degree of the reduced dag
|
|
130
|
+
dag_idim = adj_mat.sum(axis=-2).max()
|
|
131
|
+
|
|
132
|
+
# Topological sort
|
|
133
|
+
ts_inds = np.zeros(adj_mat.shape[:-1], dtype=int)
|
|
134
|
+
dm = adj_mat.copy()
|
|
135
|
+
for i in range(adj_mat.shape[1]):
|
|
136
|
+
assert dm.sum(axis=-2).min() == 0 # DAG is acyclic
|
|
137
|
+
nind = np.argmin(dm.sum(axis=-2), axis=-1)
|
|
138
|
+
dm[(*idx, slice(None), nind)] = 1 # Make nind not show up again
|
|
139
|
+
dm[(*idx, nind, slice(None))] = 0 # Allow it's children to show
|
|
140
|
+
ts_inds[(*idx, i)] = nind
|
|
141
|
+
self.ts_inds = ts_inds
|
|
142
|
+
|
|
143
|
+
# Change the dag to adjacency lists (with -1 for NA)
|
|
144
|
+
dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat, padval=-1, to_len=dag_idim)
|
|
145
|
+
self.dag = np.swapaxes(dag_T, -2, -1)
|
|
146
|
+
self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
|
|
147
|
+
|
|
148
|
+
def initvals(self, shape=None, lower=-1, upper=1):
|
|
149
|
+
"""
|
|
150
|
+
Create a set of appropriate initial values for the variable.
|
|
151
|
+
NB! It is important that proper initial values are used,
|
|
152
|
+
as only properly ordered values are in the range of the transform.
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
shape: tuple, default None
|
|
157
|
+
shape of the initial values. If None, adj_mat[:-1] is used
|
|
158
|
+
lower: float, default -1
|
|
159
|
+
lower bound for the initial values
|
|
160
|
+
upper: float, default 1
|
|
161
|
+
upper bound for the initial values
|
|
162
|
+
|
|
163
|
+
Returns
|
|
164
|
+
-------
|
|
165
|
+
vals: ndarray
|
|
166
|
+
initial values for the transformed variable
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
if shape is None:
|
|
170
|
+
shape = self.dag.shape[:-1]
|
|
171
|
+
|
|
172
|
+
if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]:
|
|
173
|
+
raise ValueError("Shape must match the shape of the adjacency matrix")
|
|
174
|
+
|
|
175
|
+
# Create the initial values
|
|
176
|
+
vals = np.linspace(lower, upper, self.dag.shape[-2])
|
|
177
|
+
inds = np.argsort(self.ts_inds, axis=-1)
|
|
178
|
+
ivals = vals[inds]
|
|
179
|
+
|
|
180
|
+
# Expand the initial values to the extra dimensions
|
|
181
|
+
extra_dims = shape[: -len(self.dag.shape[:-1])]
|
|
182
|
+
ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1])))
|
|
183
|
+
|
|
184
|
+
return ivals
|
|
185
|
+
|
|
186
|
+
def backward(self, value, *inputs):
|
|
187
|
+
minv = dtype_minval(value.dtype)
|
|
188
|
+
x = pt.concatenate(
|
|
189
|
+
[pt.zeros_like(value), pt.full(value.shape[:-1], minv)[..., None]], axis=-1
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# Indices to allow broadcasting the max over the last dimension
|
|
193
|
+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
|
|
194
|
+
idx2 = tuple(np.tile(i[:, None], self.dag.shape[-1]) for i in idx)
|
|
195
|
+
|
|
196
|
+
# Has to be done stepwise as next steps depend on previous values
|
|
197
|
+
# Also has to be done in topological order, hence the ts_inds
|
|
198
|
+
for i in range(self.dag.shape[-2]):
|
|
199
|
+
tsi = self.ts_inds[..., i]
|
|
200
|
+
if len(tsi.shape) == 0:
|
|
201
|
+
tsi = int(tsi) # if shape 0, it's a scalar
|
|
202
|
+
ni = (*idx, tsi) # i-th node in topological order
|
|
203
|
+
eni = (Ellipsis, *ni)
|
|
204
|
+
ist = self.is_start[ni]
|
|
205
|
+
|
|
206
|
+
mval = pt.max(x[(Ellipsis, *idx2, self.dag[ni])], axis=-1)
|
|
207
|
+
x = pt.set_subtensor(x[eni], ist * value[eni] + (1 - ist) * (mval + pt.exp(value[eni])))
|
|
208
|
+
return x[..., :-1]
|
|
209
|
+
|
|
210
|
+
def forward(self, value, *inputs):
|
|
211
|
+
y = pt.zeros_like(value)
|
|
212
|
+
|
|
213
|
+
minv = dtype_minval(value.dtype)
|
|
214
|
+
vx = pt.concatenate([value, pt.full(value.shape[:-1], minv)[..., None]], axis=-1)
|
|
215
|
+
|
|
216
|
+
# Indices to allow broadcasting the max over the last dimension
|
|
217
|
+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
|
|
218
|
+
idx = tuple(np.tile(i[:, None, None], self.dag.shape[-2:]) for i in idx)
|
|
219
|
+
|
|
220
|
+
y = self.is_start * value + (1 - self.is_start) * (
|
|
221
|
+
pt.log(value - pt.max(vx[(Ellipsis, *idx, self.dag[..., :])], axis=-1))
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
return y
|
|
225
|
+
|
|
226
|
+
def log_jac_det(self, value, *inputs):
|
|
227
|
+
return pt.sum(value * (1 - self.is_start), axis=-1)
|
|
@@ -12,7 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
from pymc_extras.inference.find_map import find_MAP
|
|
16
16
|
from pymc_extras.inference.fit import fit
|
|
17
|
+
from pymc_extras.inference.laplace import fit_laplace
|
|
18
|
+
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
|
|
17
19
|
|
|
18
|
-
__all__ = ["fit"]
|
|
20
|
+
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
|
|
@@ -11,11 +11,13 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
+
import arviz as az
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
def fit(method, **kwargs):
|
|
17
|
+
def fit(method: str, **kwargs) -> az.InferenceData:
|
|
17
18
|
"""
|
|
18
|
-
Fit a model with an inference algorithm
|
|
19
|
+
Fit a model with an inference algorithm.
|
|
20
|
+
See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.
|
|
19
21
|
|
|
20
22
|
Parameters
|
|
21
23
|
----------
|
|
@@ -23,11 +25,11 @@ def fit(method, **kwargs):
|
|
|
23
25
|
Which inference method to run.
|
|
24
26
|
Supported: pathfinder or laplace
|
|
25
27
|
|
|
26
|
-
kwargs are passed on.
|
|
28
|
+
kwargs: keyword arguments are passed on to the inference method.
|
|
27
29
|
|
|
28
30
|
Returns
|
|
29
31
|
-------
|
|
30
|
-
arviz.InferenceData
|
|
32
|
+
:class:`~arviz.InferenceData`
|
|
31
33
|
"""
|
|
32
34
|
if method == "pathfinder":
|
|
33
35
|
from pymc_extras.inference.pathfinder import fit_pathfinder
|
|
@@ -377,7 +377,10 @@ def sample_laplace_posterior(
|
|
|
377
377
|
posterior_dist = stats.multivariate_normal(
|
|
378
378
|
mean=mu.data, cov=H_inv, allow_singular=True, seed=rng
|
|
379
379
|
)
|
|
380
|
+
|
|
380
381
|
posterior_draws = posterior_dist.rvs(size=(chains, draws))
|
|
382
|
+
if mu.data.shape == (1,):
|
|
383
|
+
posterior_draws = np.expand_dims(posterior_draws, -1)
|
|
381
384
|
|
|
382
385
|
if transform_samples:
|
|
383
386
|
constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
|
|
@@ -506,7 +509,7 @@ def fit_laplace(
|
|
|
506
509
|
|
|
507
510
|
Returns
|
|
508
511
|
-------
|
|
509
|
-
|
|
512
|
+
:class:`~arviz.InferenceData`
|
|
510
513
|
An InferenceData object containing the approximated posterior samples.
|
|
511
514
|
|
|
512
515
|
Examples
|
|
@@ -37,11 +37,14 @@ class LBFGSHistoryManager:
|
|
|
37
37
|
initial position
|
|
38
38
|
maxiter : int
|
|
39
39
|
maximum number of iterations to store
|
|
40
|
+
epsilon : float
|
|
41
|
+
tolerance for lbfgs update
|
|
40
42
|
"""
|
|
41
43
|
|
|
42
44
|
value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
|
|
43
45
|
x0: NDArray[np.float64]
|
|
44
46
|
maxiter: int
|
|
47
|
+
epsilon: float
|
|
45
48
|
x_history: NDArray[np.float64] = field(init=False)
|
|
46
49
|
g_history: NDArray[np.float64] = field(init=False)
|
|
47
50
|
count: int = field(init=False)
|
|
@@ -52,7 +55,7 @@ class LBFGSHistoryManager:
|
|
|
52
55
|
self.count = 0
|
|
53
56
|
|
|
54
57
|
value, grad = self.value_grad_fn(self.x0)
|
|
55
|
-
if
|
|
58
|
+
if self.entry_condition_met(self.x0, value, grad):
|
|
56
59
|
self.add_entry(self.x0, grad)
|
|
57
60
|
|
|
58
61
|
def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
|
|
@@ -75,18 +78,39 @@ class LBFGSHistoryManager:
|
|
|
75
78
|
x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
|
|
76
79
|
)
|
|
77
80
|
|
|
81
|
+
def entry_condition_met(self, x, value, grad) -> bool:
|
|
82
|
+
"""Checks if the LBFGS iteration should continue."""
|
|
83
|
+
|
|
84
|
+
if np.all(np.isfinite(grad)) and np.isfinite(value) and (self.count < self.maxiter + 1):
|
|
85
|
+
if self.count == 0:
|
|
86
|
+
return True
|
|
87
|
+
else:
|
|
88
|
+
s = x - self.x_history[self.count - 1]
|
|
89
|
+
z = grad - self.g_history[self.count - 1]
|
|
90
|
+
sz = (s * z).sum(axis=-1)
|
|
91
|
+
update = sz > self.epsilon * np.sqrt(np.sum(z**2, axis=-1))
|
|
92
|
+
|
|
93
|
+
if update:
|
|
94
|
+
return True
|
|
95
|
+
else:
|
|
96
|
+
return False
|
|
97
|
+
else:
|
|
98
|
+
return False
|
|
99
|
+
|
|
78
100
|
def __call__(self, x: NDArray[np.float64]) -> None:
|
|
79
101
|
value, grad = self.value_grad_fn(x)
|
|
80
|
-
if
|
|
102
|
+
if self.entry_condition_met(x, value, grad):
|
|
81
103
|
self.add_entry(x, grad)
|
|
82
104
|
|
|
83
105
|
|
|
84
106
|
class LBFGSStatus(Enum):
|
|
85
107
|
CONVERGED = auto()
|
|
86
108
|
MAX_ITER_REACHED = auto()
|
|
87
|
-
|
|
109
|
+
NON_FINITE = auto()
|
|
110
|
+
LOW_UPDATE_PCT = auto()
|
|
88
111
|
# Statuses that lead to Exceptions:
|
|
89
112
|
INIT_FAILED = auto()
|
|
113
|
+
INIT_FAILED_LOW_UPDATE_PCT = auto()
|
|
90
114
|
LBFGS_FAILED = auto()
|
|
91
115
|
|
|
92
116
|
|
|
@@ -101,8 +125,8 @@ class LBFGSException(Exception):
|
|
|
101
125
|
class LBFGSInitFailed(LBFGSException):
|
|
102
126
|
DEFAULT_MESSAGE = "LBFGS failed to initialise."
|
|
103
127
|
|
|
104
|
-
def __init__(self, message=None):
|
|
105
|
-
super().__init__(message or self.DEFAULT_MESSAGE,
|
|
128
|
+
def __init__(self, status: LBFGSStatus, message=None):
|
|
129
|
+
super().__init__(message or self.DEFAULT_MESSAGE, status)
|
|
106
130
|
|
|
107
131
|
|
|
108
132
|
class LBFGS:
|
|
@@ -122,10 +146,12 @@ class LBFGS:
|
|
|
122
146
|
gradient tolerance for convergence, defaults to 1e-8
|
|
123
147
|
maxls : int, optional
|
|
124
148
|
maximum number of line search steps, defaults to 1000
|
|
149
|
+
epsilon : float, optional
|
|
150
|
+
tolerance for lbfgs update, defaults to 1e-8
|
|
125
151
|
"""
|
|
126
152
|
|
|
127
153
|
def __init__(
|
|
128
|
-
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
|
|
154
|
+
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000, epsilon=1e-8
|
|
129
155
|
) -> None:
|
|
130
156
|
self.value_grad_fn = value_grad_fn
|
|
131
157
|
self.maxcor = maxcor
|
|
@@ -133,6 +159,7 @@ class LBFGS:
|
|
|
133
159
|
self.ftol = ftol
|
|
134
160
|
self.gtol = gtol
|
|
135
161
|
self.maxls = maxls
|
|
162
|
+
self.epsilon = epsilon
|
|
136
163
|
|
|
137
164
|
def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
|
|
138
165
|
"""minimizes objective function starting from initial position.
|
|
@@ -157,7 +184,7 @@ class LBFGS:
|
|
|
157
184
|
x0 = np.array(x0, dtype=np.float64)
|
|
158
185
|
|
|
159
186
|
history_manager = LBFGSHistoryManager(
|
|
160
|
-
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
|
|
187
|
+
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter, epsilon=self.epsilon
|
|
161
188
|
)
|
|
162
189
|
|
|
163
190
|
result = minimize(
|
|
@@ -177,13 +204,22 @@ class LBFGS:
|
|
|
177
204
|
history = history_manager.get_history()
|
|
178
205
|
|
|
179
206
|
# warnings and suggestions for LBFGSStatus are displayed at the end
|
|
180
|
-
if
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
207
|
+
# threshold determining if the number of lbfgs updates is low compared to iterations
|
|
208
|
+
low_update_threshold = 3
|
|
209
|
+
|
|
210
|
+
if history.count <= 1: # triggers LBFGSInitFailed
|
|
211
|
+
if result.nit < low_update_threshold:
|
|
184
212
|
lbfgs_status = LBFGSStatus.INIT_FAILED
|
|
185
|
-
|
|
186
|
-
lbfgs_status = LBFGSStatus.
|
|
213
|
+
else:
|
|
214
|
+
lbfgs_status = LBFGSStatus.INIT_FAILED_LOW_UPDATE_PCT
|
|
215
|
+
elif result.status == 1:
|
|
216
|
+
# (result.nit > maxiter) or (result.nit > maxls)
|
|
217
|
+
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
|
|
218
|
+
elif result.status == 2:
|
|
219
|
+
# precision loss resulting to inf or nan
|
|
220
|
+
lbfgs_status = LBFGSStatus.NON_FINITE
|
|
221
|
+
elif history.count * low_update_threshold < result.nit:
|
|
222
|
+
lbfgs_status = LBFGSStatus.LOW_UPDATE_PCT
|
|
187
223
|
else:
|
|
188
224
|
lbfgs_status = LBFGSStatus.CONVERGED
|
|
189
225
|
|