gpjax 0.13.4__tar.gz → 0.13.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (165) hide show
  1. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/auto-label.yml +1 -1
  2. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/build_docs.yml +2 -2
  3. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/commit-lint.yml +1 -1
  4. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/integration.yml +1 -1
  5. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/release.yml +10 -10
  6. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/ruff.yml +1 -1
  7. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/security-analysis.yml +3 -3
  8. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/test_docs.yml +1 -1
  9. {gpjax-0.13.4 → gpjax-0.13.5}/.github/workflows/tests.yml +1 -1
  10. {gpjax-0.13.4 → gpjax-0.13.5}/PKG-INFO +3 -2
  11. {gpjax-0.13.4 → gpjax-0.13.5}/docs/sharp_bits.md +32 -28
  12. {gpjax-0.13.4 → gpjax-0.13.5}/examples/heteroscedastic_inference.py +29 -24
  13. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/__init__.py +1 -1
  14. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/parameters.py +2 -1
  15. {gpjax-0.13.4 → gpjax-0.13.5}/mkdocs.yml +1 -2
  16. {gpjax-0.13.4 → gpjax-0.13.5}/pyproject.toml +5 -4
  17. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_parameters.py +20 -0
  18. gpjax-0.13.5/uv.lock +4279 -0
  19. gpjax-0.13.4/uv.lock +0 -3558
  20. {gpjax-0.13.4 → gpjax-0.13.5}/.github/CODE_OF_CONDUCT.md +0 -0
  21. {gpjax-0.13.4 → gpjax-0.13.5}/.github/FUNDING.yml +0 -0
  22. {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  23. {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  24. {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  25. {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  26. {gpjax-0.13.4 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  27. {gpjax-0.13.4 → gpjax-0.13.5}/.github/codecov.yml +0 -0
  28. {gpjax-0.13.4 → gpjax-0.13.5}/.github/commitlint.config.js +0 -0
  29. {gpjax-0.13.4 → gpjax-0.13.5}/.github/dependabot.yml +0 -0
  30. {gpjax-0.13.4 → gpjax-0.13.5}/.github/labeler.yml +0 -0
  31. {gpjax-0.13.4 → gpjax-0.13.5}/.github/labels.yml +0 -0
  32. {gpjax-0.13.4 → gpjax-0.13.5}/.github/pull_request_template.md +0 -0
  33. {gpjax-0.13.4 → gpjax-0.13.5}/.github/release-drafter.yml +0 -0
  34. {gpjax-0.13.4 → gpjax-0.13.5}/.gitignore +0 -0
  35. {gpjax-0.13.4 → gpjax-0.13.5}/CITATION.bib +0 -0
  36. {gpjax-0.13.4 → gpjax-0.13.5}/LICENSE.txt +0 -0
  37. {gpjax-0.13.4 → gpjax-0.13.5}/Makefile +0 -0
  38. {gpjax-0.13.4 → gpjax-0.13.5}/README.md +0 -0
  39. {gpjax-0.13.4 → gpjax-0.13.5}/docs/CODE_OF_CONDUCT.md +0 -0
  40. {gpjax-0.13.4 → gpjax-0.13.5}/docs/GOVERNANCE.md +0 -0
  41. {gpjax-0.13.4 → gpjax-0.13.5}/docs/contributing.md +0 -0
  42. {gpjax-0.13.4 → gpjax-0.13.5}/docs/design.md +0 -0
  43. {gpjax-0.13.4 → gpjax-0.13.5}/docs/index.md +0 -0
  44. {gpjax-0.13.4 → gpjax-0.13.5}/docs/index.rst +0 -0
  45. {gpjax-0.13.4 → gpjax-0.13.5}/docs/installation.md +0 -0
  46. {gpjax-0.13.4 → gpjax-0.13.5}/docs/javascripts/katex.js +0 -0
  47. {gpjax-0.13.4 → gpjax-0.13.5}/docs/refs.bib +0 -0
  48. {gpjax-0.13.4 → gpjax-0.13.5}/docs/scripts/gen_examples.py +0 -0
  49. {gpjax-0.13.4 → gpjax-0.13.5}/docs/scripts/gen_pages.py +0 -0
  50. {gpjax-0.13.4 → gpjax-0.13.5}/docs/scripts/notebook_converter.py +0 -0
  51. {gpjax-0.13.4 → gpjax-0.13.5}/docs/scripts/sharp_bits_figure.py +0 -0
  52. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/GP.pdf +0 -0
  53. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/GP.svg +0 -0
  54. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/bijector_figure.svg +0 -0
  55. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/css/gpjax_theme.css +0 -0
  56. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/favicon.ico +0 -0
  57. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/gpjax.mplstyle +0 -0
  58. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/gpjax_logo.pdf +0 -0
  59. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/gpjax_logo.svg +0 -0
  60. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/jaxkern/lato.ttf +0 -0
  61. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/jaxkern/logo.png +0 -0
  62. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/jaxkern/logo.svg +0 -0
  63. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/jaxkern/main.py +0 -0
  64. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/step_size_figure.png +0 -0
  65. {gpjax-0.13.4 → gpjax-0.13.5}/docs/static/step_size_figure.svg +0 -0
  66. {gpjax-0.13.4 → gpjax-0.13.5}/docs/stylesheets/extra.css +0 -0
  67. {gpjax-0.13.4 → gpjax-0.13.5}/docs/stylesheets/permalinks.css +0 -0
  68. {gpjax-0.13.4 → gpjax-0.13.5}/examples/backend.py +0 -0
  69. {gpjax-0.13.4 → gpjax-0.13.5}/examples/barycentres/barycentre_gp.gif +0 -0
  70. {gpjax-0.13.4 → gpjax-0.13.5}/examples/barycentres.py +0 -0
  71. {gpjax-0.13.4 → gpjax-0.13.5}/examples/classification.py +0 -0
  72. {gpjax-0.13.4 → gpjax-0.13.5}/examples/collapsed_vi.py +0 -0
  73. {gpjax-0.13.4 → gpjax-0.13.5}/examples/constructing_new_kernels.py +0 -0
  74. {gpjax-0.13.4 → gpjax-0.13.5}/examples/data/max_tempeature_switzerland.csv +0 -0
  75. {gpjax-0.13.4 → gpjax-0.13.5}/examples/data/yacht_hydrodynamics.data +0 -0
  76. {gpjax-0.13.4 → gpjax-0.13.5}/examples/deep_kernels.py +0 -0
  77. {gpjax-0.13.4 → gpjax-0.13.5}/examples/gpjax.mplstyle +0 -0
  78. {gpjax-0.13.4 → gpjax-0.13.5}/examples/graph_kernels.py +0 -0
  79. {gpjax-0.13.4 → gpjax-0.13.5}/examples/intro_to_gps/decomposed_mll.png +0 -0
  80. {gpjax-0.13.4 → gpjax-0.13.5}/examples/intro_to_gps/generating_process.png +0 -0
  81. {gpjax-0.13.4 → gpjax-0.13.5}/examples/intro_to_gps.py +0 -0
  82. {gpjax-0.13.4 → gpjax-0.13.5}/examples/intro_to_kernels.py +0 -0
  83. {gpjax-0.13.4 → gpjax-0.13.5}/examples/likelihoods_guide.py +0 -0
  84. {gpjax-0.13.4 → gpjax-0.13.5}/examples/oceanmodelling.py +0 -0
  85. {gpjax-0.13.4 → gpjax-0.13.5}/examples/poisson.py +0 -0
  86. {gpjax-0.13.4 → gpjax-0.13.5}/examples/regression.py +0 -0
  87. {gpjax-0.13.4 → gpjax-0.13.5}/examples/uncollapsed_vi.py +0 -0
  88. {gpjax-0.13.4 → gpjax-0.13.5}/examples/utils.py +0 -0
  89. {gpjax-0.13.4 → gpjax-0.13.5}/examples/yacht.py +0 -0
  90. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/citation.py +0 -0
  91. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/dataset.py +0 -0
  92. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/distributions.py +0 -0
  93. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/fit.py +0 -0
  94. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/gps.py +0 -0
  95. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/integrators.py +0 -0
  96. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/__init__.py +0 -0
  97. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/approximations/__init__.py +0 -0
  98. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/approximations/rff.py +0 -0
  99. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/base.py +0 -0
  100. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/__init__.py +0 -0
  101. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/base.py +0 -0
  102. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/basis_functions.py +0 -0
  103. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  104. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/dense.py +0 -0
  105. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/diagonal.py +0 -0
  106. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/computations/eigen.py +0 -0
  107. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  108. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/graph.py +0 -0
  109. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/utils.py +0 -0
  110. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/nonstationary/__init__.py +0 -0
  111. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  112. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/nonstationary/linear.py +0 -0
  113. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  114. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/__init__.py +0 -0
  115. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/base.py +0 -0
  116. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/matern12.py +0 -0
  117. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/matern32.py +0 -0
  118. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/matern52.py +0 -0
  119. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/periodic.py +0 -0
  120. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  121. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  122. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/rbf.py +0 -0
  123. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/utils.py +0 -0
  124. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/kernels/stationary/white.py +0 -0
  125. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/likelihoods.py +0 -0
  126. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/linalg/__init__.py +0 -0
  127. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/linalg/operations.py +0 -0
  128. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/linalg/operators.py +0 -0
  129. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/linalg/utils.py +0 -0
  130. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/mean_functions.py +0 -0
  131. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/numpyro_extras.py +0 -0
  132. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/objectives.py +0 -0
  133. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/scan.py +0 -0
  134. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/typing.py +0 -0
  135. {gpjax-0.13.4 → gpjax-0.13.5}/gpjax/variational_families.py +0 -0
  136. {gpjax-0.13.4 → gpjax-0.13.5}/static/CONTRIBUTING.md +0 -0
  137. {gpjax-0.13.4 → gpjax-0.13.5}/static/paper.bib +0 -0
  138. {gpjax-0.13.4 → gpjax-0.13.5}/static/paper.md +0 -0
  139. {gpjax-0.13.4 → gpjax-0.13.5}/static/paper.pdf +0 -0
  140. {gpjax-0.13.4 → gpjax-0.13.5}/tests/__init__.py +0 -0
  141. {gpjax-0.13.4 → gpjax-0.13.5}/tests/conftest.py +0 -0
  142. {gpjax-0.13.4 → gpjax-0.13.5}/tests/integration_tests.py +0 -0
  143. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_citations.py +0 -0
  144. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_dataset.py +0 -0
  145. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_fit.py +0 -0
  146. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_gaussian_distribution.py +0 -0
  147. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_gps.py +0 -0
  148. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_heteroscedastic.py +0 -0
  149. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_imports.py +0 -0
  150. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_integrators.py +0 -0
  151. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/__init__.py +0 -0
  152. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_approximations.py +0 -0
  153. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_base.py +0 -0
  154. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_computation.py +0 -0
  155. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_non_euclidean.py +0 -0
  156. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_nonstationary.py +0 -0
  157. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_stationary.py +0 -0
  158. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_kernels/test_utils.py +0 -0
  159. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_likelihoods.py +0 -0
  160. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_linalg.py +0 -0
  161. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_markdown.py +0 -0
  162. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_mean_functions.py +0 -0
  163. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_numpyro_extras.py +0 -0
  164. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_objectives.py +0 -0
  165. {gpjax-0.13.4 → gpjax-0.13.5}/tests/test_variational_families.py +0 -0
@@ -16,7 +16,7 @@ jobs:
16
16
 
17
17
  steps:
18
18
  - name: Checkout repository
19
- uses: actions/checkout@v5
19
+ uses: actions/checkout@v6
20
20
 
21
21
  - name: Auto-label based on files changed
22
22
  if: github.event_name == 'pull_request'
@@ -27,7 +27,7 @@ jobs:
27
27
  steps:
28
28
  # Grap the latest commit from the branch
29
29
  - name: Checkout the branch
30
- uses: actions/checkout@v5
30
+ uses: actions/checkout@v6
31
31
  with:
32
32
  persist-credentials: false
33
33
 
@@ -61,7 +61,7 @@ jobs:
61
61
  uv run mkdocs build
62
62
 
63
63
  - name: Deploy Page 🚀
64
- uses: JamesIves/github-pages-deploy-action@v4.7.4
64
+ uses: JamesIves/github-pages-deploy-action@v4.8.0
65
65
  with:
66
66
  branch: gh-pages
67
67
  folder: site
@@ -13,7 +13,7 @@ jobs:
13
13
 
14
14
  steps:
15
15
  - name: Checkout repository
16
- uses: actions/checkout@v5
16
+ uses: actions/checkout@v6
17
17
  with:
18
18
  fetch-depth: 0
19
19
 
@@ -17,7 +17,7 @@ jobs:
17
17
  fail-fast: true
18
18
  steps:
19
19
  - name: Check out the code
20
- uses: actions/checkout@v5
20
+ uses: actions/checkout@v6
21
21
  with:
22
22
  fetch-depth: 1
23
23
  - name: Set up Python ${{ matrix.python-version }}
@@ -29,7 +29,7 @@ jobs:
29
29
 
30
30
  steps:
31
31
  - name: Checkout repository
32
- uses: actions/checkout@v5
32
+ uses: actions/checkout@v6
33
33
  with:
34
34
  fetch-depth: 0
35
35
 
@@ -72,7 +72,7 @@ jobs:
72
72
 
73
73
  steps:
74
74
  - name: Checkout repository
75
- uses: actions/checkout@v5
75
+ uses: actions/checkout@v6
76
76
 
77
77
  - name: Set up Python ${{ matrix.python-version }}
78
78
  uses: actions/setup-python@v6
@@ -108,7 +108,7 @@ jobs:
108
108
 
109
109
  steps:
110
110
  - name: Checkout repository
111
- uses: actions/checkout@v5
111
+ uses: actions/checkout@v6
112
112
 
113
113
  - name: Set up Python
114
114
  uses: actions/setup-python@v6
@@ -132,7 +132,7 @@ jobs:
132
132
  uv run bandit -r gpjax/ -f json -o bandit-report.json || echo "Bandit scan completed with warnings"
133
133
 
134
134
  - name: Upload security reports
135
- uses: actions/upload-artifact@v5
135
+ uses: actions/upload-artifact@v6
136
136
  with:
137
137
  name: security-reports
138
138
  path: |
@@ -146,7 +146,7 @@ jobs:
146
146
 
147
147
  steps:
148
148
  - name: Checkout repository
149
- uses: actions/checkout@v5
149
+ uses: actions/checkout@v6
150
150
 
151
151
  - name: Set up Python
152
152
  uses: actions/setup-python@v6
@@ -166,7 +166,7 @@ jobs:
166
166
  uv run twine check dist/*
167
167
 
168
168
  - name: Upload build artifacts
169
- uses: actions/upload-artifact@v5
169
+ uses: actions/upload-artifact@v6
170
170
  with:
171
171
  name: dist-packages
172
172
  path: dist/
@@ -181,7 +181,7 @@ jobs:
181
181
 
182
182
  steps:
183
183
  - name: Checkout repository
184
- uses: actions/checkout@v5
184
+ uses: actions/checkout@v6
185
185
  with:
186
186
  fetch-depth: 0
187
187
 
@@ -261,10 +261,10 @@ jobs:
261
261
 
262
262
  steps:
263
263
  - name: Checkout repository
264
- uses: actions/checkout@v5
264
+ uses: actions/checkout@v6
265
265
 
266
266
  - name: Download build artifacts
267
- uses: actions/download-artifact@v6
267
+ uses: actions/download-artifact@v7
268
268
  with:
269
269
  name: dist-packages
270
270
  path: dist/
@@ -294,7 +294,7 @@ jobs:
294
294
 
295
295
  steps:
296
296
  - name: Download build artifacts
297
- uses: actions/download-artifact@v6
297
+ uses: actions/download-artifact@v7
298
298
  with:
299
299
  name: dist-packages
300
300
  path: dist/
@@ -8,5 +8,5 @@ jobs:
8
8
  ruff:
9
9
  runs-on: ubuntu-latest
10
10
  steps:
11
- - uses: actions/checkout@v5
11
+ - uses: actions/checkout@v6
12
12
  - uses: chartboost/ruff-action@v1
@@ -20,7 +20,7 @@ jobs:
20
20
 
21
21
  steps:
22
22
  - name: Checkout repository
23
- uses: actions/checkout@v5
23
+ uses: actions/checkout@v6
24
24
 
25
25
  - name: Set up Python
26
26
  uses: actions/setup-python@v6
@@ -47,7 +47,7 @@ jobs:
47
47
  uv run bandit -r gpjax/ -f json -o bandit-report.json || true
48
48
 
49
49
  - name: Upload dependency scan results
50
- uses: actions/upload-artifact@v5
50
+ uses: actions/upload-artifact@v6
51
51
  if: always()
52
52
  with:
53
53
  name: security-scan-results
@@ -62,7 +62,7 @@ jobs:
62
62
 
63
63
  steps:
64
64
  - name: Checkout repository
65
- uses: actions/checkout@v5
65
+ uses: actions/checkout@v6
66
66
  with:
67
67
  fetch-depth: 0 # Fetch full history for comprehensive scanning
68
68
 
@@ -18,7 +18,7 @@ jobs:
18
18
  steps:
19
19
  # Grap the latest commit from the branch
20
20
  - name: Checkout the branch
21
- uses: actions/checkout@v5
21
+ uses: actions/checkout@v6
22
22
  with:
23
23
  persist-credentials: false
24
24
 
@@ -16,7 +16,7 @@ jobs:
16
16
  runs-on: ${{ matrix.os }}
17
17
  steps:
18
18
  - name: Check out the code
19
- uses: actions/checkout@v5
19
+ uses: actions/checkout@v6
20
20
  with:
21
21
  fetch-depth: 1
22
22
  - name: Set up Python ${{ matrix.python-version }}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.13.4
3
+ Version: 0.13.5
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
@@ -25,6 +25,7 @@ Requires-Dist: jaxtyping>0.2.10
25
25
  Requires-Dist: numpy>=2.0.0
26
26
  Requires-Dist: numpyro
27
27
  Requires-Dist: optax>0.2.1
28
+ Requires-Dist: tensorstore!=0.1.76; sys_platform == 'darwin'
28
29
  Requires-Dist: tqdm>4.66.2
29
30
  Provides-Extra: dev
30
31
  Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
@@ -59,7 +60,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
59
60
  Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
60
61
  Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
61
62
  Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
62
- Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
63
+ Requires-Dist: mkdocstrings[python]<1.1.0; extra == 'docs'
63
64
  Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
64
65
  Requires-Dist: networkx>=3.0; extra == 'docs'
65
66
  Requires-Dist: pandas>=1.5.3; extra == 'docs'
@@ -178,12 +178,19 @@ points. We demonstrate its use in
178
178
 
179
179
  ## JIT compilation
180
180
 
181
- There are a subset of operations in GPJax that are not JIT compatible by default. This
182
- is because we have assertions in place to check the properties of the parameters. For
183
- example, we check that the lengthscale parameter that a user provides is positive. This
184
- makes for a better user experience as we can provide more informative error messages;
185
- however, JIT compiling functions wherein these assertions are made will break the code.
186
- As an example, consider the following code:
181
+ GPJax validates parameters at construction time using two kinds of checks:
182
+
183
+ 1. **Type checks** plain Python `isinstance` checks that verify values are array-like.
184
+ 2. **Value checks** JAX-compatible assertions (via `checkify`) that verify constraints
185
+ like positivity or bounds.
186
+
187
+ During JIT tracing, concrete values are replaced by abstract tracers. The type checks
188
+ use `isinstance`, which is a pure Python operation that cannot be intercepted by JAX's
189
+ `checkify` transformation. This means that constructing GPJax objects (kernels, mean
190
+ functions, likelihoods, etc.) **inside** a JIT boundary will fail.
191
+
192
+ As an example, consider the following code that constructs a kernel inside a
193
+ JIT-compiled function:
187
194
 
188
195
  ```python
189
196
  import jax
@@ -192,43 +199,40 @@ import gpjax as gpx
192
199
 
193
200
  x = jnp.linspace(0, 1, 10)[:, None]
194
201
 
195
- def compute_gram(lengthscale):
202
+ def compute_gram_bad(lengthscale):
196
203
  k = gpx.kernels.RBF(active_dims=[0], lengthscale=lengthscale, variance=jnp.array(1.0))
197
204
  return k.gram(x)
198
205
 
199
- compute_gram(1.0)
206
+ compute_gram_bad(1.0) # works fine outside JIT
200
207
  ```
201
208
 
202
- so far so good. However, if we try to JIT compile this function, we will get an error:
209
+ If we try to JIT compile this function, we get a `TypeError` because the kernel
210
+ constructor receives a JAX tracer instead of a concrete array:
203
211
 
204
212
  ```python
205
- jit_compute_gram = jax.jit(compute_gram)
213
+ jit_compute_gram_bad = jax.jit(compute_gram_bad)
206
214
  try:
207
- jit_compute_gram(1.0)
215
+ jit_compute_gram_bad(1.0)
208
216
  except Exception as e:
209
217
  print(e)
210
218
  ```
211
219
 
212
- This error is due to the fact that the `RBF` kernel contains an assertion that checks
213
- that the lengthscale is positive. It does not matter that the assertion is satisfied;
214
- the very presence of the assertion will break JIT compilation.
220
+ ### The fix: construct objects outside JIT
215
221
 
216
- To resolve this, we can use the `checkify` decorator to remove the assertion. This will
217
- allow the function to be JIT compiled.
222
+ The solution is to construct GPJax objects **outside** the JIT boundary and only JIT the
223
+ computation itself. This follows the standard JAX pattern of keeping object construction
224
+ separate from traced computation:
218
225
 
219
226
  ```python
220
- from jax.experimental import checkify
227
+ k = gpx.kernels.RBF(active_dims=[0], lengthscale=1.0, variance=jnp.array(1.0))
221
228
 
222
- jit_compute_gram = jax.jit(checkify.checkify(compute_gram))
223
- error, value = jit_compute_gram(1.0)
224
- ```
225
- By virtue of the `checkify.checkify`, a tuple is returned where the first element is the
226
- output of the assertion, and the second element is the value of the function.
229
+ @jax.jit
230
+ def compute_gram(x):
231
+ return k.gram(x)
227
232
 
228
- This design is not perfect, and in an ideal world we would not enforce the user to wrap
229
- their code in `checkify.checkify`. We are actively looking into cleaner ways to provide
230
- guardrails in a less intrusive manner. However, for now, should you try to JIT compile
231
- a component of GPJax wherein there is an assertion, you will need to wrap the function
232
- in `checkify.checkify` as shown above.
233
+ result = compute_gram(x)
234
+ ```
233
235
 
234
- For more on `checkify`, please see the [JAX Checkify Doc](https://docs.jax.dev/en/latest/debugging/checkify_guide.html).
236
+ More generally, any GPJax object should be constructed outside of `jax.jit`, `jax.vmap`,
237
+ or `jax.grad` boundaries. Once constructed, their methods can be freely used inside
238
+ these JAX transformations.
@@ -16,7 +16,7 @@
16
16
  # ---
17
17
 
18
18
  # %% [markdown]
19
- # # Heteroscedastic inference for regression and classification
19
+ # # Heteroscedastic Inference
20
20
  #
21
21
  # This notebook shows how to fit a heteroscedastic Gaussian processes (GPs) that
22
22
  # allows one to perform regression where there exists non-constant, or
@@ -293,13 +293,13 @@ posterior_adv = mean_prior * likelihood_adv
293
293
  # The signal requires a richer inducing set to capture its oscillations, whereas the
294
294
  # noise process can be summarised with fewer points because the burst is localised.
295
295
  z_signal = jnp.linspace(-2.0, 2.0, 30)[:, None]
296
- z_noise = jnp.linspace(-2.0, 2.0, 15)[:, None]
296
+ z_noise = jnp.linspace(-2.0, 2.0, 20)[:, None]
297
297
 
298
298
  # Use VariationalGaussianInit to pass specific configurations
299
299
  q_init_f = VariationalGaussianInit(inducing_inputs=z_signal)
300
300
  q_init_g = VariationalGaussianInit(inducing_inputs=z_noise)
301
301
 
302
- q_adv = HeteroscedasticVariationalFamily(
302
+ q_sparse = HeteroscedasticVariationalFamily(
303
303
  posterior=posterior_adv,
304
304
  signal_init=q_init_f,
305
305
  noise_init=q_init_g,
@@ -315,19 +315,19 @@ q_adv = HeteroscedasticVariationalFamily(
315
315
  # Optimize
316
316
  objective_adv = lambda model, data: -gpx.objectives.heteroscedastic_elbo(model, data)
317
317
  optimiser_adv = ox.adam(1e-2)
318
- q_adv_trained, _ = gpx.fit(
319
- model=q_adv,
318
+ q_sparse_trained, _ = gpx.fit(
319
+ model=q_sparse,
320
320
  objective=objective_adv,
321
321
  train_data=data_adv,
322
322
  optim=optimiser_adv,
323
- num_iters=8000,
323
+ num_iters=10000,
324
324
  verbose=False,
325
325
  )
326
326
 
327
327
  # %%
328
328
  # Plotting
329
- xtest = jnp.linspace(-2.2, 2.2, 200)[:, None]
330
- pred = q_adv_trained.predict(xtest)
329
+ xtest = jnp.linspace(-2.2, 2.2, 300)[:, None]
330
+ pred = q_sparse_trained.predict(xtest)
331
331
 
332
332
  # Unpack the named tuple
333
333
  mf = pred.mean_f
@@ -339,36 +339,41 @@ vg = pred.variance_g
339
339
  # The likelihood expects the *latent* noise distribution to compute the predictive
340
340
  # but here we can just use the transformed expected variance for plotting.
341
341
  # For accurate predictive intervals, we should use likelihood.predict.
342
- signal_dist, noise_dist = q_adv_trained.predict_latents(xtest)
342
+ signal_dist, noise_dist = q_sparse_trained.predict_latents(xtest)
343
343
  predictive_dist = likelihood_adv.predict(signal_dist, noise_dist)
344
344
  predictive_mean = predictive_dist.mean
345
345
  predictive_std = jnp.sqrt(jnp.diag(predictive_dist.covariance_matrix))
346
346
 
347
- fig, ax = plt.subplots()
348
- ax.plot(x, y, "o", color="black", alpha=0.3, label="Data")
349
- ax.plot(xtest, mf, color="C0", label="Signal Mean")
347
+ fig, ax = plt.subplots(figsize=(6, 2.5))
348
+ ax.plot(x, y, "x", color="black", alpha=0.5, label="Data")
349
+
350
+ # Plot total uncertainty (signal + noise)
351
+ ax.plot(xtest, predictive_mean, "--", color=cols[1], linewidth=2)
350
352
  ax.fill_between(
351
353
  xtest.squeeze(),
352
- mf.squeeze() - 2 * jnp.sqrt(vf.squeeze()),
353
- mf.squeeze() + 2 * jnp.sqrt(vf.squeeze()),
354
- color="C0",
355
- alpha=0.2,
356
- label="Signal Uncertainty",
354
+ predictive_mean - predictive_std,
355
+ predictive_mean + predictive_std,
356
+ color=cols[1],
357
+ alpha=0.3,
358
+ label="One std. dev.",
357
359
  )
358
-
359
- # Plot total uncertainty (signal + noise)
360
- ax.plot(xtest, predictive_mean, "--", color="C1", alpha=0.5)
360
+ ax.plot(xtest.squeeze(), predictive_mean - predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
361
+ ax.plot(xtest.squeeze(), predictive_mean + predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
361
362
  ax.fill_between(
362
363
  xtest.squeeze(),
363
364
  predictive_mean - 2 * predictive_std,
364
365
  predictive_mean + 2 * predictive_std,
365
- color="C1",
366
+ color=cols[1],
366
367
  alpha=0.1,
367
- label="Predictive Uncertainty (95%)",
368
+ label="Two std. dev.",
368
369
  )
370
+ ax.plot(xtest.squeeze(), predictive_mean - 2 * predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
371
+ ax.plot(xtest.squeeze(), predictive_mean + 2 * predictive_std, "--", color=cols[1], alpha=0.5, linewidth=0.75)
369
372
 
370
- ax.set_title("Heteroscedastic Regression with Custom Inducing Points")
371
- ax.legend(loc="upper left", fontsize="small")
373
+ ax.set_title("Sparse Heteroscedastic Regression")
374
+ ax.legend(loc="best", fontsize="small")
375
+ ax.set_xlabel("$x$")
376
+ ax.set_ylabel("$y$")
372
377
 
373
378
  # %% [markdown]
374
379
  # ## Takeaways
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/thomaspinder/GPJax"
42
42
  __contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
43
- __version__ = "0.13.4"
43
+ __version__ = "0.13.5"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
@@ -1,6 +1,7 @@
1
1
  import typing as tp
2
2
 
3
3
  from flax import nnx
4
+ import jax
4
5
  from jax.experimental import checkify
5
6
  import jax.numpy as jnp
6
7
  import jax.tree_util as jtu
@@ -162,7 +163,7 @@ def _check_is_arraylike(value: T) -> None:
162
163
  Raises:
163
164
  TypeError: If the value is not array-like.
164
165
  """
165
- if not isinstance(value, (ArrayLike, list)):
166
+ if not isinstance(value, (jax.Array, ArrayLike, list)):
166
167
  raise TypeError(
167
168
  f"Expected parameter value to be an array-like type. Got {value}."
168
169
  )
@@ -118,10 +118,9 @@ plugins:
118
118
  handlers:
119
119
  python:
120
120
  paths: ["gpjax"]
121
- rendering:
121
+ options:
122
122
  show_symbol_type_toc: true
123
123
  show_signature_annotations: true
124
- options:
125
124
  members_order: source
126
125
  inherited_members: true
127
126
  show_source: false
@@ -30,13 +30,14 @@ dependencies = [
30
30
  "beartype>0.16.1",
31
31
  "flax>=0.12.0",
32
32
  "numpy>=2.0.0",
33
+ "tensorstore!=0.1.76; sys_platform == 'darwin'",
33
34
  ]
34
35
 
35
36
  [project.optional-dependencies]
36
37
  docs = [
37
38
  "mkdocs>=1.5.3",
38
39
  "mkdocs-material>=9.5.12",
39
- "mkdocstrings[python]<0.31.0",
40
+ "mkdocstrings[python]<1.1.0",
40
41
  "mkdocs-jupyter>=0.24.3",
41
42
  "mkdocs-gen-files>=0.5.0",
42
43
  "mkdocs-literate-nav>=0.6.0",
@@ -78,6 +79,7 @@ dev = [
78
79
  ]
79
80
 
80
81
  [tool.uv]
82
+ exclude-newer = "7 days"
81
83
  managed = true
82
84
  dev-dependencies = [
83
85
  "ruff>=0.6",
@@ -260,12 +262,11 @@ convention = "numpy"
260
262
 
261
263
  [tool.ruff.lint.per-file-ignores]
262
264
  "gpjax/__init__.py" = ['I', 'F401', 'E402', 'D104']
263
- "gpjax/progress_bar.py" = ["TCH004"]
264
265
  "gpjax/scan.py" = ["PLR0913"]
265
266
  "gpjax/citation.py" = ["F811"]
266
- "tests/test_base/test_module.py" = ["PLR0915"]
267
267
  "tests/test_objectives.py" = ["PLR0913"]
268
- "docs/examples/barycentres.py" = ["PLR0913"]
268
+ "examples/barycentres.py" = ["PLR0913"]
269
+ "tests/*.py" = ["PLW0108"]
269
270
 
270
271
  [tool.isort]
271
272
  profile = "black"
@@ -1,3 +1,4 @@
1
+ import jax
1
2
  from flax import nnx
2
3
  from jax import jit
3
4
  from jax.experimental import checkify
@@ -109,3 +110,22 @@ def test_check_in_bounds():
109
110
  _safe_assert(
110
111
  _check_in_bounds, jnp.array(1.5), low=jnp.array(0.0), high=jnp.array(1.0)
111
112
  )
113
+
114
+
115
+ @pytest.mark.parametrize(
116
+ "param_cls, value",
117
+ [
118
+ (PositiveReal, jnp.array(1.0)),
119
+ (PositiveReal, jnp.array([1.0, 2.0])),
120
+ (Real, jnp.array(1.0)),
121
+ (NonNegativeReal, jnp.array(1.0)),
122
+ ],
123
+ )
124
+ def test_parameter_construction_under_grad(param_cls, value):
125
+ """Regression test for #592: parameter construction must accept JAX tracers."""
126
+
127
+ def f(x):
128
+ return param_cls(x).value.sum()
129
+
130
+ grad = jax.grad(f)(value)
131
+ assert grad.shape == value.shape