tinygp 0.2.4__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. {tinygp-0.2.4 → tinygp-0.3.0}/.github/workflows/news.yml +2 -2
  2. {tinygp-0.2.4 → tinygp-0.3.0}/.github/workflows/tests.yml +17 -7
  3. {tinygp-0.2.4 → tinygp-0.3.0}/.pre-commit-config.yaml +3 -3
  4. {tinygp-0.2.4 → tinygp-0.3.0}/.zenodo.json +20 -10
  5. {tinygp-0.2.4 → tinygp-0.3.0}/PKG-INFO +5 -3
  6. {tinygp-0.2.4 → tinygp-0.3.0}/docs/news.rst +17 -0
  7. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/derivative.ipynb +8 -5
  8. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/kernels.ipynb +4 -5
  9. {tinygp-0.2.4 → tinygp-0.3.0}/noxfile.py +6 -0
  10. {tinygp-0.2.4 → tinygp-0.3.0}/pyproject.toml +3 -2
  11. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/gp.py +25 -10
  12. tinygp-0.3.0/src/tinygp/helpers.py +19 -0
  13. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/base.py +12 -22
  14. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/distance.py +4 -5
  15. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/quasisep.py +353 -133
  16. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/stationary.py +13 -16
  17. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/means.py +23 -19
  18. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/noise.py +5 -7
  19. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/direct.py +8 -14
  20. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/kalman.py +11 -11
  21. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/core.py +6 -21
  22. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/general.py +4 -9
  23. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/ops.py +6 -3
  24. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/solver.py +7 -8
  25. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/solver.py +18 -3
  26. tinygp-0.3.0/src/tinygp/test_utils.py +32 -0
  27. tinygp-0.3.0/src/tinygp/tinygp_version.py +16 -0
  28. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/transforms.py +4 -7
  29. {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_george_compat.py +30 -30
  30. {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_gp.py +9 -8
  31. {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_kernels/test_distance.py +5 -5
  32. {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_kernels/test_kernels.py +56 -16
  33. tinygp-0.3.0/tests/test_kernels/test_quasisep.py +153 -0
  34. tinygp-0.3.0/tests/test_noise.py +71 -0
  35. {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_solvers/test_kalman.py +9 -8
  36. {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_solvers/test_quasisep/test_core.py +94 -92
  37. tinygp-0.3.0/tests/test_solvers/test_quasisep/test_general.py +20 -0
  38. {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_solvers/test_quasisep/test_solver.py +29 -31
  39. {tinygp-0.2.4 → tinygp-0.3.0}/tests/test_transforms.py +8 -8
  40. tinygp-0.2.4/src/tinygp/helpers.py +0 -74
  41. tinygp-0.2.4/src/tinygp/tinygp_version.py +0 -8
  42. tinygp-0.2.4/tests/test_kernels/test_quasisep.py +0 -77
  43. tinygp-0.2.4/tests/test_noise.py +0 -68
  44. tinygp-0.2.4/tests/test_solvers/test_quasisep/test_general.py +0 -18
  45. {tinygp-0.2.4 → tinygp-0.3.0}/.gitattributes +0 -0
  46. {tinygp-0.2.4 → tinygp-0.3.0}/.github/dependabot.yml +0 -0
  47. {tinygp-0.2.4 → tinygp-0.3.0}/.gitignore +0 -0
  48. {tinygp-0.2.4 → tinygp-0.3.0}/.readthedocs.yaml +0 -0
  49. {tinygp-0.2.4 → tinygp-0.3.0}/CODE_OF_CONDUCT.md +0 -0
  50. {tinygp-0.2.4 → tinygp-0.3.0}/CONTRIBUTING.md +0 -0
  51. {tinygp-0.2.4 → tinygp-0.3.0}/LICENSE +0 -0
  52. {tinygp-0.2.4 → tinygp-0.3.0}/MANIFEST.in +0 -0
  53. {tinygp-0.2.4 → tinygp-0.3.0}/README.md +0 -0
  54. {tinygp-0.2.4 → tinygp-0.3.0}/docs/.gitignore +0 -0
  55. {tinygp-0.2.4 → tinygp-0.3.0}/docs/Makefile +0 -0
  56. {tinygp-0.2.4 → tinygp-0.3.0}/docs/_static/favicon.png +0 -0
  57. {tinygp-0.2.4 → tinygp-0.3.0}/docs/_static/zap.png +0 -0
  58. {tinygp-0.2.4 → tinygp-0.3.0}/docs/_static/zap.svg +0 -0
  59. {tinygp-0.2.4 → tinygp-0.3.0}/docs/_templates/autosummary/class.rst +0 -0
  60. {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/index.rst +0 -0
  61. {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/kernels.quasisep.rst +0 -0
  62. {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/kernels.rst +0 -0
  63. {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/means.rst +0 -0
  64. {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/noise.rst +0 -0
  65. {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/solvers.quasisep.rst +0 -0
  66. {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/solvers.rst +0 -0
  67. {tinygp-0.2.4 → tinygp-0.3.0}/docs/api/transforms.rst +0 -0
  68. {tinygp-0.2.4 → tinygp-0.3.0}/docs/benchmarks.ipynb +0 -0
  69. {tinygp-0.2.4 → tinygp-0.3.0}/docs/code-of-conduct.md +0 -0
  70. {tinygp-0.2.4 → tinygp-0.3.0}/docs/conf.py +0 -0
  71. {tinygp-0.2.4 → tinygp-0.3.0}/docs/contributing.md +0 -0
  72. {tinygp-0.2.4 → tinygp-0.3.0}/docs/guide.md +0 -0
  73. {tinygp-0.2.4 → tinygp-0.3.0}/docs/index.md +0 -0
  74. {tinygp-0.2.4 → tinygp-0.3.0}/docs/install.md +0 -0
  75. {tinygp-0.2.4 → tinygp-0.3.0}/docs/motivation.md +0 -0
  76. {tinygp-0.2.4 → tinygp-0.3.0}/docs/troubleshooting.md +0 -0
  77. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/geometry.ipynb +0 -0
  78. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/intro.ipynb +0 -0
  79. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/ipython_kernel_config.py +0 -0
  80. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/likelihoods.ipynb +0 -0
  81. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/matplotlibrc +0 -0
  82. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/means.ipynb +0 -0
  83. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/mixture.ipynb +0 -0
  84. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/modeling.ipynb +0 -0
  85. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/multivariate.ipynb +0 -0
  86. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/quasisep-custom.ipynb +0 -0
  87. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/quasisep.ipynb +0 -0
  88. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/quickstart.ipynb +0 -0
  89. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials/transforms.ipynb +0 -0
  90. {tinygp-0.2.4 → tinygp-0.3.0}/docs/tutorials.md +0 -0
  91. {tinygp-0.2.4 → tinygp-0.3.0}/news/.gitignore +0 -0
  92. {tinygp-0.2.4 → tinygp-0.3.0}/requirements.txt +0 -0
  93. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/__init__.py +0 -0
  94. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/kernels/__init__.py +0 -0
  95. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/numpyro_support.py +0 -0
  96. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/py.typed +0 -0
  97. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/__init__.py +0 -0
  98. {tinygp-0.2.4 → tinygp-0.3.0}/src/tinygp/solvers/quasisep/__init__.py +0 -0
@@ -8,11 +8,11 @@ jobs:
8
8
  if: ${{ github.actor != 'dependabot[bot]' && github.actor != 'pre-commit-ci[bot]' }}
9
9
  runs-on: ubuntu-latest
10
10
  steps:
11
- - uses: actions/checkout@v3
11
+ - uses: actions/checkout@v4
12
12
  with:
13
13
  fetch-depth: 0
14
14
  - name: Setup Python
15
- uses: actions/setup-python@v4
15
+ uses: actions/setup-python@v5
16
16
  with:
17
17
  python-version: "3.10"
18
18
  - name: Install dependencies
@@ -15,17 +15,25 @@ jobs:
15
15
  matrix:
16
16
  python-version: ["3.9", "3.10", "3.11"]
17
17
  nox-session: ["test"]
18
+ x64: ["1"]
18
19
  include:
20
+ - python-version: "3.10"
21
+ nox-session: "test"
22
+ x64: "0"
23
+ - python-version: "3.10"
24
+ nox-session: "comparison"
25
+ x64: "1"
19
26
  - python-version: "3.10"
20
27
  nox-session: "doctest"
28
+ x64: "1"
21
29
 
22
30
  steps:
23
31
  - name: Checkout
24
- uses: actions/checkout@v3
32
+ uses: actions/checkout@v4
25
33
  with:
26
34
  fetch-depth: 0
27
35
  - name: Setup Python
28
- uses: actions/setup-python@v4
36
+ uses: actions/setup-python@v5
29
37
  with:
30
38
  python-version: ${{ matrix.python-version }}
31
39
  - name: Install dependencies
@@ -36,14 +44,16 @@ jobs:
36
44
  run: |
37
45
  python -m nox --non-interactive --error-on-missing-interpreter \
38
46
  --session ${{ matrix.nox-session }} --python ${{ matrix.python-version }}
47
+ env:
48
+ JAX_ENABLE_X64: ${{ matrix.x64 }}
39
49
 
40
50
  build:
41
51
  runs-on: ubuntu-latest
42
52
  steps:
43
- - uses: actions/checkout@v3
53
+ - uses: actions/checkout@v4
44
54
  with:
45
55
  fetch-depth: 0
46
- - uses: actions/setup-python@v4
56
+ - uses: actions/setup-python@v5
47
57
  name: Install Python
48
58
  with:
49
59
  python-version: "3.10"
@@ -55,7 +65,7 @@ jobs:
55
65
  run: python -m build .
56
66
  - name: Check the distribution
57
67
  run: python -m twine check --strict dist/*
58
- - uses: actions/upload-artifact@v3
68
+ - uses: actions/upload-artifact@v4
59
69
  with:
60
70
  path: dist/*
61
71
 
@@ -69,8 +79,8 @@ jobs:
69
79
  runs-on: ubuntu-latest
70
80
  if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/')
71
81
  steps:
72
- - uses: actions/download-artifact@v3
82
+ - uses: actions/download-artifact@v4
73
83
  with:
74
84
  name: artifact
75
85
  path: dist
76
- - uses: pypa/gh-action-pypi-publish@v1.8.10
86
+ - uses: pypa/gh-action-pypi-publish@v1.8.11
@@ -3,17 +3,17 @@ ci:
3
3
 
4
4
  repos:
5
5
  - repo: https://github.com/pre-commit/pre-commit-hooks
6
- rev: "v4.4.0"
6
+ rev: "v4.5.0"
7
7
  hooks:
8
8
  - id: trailing-whitespace
9
9
  - id: end-of-file-fixer
10
10
  exclude_types: [json, binary]
11
11
  - repo: https://github.com/psf/black
12
- rev: "23.9.1"
12
+ rev: "23.12.1"
13
13
  hooks:
14
14
  - id: black-jupyter
15
15
  - repo: https://github.com/astral-sh/ruff-pre-commit
16
- rev: "v0.0.291"
16
+ rev: "v0.1.9"
17
17
  hooks:
18
18
  - id: ruff
19
19
  args: [--fix, --exit-non-zero-on-fix]
@@ -5,6 +5,11 @@
5
5
  "affiliation": "Center for Computational Astrophysics, Flatiron Institute, New York, NY, USA",
6
6
  "name": "Foreman-Mackey, Daniel"
7
7
  },
8
+ {
9
+ "orcid": "0000-0003-1262-2897",
10
+ "affiliation": "Department of Physics and Astronomy, Bishop's University, Canada",
11
+ "name": "Weixiang Yu"
12
+ },
8
13
  {
9
14
  "orcid": "0000-0003-0048-1118",
10
15
  "affiliation": "Indian Institute of Technology Gandhinagar: Gandhinagar, Gujarat, IN",
@@ -15,11 +20,26 @@
15
20
  "affiliation": "Massachusetts Institute of Technology, Probabilistic Computing Project, Cambridge, MA, USA",
16
21
  "name": "Becker, McCoy Reynolds"
17
22
  },
23
+ {
24
+ "orcid": "0000-0003-3287-5250",
25
+ "affiliation": "Department of Astronomy and the DiRAC Institute, University of Washington, Seattle, WA, USA",
26
+ "name": "Caplar, Neven"
27
+ },
28
+ {
29
+ "orcid": "0000-0002-1169-7486",
30
+ "affiliation": "SRON Netherlands Institute for Space Research, Leiden, The Netherlands",
31
+ "name": "Huppenkothen, Daniela"
32
+ },
18
33
  {
19
34
  "orcid": "0000-0002-0440-9597",
20
35
  "name": "Killestein, Thomas",
21
36
  "affiliation": "Department of Physics, University of Warwick, Coventry, UK"
22
37
  },
38
+ {
39
+ "orcid": "0000-0003-1001-0707",
40
+ "affiliation": "Department of Physics and Astronomy, Aarhus University, DK",
41
+ "name": "Tronsgaard, René"
42
+ },
23
43
  {
24
44
  "affiliation": "School of Public Health, Imperial College London, UK",
25
45
  "name": "Rashid, Theo"
@@ -28,16 +48,6 @@
28
48
  "orcid": "0000-0003-1354-0578",
29
49
  "affiliation": "Helmholtz-Zentrum Dresden-Rossendorf e.V.",
30
50
  "name": "Schmerler, Steve"
31
- },
32
- {
33
- "orcid": "0000-0003-1001-0707",
34
- "affiliation": "Department of Physics and Astronomy, Aarhus University, DK",
35
- "name": "Tronsgaard, René"
36
- },
37
- {
38
- "orcid": "0000-0003-3287-5250",
39
- "affiliation": "Department of Astronomy and the DiRAC Institute, University of Washington, Seattle, WA, USA",
40
- "name": "Caplar, Neven"
41
51
  }
42
52
  ],
43
53
  "license": "MIT",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tinygp
3
- Version: 0.2.4
3
+ Version: 0.3.0
4
4
  Summary: The tiniest of Gaussian Process libraries
5
5
  Author-email: Dan Foreman-Mackey <foreman.mackey@gmail.com>
6
6
  License: MIT
@@ -13,8 +13,12 @@ Classifier: Operating System :: OS Independent
13
13
  Classifier: Programming Language :: Python
14
14
  Classifier: Programming Language :: Python :: 3
15
15
  Requires-Python: >=3.9
16
+ Requires-Dist: equinox
16
17
  Requires-Dist: jax
17
18
  Requires-Dist: jaxlib
19
+ Provides-Extra: comparison
20
+ Requires-Dist: celerite; extra == 'comparison'
21
+ Requires-Dist: george; extra == 'comparison'
18
22
  Provides-Extra: docs
19
23
  Requires-Dist: arviz; extra == 'docs'
20
24
  Requires-Dist: flax; extra == 'docs'
@@ -27,8 +31,6 @@ Requires-Dist: optax; extra == 'docs'
27
31
  Requires-Dist: sphinx-book-theme; extra == 'docs'
28
32
  Requires-Dist: statsmodels; extra == 'docs'
29
33
  Provides-Extra: test
30
- Requires-Dist: celerite; extra == 'test'
31
- Requires-Dist: george; extra == 'test'
32
34
  Requires-Dist: pytest; extra == 'test'
33
35
  Description-Content-Type: text/markdown
34
36
 
@@ -5,6 +5,23 @@ Release Notes
5
5
 
6
6
  .. towncrier release notes start
7
7
 
8
+ tinygp 0.3.0 (2024-01-05)
9
+ -------------------------
10
+
11
+ Features
12
+ ~~~~~~~~
13
+
14
+ - Added a more robust and better tested implementation of the ``CARMA`` kernel for
15
+ use with the ``QuasisepSolver``. (`#90 <https://github.com/dfm/tinygp/issues/90>`_)
16
+ - Switched all base classes to `equinox.Module <https://docs.kidger.site/equinox/api/module/module/>`_ objects to simplify dataclass handling. (`#200 <https://github.com/dfm/tinygp/issues/200>`_)
17
+
18
+
19
+ Bugfixes
20
+ ~~~~~~~~
21
+
22
+ - Fixed use of `jnp.roots` and `np.roll` to make CARMA kernel jit-compliant. (`#188 <https://github.com/dfm/tinygp/issues/188>`_)
23
+
24
+
8
25
  tinygp 0.2.4 (2023-09-29)
9
26
  -------------------------
10
27
 
@@ -105,8 +105,7 @@
105
105
  "\n",
106
106
  "\n",
107
107
  "class DerivativeKernel(tinygp.kernels.Kernel):\n",
108
- " def __init__(self, kernel):\n",
109
- " self.kernel = kernel\n",
108
+ " kernel: tinygp.kernels.Kernel\n",
110
109
  "\n",
111
110
  " def evaluate(self, X1, X2):\n",
112
111
  " t1, d1 = X1\n",
@@ -301,6 +300,10 @@
301
300
  " shape as ``coeff_prim``.\n",
302
301
  " \"\"\"\n",
303
302
  "\n",
303
+ " kernel: tinygp.kernels.Kernel\n",
304
+ " coeff_prim: jax.Array\n",
305
+ " coeff_deriv: jax.Array\n",
306
+ "\n",
304
307
  " def __init__(self, kernel, coeff_prim, coeff_deriv):\n",
305
308
  " self.kernel = kernel\n",
306
309
  " self.coeff_prim, self.coeff_deriv = jnp.broadcast_arrays(\n",
@@ -497,7 +500,7 @@
497
500
  "hash": "d20ea8a315da34b3e8fab0dbd7b542a0ef3c8cf12937343660e6bc10a20768e3"
498
501
  },
499
502
  "kernelspec": {
500
- "display_name": "Python 3.9.9 ('tinygp')",
503
+ "display_name": "Python 3 (ipykernel)",
501
504
  "language": "python",
502
505
  "name": "python3"
503
506
  },
@@ -511,9 +514,9 @@
511
514
  "name": "python",
512
515
  "nbconvert_exporter": "python",
513
516
  "pygments_lexer": "ipython3",
514
- "version": "3.9.9"
517
+ "version": "3.10.6"
515
518
  }
516
519
  },
517
520
  "nbformat": 4,
518
- "nbformat_minor": 2
521
+ "nbformat_minor": 4
519
522
  }
@@ -54,10 +54,9 @@
54
54
  "\n",
55
55
  "\n",
56
56
  "class SpectralMixture(tinygp.kernels.Kernel):\n",
57
- " def __init__(self, weight, scale, freq):\n",
58
- " self.weight = jnp.atleast_1d(weight)\n",
59
- " self.scale = jnp.atleast_1d(scale)\n",
60
- " self.freq = jnp.atleast_1d(freq)\n",
57
+ " weight: jax.Array\n",
58
+ " scale: jax.Array\n",
59
+ " freq: jax.Array\n",
61
60
  "\n",
62
61
  " def evaluate(self, X1, X2):\n",
63
62
  " tau = jnp.atleast_1d(jnp.abs(X1 - X2))[..., None]\n",
@@ -210,7 +209,7 @@
210
209
  ],
211
210
  "metadata": {
212
211
  "kernelspec": {
213
- "display_name": "tinygp",
212
+ "display_name": "Python 3 (ipykernel)",
214
213
  "language": "python",
215
214
  "name": "python3"
216
215
  },
@@ -9,6 +9,12 @@ PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12"]
9
9
  @nox.session(python=PYTHON_VERSIONS)
10
10
  def test(session: nox.Session) -> None:
11
11
  session.install(".[test]")
12
+ session.run("pytest", *session.posargs)
13
+
14
+
15
+ @nox.session(python=PYTHON_VERSIONS)
16
+ def comparison(session: nox.Session) -> None:
17
+ session.install(".[test,comparison]")
12
18
  session.run("pytest", *session.posargs, env={"JAX_ENABLE_X64": "1"})
13
19
 
14
20
 
@@ -15,10 +15,11 @@ classifiers = [
15
15
  "Programming Language :: Python :: 3",
16
16
  ]
17
17
  dynamic = ["version"]
18
- dependencies = ["jax", "jaxlib"]
18
+ dependencies = ["jax", "jaxlib", "equinox"]
19
19
 
20
20
  [project.optional-dependencies]
21
- test = ["pytest", "george", "celerite"]
21
+ test = ["pytest"]
22
+ comparison = ["george", "celerite"]
22
23
  docs = [
23
24
  "sphinx-book-theme",
24
25
  "myst-nb",
@@ -11,8 +11,10 @@ from typing import (
11
11
  NamedTuple,
12
12
  )
13
13
 
14
+ import equinox as eqx
14
15
  import jax
15
16
  import jax.numpy as jnp
17
+ import numpy as np
16
18
 
17
19
  from tinygp import kernels, means
18
20
  from tinygp.helpers import JAXArray
@@ -20,12 +22,13 @@ from tinygp.kernels.quasisep import Quasisep
20
22
  from tinygp.noise import Diagonal, Noise
21
23
  from tinygp.solvers import DirectSolver, QuasisepSolver
22
24
  from tinygp.solvers.quasisep.core import SymmQSM
25
+ from tinygp.solvers.solver import Solver
23
26
 
24
27
  if TYPE_CHECKING:
25
28
  from tinygp.numpyro_support import TinyDistribution
26
29
 
27
30
 
28
- class GaussianProcess:
31
+ class GaussianProcess(eqx.Module):
29
32
  """An interface for designing a Gaussian Process regression model
30
33
 
31
34
  Args:
@@ -50,6 +53,15 @@ class GaussianProcess:
50
53
  algebra.
51
54
  """
52
55
 
56
+ num_data: int = eqx.field(static=True)
57
+ dtype: np.dtype = eqx.field(static=True)
58
+ kernel: kernels.Kernel
59
+ X: JAXArray
60
+ mean_function: means.MeanBase
61
+ mean: JAXArray
62
+ noise: Noise
63
+ solver: Solver
64
+
53
65
  def __init__(
54
66
  self,
55
67
  kernel: kernels.Kernel,
@@ -57,7 +69,7 @@ class GaussianProcess:
57
69
  *,
58
70
  diag: JAXArray | None = None,
59
71
  noise: Noise | None = None,
60
- mean: Callable[[JAXArray], JAXArray] | JAXArray | None = None,
72
+ mean: means.MeanBase | Callable[[JAXArray], JAXArray] | JAXArray | None = None,
61
73
  solver: Any | None = None,
62
74
  mean_value: JAXArray | None = None,
63
75
  covariance_value: Any | None = None,
@@ -66,7 +78,7 @@ class GaussianProcess:
66
78
  self.kernel = kernel
67
79
  self.X = X
68
80
 
69
- if callable(mean):
81
+ if isinstance(mean, means.MeanBase):
70
82
  self.mean_function = mean
71
83
  elif mean is None:
72
84
  self.mean_function = means.Mean(jnp.zeros(()))
@@ -76,7 +88,7 @@ class GaussianProcess:
76
88
  mean_value = jax.vmap(self.mean_function)(self.X)
77
89
  self.num_data = mean_value.shape[0]
78
90
  self.dtype = mean_value.dtype
79
- self.loc = self.mean = mean_value
91
+ self.mean = mean_value
80
92
  if self.mean.ndim != 1:
81
93
  raise ValueError(
82
94
  "Invalid mean shape: " f"expected ndim = 1, got ndim={self.mean.ndim}"
@@ -92,7 +104,7 @@ class GaussianProcess:
92
104
  solver = QuasisepSolver
93
105
  else:
94
106
  solver = DirectSolver
95
- self.solver = solver.init(
107
+ self.solver = solver(
96
108
  kernel,
97
109
  self.X,
98
110
  self.noise,
@@ -100,6 +112,10 @@ class GaussianProcess:
100
112
  **solver_kwargs,
101
113
  )
102
114
 
115
+ @property
116
+ def loc(self) -> JAXArray:
117
+ return self.mean
118
+
103
119
  @property
104
120
  def variance(self) -> JAXArray:
105
121
  return self.solver.variance()
@@ -209,7 +225,6 @@ class GaussianProcess:
209
225
 
210
226
  @partial(
211
227
  jax.jit,
212
- static_argnums=(0,),
213
228
  static_argnames=("include_mean", "return_var", "return_cov"),
214
229
  )
215
230
  def predict(
@@ -281,7 +296,7 @@ class GaussianProcess:
281
296
 
282
297
  return TinyDistribution(self, **kwargs)
283
298
 
284
- @partial(jax.jit, static_argnums=(0, 2))
299
+ @partial(jax.jit, static_argnums=(2,))
285
300
  def _sample(
286
301
  self,
287
302
  key: jax.random.KeyArray,
@@ -296,16 +311,16 @@ class GaussianProcess:
296
311
  self.solver.dot_triangular(normal_samples), 0, -1
297
312
  )
298
313
 
299
- @partial(jax.jit, static_argnums=0)
314
+ @jax.jit
300
315
  def _compute_log_prob(self, alpha: JAXArray) -> JAXArray:
301
316
  loglike = -0.5 * jnp.sum(jnp.square(alpha)) - self.solver.normalization()
302
317
  return jnp.where(jnp.isfinite(loglike), loglike, -jnp.inf)
303
318
 
304
- @partial(jax.jit, static_argnums=0)
319
+ @jax.jit
305
320
  def _get_alpha(self, y: JAXArray) -> JAXArray:
306
321
  return self.solver.solve_triangular(y - self.loc)
307
322
 
308
- @partial(jax.jit, static_argnums=(0, 3))
323
+ @partial(jax.jit, static_argnums=(3,))
309
324
  def _condition(
310
325
  self,
311
326
  y: JAXArray,
@@ -0,0 +1,19 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = ["JAXArray", "dataclass", "field"]
4
+
5
+ from typing import Any
6
+
7
+ import equinox as eqx
8
+ import jax
9
+
10
+ JAXArray = jax.Array
11
+
12
+
13
+ # The following is just for backwards compatibility since tinygp used to provide a
14
+ # custom dataclass implementation
15
+ field = eqx.field
16
+
17
+
18
+ def dataclass(clz: type[Any]) -> type[Any]:
19
+ return clz
@@ -11,23 +11,23 @@ __all__ = [
11
11
  "Polynomial",
12
12
  ]
13
13
 
14
- from abc import ABCMeta, abstractmethod
14
+ from abc import abstractmethod
15
15
  from collections.abc import Sequence
16
16
  from typing import TYPE_CHECKING, Any, Callable, Union
17
17
 
18
+ import equinox as eqx
18
19
  import jax
19
20
  import jax.numpy as jnp
20
21
 
21
- from tinygp.helpers import JAXArray, dataclass, field
22
+ from tinygp.helpers import JAXArray
22
23
 
23
24
  if TYPE_CHECKING:
24
25
  from tinygp.solvers.solver import Solver
25
26
 
26
-
27
27
  Axis = Union[int, Sequence[int]]
28
28
 
29
29
 
30
- class Kernel(metaclass=ABCMeta):
30
+ class Kernel(eqx.Module):
31
31
  """The base class for all kernel implementations
32
32
 
33
33
  This subclass provides default implementations to add and multiply kernels.
@@ -35,11 +35,6 @@ class Kernel(metaclass=ABCMeta):
35
35
  :func:`Kernel.evaluate` with custom behavior.
36
36
  """
37
37
 
38
- if TYPE_CHECKING:
39
-
40
- def __init__(self, *args: Any, **kwargs: Any) -> None:
41
- pass
42
-
43
38
  @abstractmethod
44
39
  def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
45
40
  """Evaluate the kernel at a pair of input coordinates
@@ -58,6 +53,7 @@ class Kernel(metaclass=ABCMeta):
58
53
  ``(n_data, n_dim)``, and you should let the :class:`Kernel` ``vmap``
59
54
  magic handle all the broadcasting for you.
60
55
  """
56
+ del X1, X2
61
57
  raise NotImplementedError
62
58
 
63
59
  def evaluate_diag(self, X: JAXArray) -> JAXArray:
@@ -130,7 +126,6 @@ class Kernel(metaclass=ABCMeta):
130
126
  return Product(Constant(other), self)
131
127
 
132
128
 
133
- @dataclass
134
129
  class Conditioned(Kernel):
135
130
  """A kernel used when conditioning a process on data
136
131
 
@@ -158,7 +153,6 @@ class Conditioned(Kernel):
158
153
  return self.kernel.evaluate_diag(X) - K.transpose() @ K
159
154
 
160
155
 
161
- @dataclass
162
156
  class Custom(Kernel):
163
157
  """A custom kernel class implemented as a callable
164
158
 
@@ -167,13 +161,12 @@ class Custom(Kernel):
167
161
  :func:`Kernel.evaluate`.
168
162
  """
169
163
 
170
- function: Callable[[Any, Any], Any]
164
+ function: Callable[[Any, Any], Any] = eqx.field(static=True)
171
165
 
172
166
  def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
173
167
  return self.function(X1, X2)
174
168
 
175
169
 
176
- @dataclass
177
170
  class Sum(Kernel):
178
171
  """A helper to represent the sum of two kernels"""
179
172
 
@@ -184,7 +177,6 @@ class Sum(Kernel):
184
177
  return self.kernel1.evaluate(X1, X2) + self.kernel2.evaluate(X1, X2)
185
178
 
186
179
 
187
- @dataclass
188
180
  class Product(Kernel):
189
181
  """A helper to represent the product of two kernels"""
190
182
 
@@ -195,7 +187,6 @@ class Product(Kernel):
195
187
  return self.kernel1.evaluate(X1, X2) * self.kernel2.evaluate(X1, X2)
196
188
 
197
189
 
198
- @dataclass
199
190
  class Constant(Kernel):
200
191
  r"""This kernel returns the constant
201
192
 
@@ -209,15 +200,15 @@ class Constant(Kernel):
209
200
  c: The parameter :math:`c` in the above equation.
210
201
  """
211
202
 
212
- value: JAXArray
203
+ value: JAXArray | float
213
204
 
214
205
  def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
206
+ del X1, X2
215
207
  if jnp.ndim(self.value) != 0:
216
208
  raise ValueError("The value of a constant kernel must be a scalar")
217
- return self.value
209
+ return jnp.asarray(self.value)
218
210
 
219
211
 
220
- @dataclass
221
212
  class DotProduct(Kernel):
222
213
  r"""The dot product kernel
223
214
 
@@ -234,7 +225,6 @@ class DotProduct(Kernel):
234
225
  return X1 @ X2
235
226
 
236
227
 
237
- @dataclass
238
228
  class Polynomial(Kernel):
239
229
  r"""A polynomial kernel
240
230
 
@@ -249,9 +239,9 @@ class Polynomial(Kernel):
249
239
  sigma: The parameter :math:`\sigma`.
250
240
  """
251
241
 
252
- order: JAXArray
253
- scale: JAXArray = field(default_factory=lambda: jnp.ones(()))
254
- sigma: JAXArray = field(default_factory=lambda: jnp.zeros(()))
242
+ order: JAXArray | float
243
+ scale: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))
244
+ sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.zeros(()))
255
245
 
256
246
  def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
257
247
  return (
@@ -11,14 +11,15 @@ from __future__ import annotations
11
11
 
12
12
  __all__ = ["Distance", "L1Distance", "L2Distance"]
13
13
 
14
- from abc import ABCMeta, abstractmethod
14
+ from abc import abstractmethod
15
15
 
16
+ import equinox as eqx
16
17
  import jax.numpy as jnp
17
18
 
18
- from tinygp.helpers import JAXArray, dataclass
19
+ from tinygp.helpers import JAXArray
19
20
 
20
21
 
21
- class Distance(metaclass=ABCMeta):
22
+ class Distance(eqx.Module):
22
23
  """An abstract base class defining a distance metric interface"""
23
24
 
24
25
  @abstractmethod
@@ -37,7 +38,6 @@ class Distance(metaclass=ABCMeta):
37
38
  return jnp.square(self.distance(X1, X2))
38
39
 
39
40
 
40
- @dataclass
41
41
  class L1Distance(Distance):
42
42
  """The L1 or Manhattan distance between two coordinates"""
43
43
 
@@ -45,7 +45,6 @@ class L1Distance(Distance):
45
45
  return jnp.sum(jnp.abs(X1 - X2))
46
46
 
47
47
 
48
- @dataclass
49
48
  class L2Distance(Distance):
50
49
  """The L2 or Euclidean distance between two coordinates"""
51
50