tinygp 0.2.4__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tinygp-0.2.4 → tinygp-0.3.0}/.github/workflows/news.yml +2 -2
- {tinygp-0.2.4 → tinygp-0.3.0}/.github/workflows/tests.yml +17 -7
- {tinygp-0.2.4 → tinygp-0.3.0}/.pre-commit-config.yaml +3 -3
- {tinygp-0.2.4 → tinygp-0.3.0}/.zenodo.json +20 -10
- {tinygp-0.2.4 → tinygp-0.3.0}/PKG-INFO +5 -3
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/news.rst +17 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/derivative.ipynb +8 -5
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/kernels.ipynb +4 -5
- {tinygp-0.2.4 → tinygp-0.3.0}/noxfile.py +6 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/pyproject.toml +3 -2
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/gp.py +25 -10
- tinygp-0.3.0/src/tinygp/helpers.py +19 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/base.py +12 -22
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/distance.py +4 -5
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/quasisep.py +353 -133
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/stationary.py +13 -16
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/means.py +23 -19
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/noise.py +5 -7
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/direct.py +8 -14
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/kalman.py +11 -11
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/core.py +6 -21
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/general.py +4 -9
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/ops.py +6 -3
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/solver.py +7 -8
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/solver.py +18 -3
- tinygp-0.3.0/src/tinygp/test_utils.py +32 -0
- tinygp-0.3.0/src/tinygp/tinygp_version.py +16 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/transforms.py +4 -7
- {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_george_compat.py +30 -30
- {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_gp.py +9 -8
- {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_kernels/test_distance.py +5 -5
- {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_kernels/test_kernels.py +56 -16
- tinygp-0.3.0/tests/test_kernels/test_quasisep.py +153 -0
- tinygp-0.3.0/tests/test_noise.py +71 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_solvers/test_kalman.py +9 -8
- {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_solvers/test_quasisep/test_core.py +94 -92
- tinygp-0.3.0/tests/test_solvers/test_quasisep/test_general.py +20 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_solvers/test_quasisep/test_solver.py +29 -31
- {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_transforms.py +8 -8
- tinygp-0.2.4/src/tinygp/helpers.py +0 -74
- tinygp-0.2.4/src/tinygp/tinygp_version.py +0 -8
- tinygp-0.2.4/tests/test_kernels/test_quasisep.py +0 -77
- tinygp-0.2.4/tests/test_noise.py +0 -68
- tinygp-0.2.4/tests/test_solvers/test_quasisep/test_general.py +0 -18
- {tinygp-0.2.4 → tinygp-0.3.0}/.gitattributes +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/.github/dependabot.yml +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/.gitignore +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/.readthedocs.yaml +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/CODE_OF_CONDUCT.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/CONTRIBUTING.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/LICENSE +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/MANIFEST.in +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/README.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/.gitignore +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/Makefile +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/_static/favicon.png +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/_static/zap.png +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/_static/zap.svg +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/_templates/autosummary/class.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/index.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/kernels.quasisep.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/kernels.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/means.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/noise.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/solvers.quasisep.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/solvers.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/transforms.rst +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/benchmarks.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/code-of-conduct.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/conf.py +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/contributing.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/guide.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/index.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/install.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/motivation.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/troubleshooting.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/geometry.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/intro.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/ipython_kernel_config.py +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/likelihoods.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/matplotlibrc +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/means.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/mixture.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/modeling.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/multivariate.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/quasisep-custom.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/quasisep.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/quickstart.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/transforms.ipynb +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials.md +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/news/.gitignore +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/requirements.txt +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/__init__.py +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/__init__.py +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/numpyro_support.py +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/py.typed +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/__init__.py +0 -0
- {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/__init__.py +0 -0
|
@@ -8,11 +8,11 @@ jobs:
|
|
|
8
8
|
if: ${{ github.actor != 'dependabot[bot]' && github.actor != 'pre-commit-ci[bot]' }}
|
|
9
9
|
runs-on: ubuntu-latest
|
|
10
10
|
steps:
|
|
11
|
-
- uses: actions/checkout@
|
|
11
|
+
- uses: actions/checkout@v4
|
|
12
12
|
with:
|
|
13
13
|
fetch-depth: 0
|
|
14
14
|
- name: Setup Python
|
|
15
|
-
uses: actions/setup-python@
|
|
15
|
+
uses: actions/setup-python@v5
|
|
16
16
|
with:
|
|
17
17
|
python-version: "3.10"
|
|
18
18
|
- name: Install dependencies
|
|
@@ -15,17 +15,25 @@ jobs:
|
|
|
15
15
|
matrix:
|
|
16
16
|
python-version: ["3.9", "3.10", "3.11"]
|
|
17
17
|
nox-session: ["test"]
|
|
18
|
+
x64: ["1"]
|
|
18
19
|
include:
|
|
20
|
+
- python-version: "3.10"
|
|
21
|
+
nox-session: "test"
|
|
22
|
+
x64: "0"
|
|
23
|
+
- python-version: "3.10"
|
|
24
|
+
nox-session: "comparison"
|
|
25
|
+
x64: "1"
|
|
19
26
|
- python-version: "3.10"
|
|
20
27
|
nox-session: "doctest"
|
|
28
|
+
x64: "1"
|
|
21
29
|
|
|
22
30
|
steps:
|
|
23
31
|
- name: Checkout
|
|
24
|
-
uses: actions/checkout@
|
|
32
|
+
uses: actions/checkout@v4
|
|
25
33
|
with:
|
|
26
34
|
fetch-depth: 0
|
|
27
35
|
- name: Setup Python
|
|
28
|
-
uses: actions/setup-python@
|
|
36
|
+
uses: actions/setup-python@v5
|
|
29
37
|
with:
|
|
30
38
|
python-version: ${{ matrix.python-version }}
|
|
31
39
|
- name: Install dependencies
|
|
@@ -36,14 +44,16 @@ jobs:
|
|
|
36
44
|
run: |
|
|
37
45
|
python -m nox --non-interactive --error-on-missing-interpreter \
|
|
38
46
|
--session ${{ matrix.nox-session }} --python ${{ matrix.python-version }}
|
|
47
|
+
env:
|
|
48
|
+
JAX_ENABLE_X64: ${{ matrix.x64 }}
|
|
39
49
|
|
|
40
50
|
build:
|
|
41
51
|
runs-on: ubuntu-latest
|
|
42
52
|
steps:
|
|
43
|
-
- uses: actions/checkout@
|
|
53
|
+
- uses: actions/checkout@v4
|
|
44
54
|
with:
|
|
45
55
|
fetch-depth: 0
|
|
46
|
-
- uses: actions/setup-python@
|
|
56
|
+
- uses: actions/setup-python@v5
|
|
47
57
|
name: Install Python
|
|
48
58
|
with:
|
|
49
59
|
python-version: "3.10"
|
|
@@ -55,7 +65,7 @@ jobs:
|
|
|
55
65
|
run: python -m build .
|
|
56
66
|
- name: Check the distribution
|
|
57
67
|
run: python -m twine check --strict dist/*
|
|
58
|
-
- uses: actions/upload-artifact@
|
|
68
|
+
- uses: actions/upload-artifact@v4
|
|
59
69
|
with:
|
|
60
70
|
path: dist/*
|
|
61
71
|
|
|
@@ -69,8 +79,8 @@ jobs:
|
|
|
69
79
|
runs-on: ubuntu-latest
|
|
70
80
|
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
|
|
71
81
|
steps:
|
|
72
|
-
- uses: actions/download-artifact@
|
|
82
|
+
- uses: actions/download-artifact@v4
|
|
73
83
|
with:
|
|
74
84
|
name: artifact
|
|
75
85
|
path: dist
|
|
76
|
-
- uses: pypa/gh-action-pypi-publish@v1.8.
|
|
86
|
+
- uses: pypa/gh-action-pypi-publish@v1.8.11
|
|
@@ -3,17 +3,17 @@ ci:
|
|
|
3
3
|
|
|
4
4
|
repos:
|
|
5
5
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
6
|
-
rev: "v4.
|
|
6
|
+
rev: "v4.5.0"
|
|
7
7
|
hooks:
|
|
8
8
|
- id: trailing-whitespace
|
|
9
9
|
- id: end-of-file-fixer
|
|
10
10
|
exclude_types: [json, binary]
|
|
11
11
|
- repo: https://github.com/psf/black
|
|
12
|
-
rev: "23.
|
|
12
|
+
rev: "23.12.1"
|
|
13
13
|
hooks:
|
|
14
14
|
- id: black-jupyter
|
|
15
15
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
16
|
-
rev: "v0.
|
|
16
|
+
rev: "v0.1.9"
|
|
17
17
|
hooks:
|
|
18
18
|
- id: ruff
|
|
19
19
|
args: [--fix, --exit-non-zero-on-fix]
|
|
@@ -5,6 +5,11 @@
|
|
|
5
5
|
"affiliation": "Center for Computational Astrophysics, Flatiron Institute, New York, NY, USA",
|
|
6
6
|
"name": "Foreman-Mackey, Daniel"
|
|
7
7
|
},
|
|
8
|
+
{
|
|
9
|
+
"orcid": "0000-0003-1262-2897",
|
|
10
|
+
"affiliation": "Department of Physics and Astronomy, Bishop's University, Canada",
|
|
11
|
+
"name": "Weixiang Yu"
|
|
12
|
+
},
|
|
8
13
|
{
|
|
9
14
|
"orcid": "0000-0003-0048-1118",
|
|
10
15
|
"affiliation": "Indian Institute of Technology Gandhinagar: Gandhinagar, Gujarat, IN",
|
|
@@ -15,11 +20,26 @@
|
|
|
15
20
|
"affiliation": "Massachusetts Institute of Technology, Probabilistic Computing Project, Cambridge, MA, USA",
|
|
16
21
|
"name": "Becker, McCoy Reynolds"
|
|
17
22
|
},
|
|
23
|
+
{
|
|
24
|
+
"orcid": "0000-0003-3287-5250",
|
|
25
|
+
"affiliation": "Department of Astronomy and the DiRAC Institute, University of Washington, Seattle, WA, USA",
|
|
26
|
+
"name": "Caplar, Neven"
|
|
27
|
+
},
|
|
28
|
+
{
|
|
29
|
+
"orcid": "0000-0002-1169-7486",
|
|
30
|
+
"affiliation": "SRON Netherlands Institute for Space Research, Leiden, The Netherlands",
|
|
31
|
+
"name": "Huppenkothen, Daniela"
|
|
32
|
+
},
|
|
18
33
|
{
|
|
19
34
|
"orcid": "0000-0002-0440-9597",
|
|
20
35
|
"name": "Killestein, Thomas",
|
|
21
36
|
"affiliation": "Department of Physics, University of Warwick, Coventry, UK"
|
|
22
37
|
},
|
|
38
|
+
{
|
|
39
|
+
"orcid": "0000-0003-1001-0707",
|
|
40
|
+
"affiliation": "Department of Physics and Astronomy, Aarhus University, DK",
|
|
41
|
+
"name": "Tronsgaard, René"
|
|
42
|
+
},
|
|
23
43
|
{
|
|
24
44
|
"affiliation": "School of Public Health, Imperial College London, UK",
|
|
25
45
|
"name": "Rashid, Theo"
|
|
@@ -28,16 +48,6 @@
|
|
|
28
48
|
"orcid": "0000-0003-1354-0578",
|
|
29
49
|
"affiliation": "Helmholtz-Zentrum Dresden-Rossendorf e.V.",
|
|
30
50
|
"name": "Schmerler, Steve"
|
|
31
|
-
},
|
|
32
|
-
{
|
|
33
|
-
"orcid": "0000-0003-1001-0707",
|
|
34
|
-
"affiliation": "Department of Physics and Astronomy, Aarhus University, DK",
|
|
35
|
-
"name": "Tronsgaard, René"
|
|
36
|
-
},
|
|
37
|
-
{
|
|
38
|
-
"orcid": "0000-0003-3287-5250",
|
|
39
|
-
"affiliation": "Department of Astronomy and the DiRAC Institute, University of Washington, Seattle, WA, USA",
|
|
40
|
-
"name": "Caplar, Neven"
|
|
41
51
|
}
|
|
42
52
|
],
|
|
43
53
|
"license": "MIT",
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: tinygp
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: The tiniest of Gaussian Process libraries
|
|
5
5
|
Author-email: Dan Foreman-Mackey <foreman.mackey@gmail.com>
|
|
6
6
|
License: MIT
|
|
@@ -13,8 +13,12 @@ Classifier: Operating System :: OS Independent
|
|
|
13
13
|
Classifier: Programming Language :: Python
|
|
14
14
|
Classifier: Programming Language :: Python :: 3
|
|
15
15
|
Requires-Python: >=3.9
|
|
16
|
+
Requires-Dist: equinox
|
|
16
17
|
Requires-Dist: jax
|
|
17
18
|
Requires-Dist: jaxlib
|
|
19
|
+
Provides-Extra: comparison
|
|
20
|
+
Requires-Dist: celerite; extra == 'comparison'
|
|
21
|
+
Requires-Dist: george; extra == 'comparison'
|
|
18
22
|
Provides-Extra: docs
|
|
19
23
|
Requires-Dist: arviz; extra == 'docs'
|
|
20
24
|
Requires-Dist: flax; extra == 'docs'
|
|
@@ -27,8 +31,6 @@ Requires-Dist: optax; extra == 'docs'
|
|
|
27
31
|
Requires-Dist: sphinx-book-theme; extra == 'docs'
|
|
28
32
|
Requires-Dist: statsmodels; extra == 'docs'
|
|
29
33
|
Provides-Extra: test
|
|
30
|
-
Requires-Dist: celerite; extra == 'test'
|
|
31
|
-
Requires-Dist: george; extra == 'test'
|
|
32
34
|
Requires-Dist: pytest; extra == 'test'
|
|
33
35
|
Description-Content-Type: text/markdown
|
|
34
36
|
|
|
@@ -5,6 +5,23 @@ Release Notes
|
|
|
5
5
|
|
|
6
6
|
.. towncrier release notes start
|
|
7
7
|
|
|
8
|
+
tinygp 0.3.0 (2024-01-05)
|
|
9
|
+
-------------------------
|
|
10
|
+
|
|
11
|
+
Features
|
|
12
|
+
~~~~~~~~
|
|
13
|
+
|
|
14
|
+
- Added a more robust and better tested implementation of the ``CARMA`` kernel for
|
|
15
|
+
use with the ``QuasisepSolver``. (`#90 <https://github.com/dfm/tinygp/issues/90>`_)
|
|
16
|
+
- Switched all base classes to `equinox.Module <https://docs.kidger.site/equinox/api/module/module/>`_ objects to simplify dataclass handling. (`#200 <https://github.com/dfm/tinygp/issues/200>`_)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
Bugfixes
|
|
20
|
+
~~~~~~~~
|
|
21
|
+
|
|
22
|
+
- Fixed use of `jnp.roots` and `np.roll` to make CARMA kernel jit-compliant. (`#188 <https://github.com/dfm/tinygp/issues/188>`_)
|
|
23
|
+
|
|
24
|
+
|
|
8
25
|
tinygp 0.2.4 (2023-09-29)
|
|
9
26
|
-------------------------
|
|
10
27
|
|
|
@@ -105,8 +105,7 @@
|
|
|
105
105
|
"\n",
|
|
106
106
|
"\n",
|
|
107
107
|
"class DerivativeKernel(tinygp.kernels.Kernel):\n",
|
|
108
|
-
"
|
|
109
|
-
" self.kernel = kernel\n",
|
|
108
|
+
" kernel: tinygp.kernels.Kernel\n",
|
|
110
109
|
"\n",
|
|
111
110
|
" def evaluate(self, X1, X2):\n",
|
|
112
111
|
" t1, d1 = X1\n",
|
|
@@ -301,6 +300,10 @@
|
|
|
301
300
|
" shape as ``coeff_prim``.\n",
|
|
302
301
|
" \"\"\"\n",
|
|
303
302
|
"\n",
|
|
303
|
+
" kernel: tinygp.kernels.Kernel\n",
|
|
304
|
+
" coeff_prim: jax.Array\n",
|
|
305
|
+
" coeff_deriv: jax.Array\n",
|
|
306
|
+
"\n",
|
|
304
307
|
" def __init__(self, kernel, coeff_prim, coeff_deriv):\n",
|
|
305
308
|
" self.kernel = kernel\n",
|
|
306
309
|
" self.coeff_prim, self.coeff_deriv = jnp.broadcast_arrays(\n",
|
|
@@ -497,7 +500,7 @@
|
|
|
497
500
|
"hash": "d20ea8a315da34b3e8fab0dbd7b542a0ef3c8cf12937343660e6bc10a20768e3"
|
|
498
501
|
},
|
|
499
502
|
"kernelspec": {
|
|
500
|
-
"display_name": "Python 3
|
|
503
|
+
"display_name": "Python 3 (ipykernel)",
|
|
501
504
|
"language": "python",
|
|
502
505
|
"name": "python3"
|
|
503
506
|
},
|
|
@@ -511,9 +514,9 @@
|
|
|
511
514
|
"name": "python",
|
|
512
515
|
"nbconvert_exporter": "python",
|
|
513
516
|
"pygments_lexer": "ipython3",
|
|
514
|
-
"version": "3.
|
|
517
|
+
"version": "3.10.6"
|
|
515
518
|
}
|
|
516
519
|
},
|
|
517
520
|
"nbformat": 4,
|
|
518
|
-
"nbformat_minor":
|
|
521
|
+
"nbformat_minor": 4
|
|
519
522
|
}
|
|
@@ -54,10 +54,9 @@
|
|
|
54
54
|
"\n",
|
|
55
55
|
"\n",
|
|
56
56
|
"class SpectralMixture(tinygp.kernels.Kernel):\n",
|
|
57
|
-
"
|
|
58
|
-
"
|
|
59
|
-
"
|
|
60
|
-
" self.freq = jnp.atleast_1d(freq)\n",
|
|
57
|
+
" weight: jax.Array\n",
|
|
58
|
+
" scale: jax.Array\n",
|
|
59
|
+
" freq: jax.Array\n",
|
|
61
60
|
"\n",
|
|
62
61
|
" def evaluate(self, X1, X2):\n",
|
|
63
62
|
" tau = jnp.atleast_1d(jnp.abs(X1 - X2))[..., None]\n",
|
|
@@ -210,7 +209,7 @@
|
|
|
210
209
|
],
|
|
211
210
|
"metadata": {
|
|
212
211
|
"kernelspec": {
|
|
213
|
-
"display_name": "
|
|
212
|
+
"display_name": "Python 3 (ipykernel)",
|
|
214
213
|
"language": "python",
|
|
215
214
|
"name": "python3"
|
|
216
215
|
},
|
|
@@ -9,6 +9,12 @@ PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
|
|
|
9
9
|
@nox.session(python=PYTHON_VERSIONS)
|
|
10
10
|
def test(session: nox.Session) -> None:
|
|
11
11
|
session.install(".[test]")
|
|
12
|
+
session.run("pytest", *session.posargs)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@nox.session(python=PYTHON_VERSIONS)
|
|
16
|
+
def comparison(session: nox.Session) -> None:
|
|
17
|
+
session.install(".[test,comparison]")
|
|
12
18
|
session.run("pytest", *session.posargs, env={"JAX_ENABLE_X64": "1"})
|
|
13
19
|
|
|
14
20
|
|
|
@@ -15,10 +15,11 @@ classifiers = [
|
|
|
15
15
|
"Programming Language :: Python :: 3",
|
|
16
16
|
]
|
|
17
17
|
dynamic = ["version"]
|
|
18
|
-
dependencies = ["jax", "jaxlib"]
|
|
18
|
+
dependencies = ["jax", "jaxlib", "equinox"]
|
|
19
19
|
|
|
20
20
|
[project.optional-dependencies]
|
|
21
|
-
test = ["pytest"
|
|
21
|
+
test = ["pytest"]
|
|
22
|
+
comparison = ["george", "celerite"]
|
|
22
23
|
docs = [
|
|
23
24
|
"sphinx-book-theme",
|
|
24
25
|
"myst-nb",
|
|
@@ -11,8 +11,10 @@ from typing import (
|
|
|
11
11
|
NamedTuple,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
|
+
import equinox as eqx
|
|
14
15
|
import jax
|
|
15
16
|
import jax.numpy as jnp
|
|
17
|
+
import numpy as np
|
|
16
18
|
|
|
17
19
|
from tinygp import kernels, means
|
|
18
20
|
from tinygp.helpers import JAXArray
|
|
@@ -20,12 +22,13 @@ from tinygp.kernels.quasisep import Quasisep
|
|
|
20
22
|
from tinygp.noise import Diagonal, Noise
|
|
21
23
|
from tinygp.solvers import DirectSolver, QuasisepSolver
|
|
22
24
|
from tinygp.solvers.quasisep.core import SymmQSM
|
|
25
|
+
from tinygp.solvers.solver import Solver
|
|
23
26
|
|
|
24
27
|
if TYPE_CHECKING:
|
|
25
28
|
from tinygp.numpyro_support import TinyDistribution
|
|
26
29
|
|
|
27
30
|
|
|
28
|
-
class GaussianProcess:
|
|
31
|
+
class GaussianProcess(eqx.Module):
|
|
29
32
|
"""An interface for designing a Gaussian Process regression model
|
|
30
33
|
|
|
31
34
|
Args:
|
|
@@ -50,6 +53,15 @@ class GaussianProcess:
|
|
|
50
53
|
algebra.
|
|
51
54
|
"""
|
|
52
55
|
|
|
56
|
+
num_data: int = eqx.field(static=True)
|
|
57
|
+
dtype: np.dtype = eqx.field(static=True)
|
|
58
|
+
kernel: kernels.Kernel
|
|
59
|
+
X: JAXArray
|
|
60
|
+
mean_function: means.MeanBase
|
|
61
|
+
mean: JAXArray
|
|
62
|
+
noise: Noise
|
|
63
|
+
solver: Solver
|
|
64
|
+
|
|
53
65
|
def __init__(
|
|
54
66
|
self,
|
|
55
67
|
kernel: kernels.Kernel,
|
|
@@ -57,7 +69,7 @@ class GaussianProcess:
|
|
|
57
69
|
*,
|
|
58
70
|
diag: JAXArray | None = None,
|
|
59
71
|
noise: Noise | None = None,
|
|
60
|
-
mean: Callable[[JAXArray], JAXArray] | JAXArray | None = None,
|
|
72
|
+
mean: means.MeanBase | Callable[[JAXArray], JAXArray] | JAXArray | None = None,
|
|
61
73
|
solver: Any | None = None,
|
|
62
74
|
mean_value: JAXArray | None = None,
|
|
63
75
|
covariance_value: Any | None = None,
|
|
@@ -66,7 +78,7 @@ class GaussianProcess:
|
|
|
66
78
|
self.kernel = kernel
|
|
67
79
|
self.X = X
|
|
68
80
|
|
|
69
|
-
if
|
|
81
|
+
if isinstance(mean, means.MeanBase):
|
|
70
82
|
self.mean_function = mean
|
|
71
83
|
elif mean is None:
|
|
72
84
|
self.mean_function = means.Mean(jnp.zeros(()))
|
|
@@ -76,7 +88,7 @@ class GaussianProcess:
|
|
|
76
88
|
mean_value = jax.vmap(self.mean_function)(self.X)
|
|
77
89
|
self.num_data = mean_value.shape[0]
|
|
78
90
|
self.dtype = mean_value.dtype
|
|
79
|
-
self.
|
|
91
|
+
self.mean = mean_value
|
|
80
92
|
if self.mean.ndim != 1:
|
|
81
93
|
raise ValueError(
|
|
82
94
|
"Invalid mean shape: " f"expected ndim = 1, got ndim={self.mean.ndim}"
|
|
@@ -92,7 +104,7 @@ class GaussianProcess:
|
|
|
92
104
|
solver = QuasisepSolver
|
|
93
105
|
else:
|
|
94
106
|
solver = DirectSolver
|
|
95
|
-
self.solver = solver
|
|
107
|
+
self.solver = solver(
|
|
96
108
|
kernel,
|
|
97
109
|
self.X,
|
|
98
110
|
self.noise,
|
|
@@ -100,6 +112,10 @@ class GaussianProcess:
|
|
|
100
112
|
**solver_kwargs,
|
|
101
113
|
)
|
|
102
114
|
|
|
115
|
+
@property
|
|
116
|
+
def loc(self) -> JAXArray:
|
|
117
|
+
return self.mean
|
|
118
|
+
|
|
103
119
|
@property
|
|
104
120
|
def variance(self) -> JAXArray:
|
|
105
121
|
return self.solver.variance()
|
|
@@ -209,7 +225,6 @@ class GaussianProcess:
|
|
|
209
225
|
|
|
210
226
|
@partial(
|
|
211
227
|
jax.jit,
|
|
212
|
-
static_argnums=(0,),
|
|
213
228
|
static_argnames=("include_mean", "return_var", "return_cov"),
|
|
214
229
|
)
|
|
215
230
|
def predict(
|
|
@@ -281,7 +296,7 @@ class GaussianProcess:
|
|
|
281
296
|
|
|
282
297
|
return TinyDistribution(self, **kwargs)
|
|
283
298
|
|
|
284
|
-
@partial(jax.jit, static_argnums=(
|
|
299
|
+
@partial(jax.jit, static_argnums=(2,))
|
|
285
300
|
def _sample(
|
|
286
301
|
self,
|
|
287
302
|
key: jax.random.KeyArray,
|
|
@@ -296,16 +311,16 @@ class GaussianProcess:
|
|
|
296
311
|
self.solver.dot_triangular(normal_samples), 0, -1
|
|
297
312
|
)
|
|
298
313
|
|
|
299
|
-
@
|
|
314
|
+
@jax.jit
|
|
300
315
|
def _compute_log_prob(self, alpha: JAXArray) -> JAXArray:
|
|
301
316
|
loglike = -0.5 * jnp.sum(jnp.square(alpha)) - self.solver.normalization()
|
|
302
317
|
return jnp.where(jnp.isfinite(loglike), loglike, -jnp.inf)
|
|
303
318
|
|
|
304
|
-
@
|
|
319
|
+
@jax.jit
|
|
305
320
|
def _get_alpha(self, y: JAXArray) -> JAXArray:
|
|
306
321
|
return self.solver.solve_triangular(y - self.loc)
|
|
307
322
|
|
|
308
|
-
@partial(jax.jit, static_argnums=(
|
|
323
|
+
@partial(jax.jit, static_argnums=(3,))
|
|
309
324
|
def _condition(
|
|
310
325
|
self,
|
|
311
326
|
y: JAXArray,
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
__all__ = ["JAXArray", "dataclass", "field"]
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import equinox as eqx
|
|
8
|
+
import jax
|
|
9
|
+
|
|
10
|
+
JAXArray = jax.Array
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# The following is just for backwards compatibility since tinygp used to provide a
|
|
14
|
+
# custom dataclass implementation
|
|
15
|
+
field = eqx.field
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def dataclass(clz: type[Any]) -> type[Any]:
|
|
19
|
+
return clz
|
|
@@ -11,23 +11,23 @@ __all__ = [
|
|
|
11
11
|
"Polynomial",
|
|
12
12
|
]
|
|
13
13
|
|
|
14
|
-
from abc import
|
|
14
|
+
from abc import abstractmethod
|
|
15
15
|
from collections.abc import Sequence
|
|
16
16
|
from typing import TYPE_CHECKING, Any, Callable, Union
|
|
17
17
|
|
|
18
|
+
import equinox as eqx
|
|
18
19
|
import jax
|
|
19
20
|
import jax.numpy as jnp
|
|
20
21
|
|
|
21
|
-
from tinygp.helpers import JAXArray
|
|
22
|
+
from tinygp.helpers import JAXArray
|
|
22
23
|
|
|
23
24
|
if TYPE_CHECKING:
|
|
24
25
|
from tinygp.solvers.solver import Solver
|
|
25
26
|
|
|
26
|
-
|
|
27
27
|
Axis = Union[int, Sequence[int]]
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
class Kernel(
|
|
30
|
+
class Kernel(eqx.Module):
|
|
31
31
|
"""The base class for all kernel implementations
|
|
32
32
|
|
|
33
33
|
This subclass provides default implementations to add and multiply kernels.
|
|
@@ -35,11 +35,6 @@ class Kernel(metaclass=ABCMeta):
|
|
|
35
35
|
:func:`Kernel.evaluate` with custom behavior.
|
|
36
36
|
"""
|
|
37
37
|
|
|
38
|
-
if TYPE_CHECKING:
|
|
39
|
-
|
|
40
|
-
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
41
|
-
pass
|
|
42
|
-
|
|
43
38
|
@abstractmethod
|
|
44
39
|
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
|
|
45
40
|
"""Evaluate the kernel at a pair of input coordinates
|
|
@@ -58,6 +53,7 @@ class Kernel(metaclass=ABCMeta):
|
|
|
58
53
|
``(n_data, n_dim)``, and you should let the :class:`Kernel` ``vmap``
|
|
59
54
|
magic handle all the broadcasting for you.
|
|
60
55
|
"""
|
|
56
|
+
del X1, X2
|
|
61
57
|
raise NotImplementedError
|
|
62
58
|
|
|
63
59
|
def evaluate_diag(self, X: JAXArray) -> JAXArray:
|
|
@@ -130,7 +126,6 @@ class Kernel(metaclass=ABCMeta):
|
|
|
130
126
|
return Product(Constant(other), self)
|
|
131
127
|
|
|
132
128
|
|
|
133
|
-
@dataclass
|
|
134
129
|
class Conditioned(Kernel):
|
|
135
130
|
"""A kernel used when conditioning a process on data
|
|
136
131
|
|
|
@@ -158,7 +153,6 @@ class Conditioned(Kernel):
|
|
|
158
153
|
return self.kernel.evaluate_diag(X) - K.transpose() @ K
|
|
159
154
|
|
|
160
155
|
|
|
161
|
-
@dataclass
|
|
162
156
|
class Custom(Kernel):
|
|
163
157
|
"""A custom kernel class implemented as a callable
|
|
164
158
|
|
|
@@ -167,13 +161,12 @@ class Custom(Kernel):
|
|
|
167
161
|
:func:`Kernel.evaluate`.
|
|
168
162
|
"""
|
|
169
163
|
|
|
170
|
-
function: Callable[[Any, Any], Any]
|
|
164
|
+
function: Callable[[Any, Any], Any] = eqx.field(static=True)
|
|
171
165
|
|
|
172
166
|
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
|
|
173
167
|
return self.function(X1, X2)
|
|
174
168
|
|
|
175
169
|
|
|
176
|
-
@dataclass
|
|
177
170
|
class Sum(Kernel):
|
|
178
171
|
"""A helper to represent the sum of two kernels"""
|
|
179
172
|
|
|
@@ -184,7 +177,6 @@ class Sum(Kernel):
|
|
|
184
177
|
return self.kernel1.evaluate(X1, X2) + self.kernel2.evaluate(X1, X2)
|
|
185
178
|
|
|
186
179
|
|
|
187
|
-
@dataclass
|
|
188
180
|
class Product(Kernel):
|
|
189
181
|
"""A helper to represent the product of two kernels"""
|
|
190
182
|
|
|
@@ -195,7 +187,6 @@ class Product(Kernel):
|
|
|
195
187
|
return self.kernel1.evaluate(X1, X2) * self.kernel2.evaluate(X1, X2)
|
|
196
188
|
|
|
197
189
|
|
|
198
|
-
@dataclass
|
|
199
190
|
class Constant(Kernel):
|
|
200
191
|
r"""This kernel returns the constant
|
|
201
192
|
|
|
@@ -209,15 +200,15 @@ class Constant(Kernel):
|
|
|
209
200
|
c: The parameter :math:`c` in the above equation.
|
|
210
201
|
"""
|
|
211
202
|
|
|
212
|
-
value: JAXArray
|
|
203
|
+
value: JAXArray | float
|
|
213
204
|
|
|
214
205
|
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
|
|
206
|
+
del X1, X2
|
|
215
207
|
if jnp.ndim(self.value) != 0:
|
|
216
208
|
raise ValueError("The value of a constant kernel must be a scalar")
|
|
217
|
-
return self.value
|
|
209
|
+
return jnp.asarray(self.value)
|
|
218
210
|
|
|
219
211
|
|
|
220
|
-
@dataclass
|
|
221
212
|
class DotProduct(Kernel):
|
|
222
213
|
r"""The dot product kernel
|
|
223
214
|
|
|
@@ -234,7 +225,6 @@ class DotProduct(Kernel):
|
|
|
234
225
|
return X1 @ X2
|
|
235
226
|
|
|
236
227
|
|
|
237
|
-
@dataclass
|
|
238
228
|
class Polynomial(Kernel):
|
|
239
229
|
r"""A polynomial kernel
|
|
240
230
|
|
|
@@ -249,9 +239,9 @@ class Polynomial(Kernel):
|
|
|
249
239
|
sigma: The parameter :math:`\sigma`.
|
|
250
240
|
"""
|
|
251
241
|
|
|
252
|
-
order: JAXArray
|
|
253
|
-
scale: JAXArray = field(default_factory=lambda: jnp.ones(()))
|
|
254
|
-
sigma: JAXArray = field(default_factory=lambda: jnp.zeros(()))
|
|
242
|
+
order: JAXArray | float
|
|
243
|
+
scale: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))
|
|
244
|
+
sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.zeros(()))
|
|
255
245
|
|
|
256
246
|
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
|
|
257
247
|
return (
|
|
@@ -11,14 +11,15 @@ from __future__ import annotations
|
|
|
11
11
|
|
|
12
12
|
__all__ = ["Distance", "L1Distance", "L2Distance"]
|
|
13
13
|
|
|
14
|
-
from abc import
|
|
14
|
+
from abc import abstractmethod
|
|
15
15
|
|
|
16
|
+
import equinox as eqx
|
|
16
17
|
import jax.numpy as jnp
|
|
17
18
|
|
|
18
|
-
from tinygp.helpers import JAXArray
|
|
19
|
+
from tinygp.helpers import JAXArray
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
class Distance(
|
|
22
|
+
class Distance(eqx.Module):
|
|
22
23
|
"""An abstract base class defining a distance metric interface"""
|
|
23
24
|
|
|
24
25
|
@abstractmethod
|
|
@@ -37,7 +38,6 @@ class Distance(metaclass=ABCMeta):
|
|
|
37
38
|
return jnp.square(self.distance(X1, X2))
|
|
38
39
|
|
|
39
40
|
|
|
40
|
-
@dataclass
|
|
41
41
|
class L1Distance(Distance):
|
|
42
42
|
"""The L1 or Manhattan distance between two coordinates"""
|
|
43
43
|
|
|
@@ -45,7 +45,6 @@ class L1Distance(Distance):
|
|
|
45
45
|
return jnp.sum(jnp.abs(X1 - X2))
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
@dataclass
|
|
49
48
|
class L2Distance(Distance):
|
|
50
49
|
"""The L2 or Euclidean distance between two coordinates"""
|
|
51
50
|
|