gpjax 0.13.3__tar.gz → 0.13.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/auto-label.yml +1 -1
  2. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/build_docs.yml +2 -2
  3. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/commit-lint.yml +1 -1
  4. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/integration.yml +1 -1
  5. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/release.yml +10 -10
  6. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/ruff.yml +1 -1
  7. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/security-analysis.yml +3 -3
  8. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/test_docs.yml +1 -1
  9. {gpjax-0.13.3 → gpjax-0.13.5}/.github/workflows/tests.yml +1 -1
  10. {gpjax-0.13.3 → gpjax-0.13.5}/.gitignore +5 -0
  11. {gpjax-0.13.3 → gpjax-0.13.5}/PKG-INFO +4 -3
  12. {gpjax-0.13.3 → gpjax-0.13.5}/README.md +1 -1
  13. {gpjax-0.13.3 → gpjax-0.13.5}/docs/sharp_bits.md +32 -28
  14. gpjax-0.13.5/examples/heteroscedastic_inference.py +394 -0
  15. {gpjax-0.13.3 → gpjax-0.13.5}/examples/regression.py +24 -23
  16. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/__init__.py +1 -1
  17. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/citation.py +13 -0
  18. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/gps.py +77 -0
  19. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/likelihoods.py +234 -0
  20. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/mean_functions.py +2 -2
  21. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/objectives.py +56 -1
  22. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/parameters.py +10 -2
  23. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/variational_families.py +129 -0
  24. {gpjax-0.13.3 → gpjax-0.13.5}/mkdocs.yml +2 -2
  25. {gpjax-0.13.3 → gpjax-0.13.5}/pyproject.toml +7 -5
  26. {gpjax-0.13.3 → gpjax-0.13.5}/tests/conftest.py +7 -0
  27. {gpjax-0.13.3 → gpjax-0.13.5}/tests/integration_tests.py +9 -2
  28. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_citations.py +16 -0
  29. gpjax-0.13.5/tests/test_heteroscedastic.py +407 -0
  30. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_mean_functions.py +16 -1
  31. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_parameters.py +20 -0
  32. gpjax-0.13.5/uv.lock +4279 -0
  33. gpjax-0.13.3/.github/workflows/pr_greeting.yml +0 -62
  34. gpjax-0.13.3/uv.lock +0 -3535
  35. {gpjax-0.13.3 → gpjax-0.13.5}/.github/CODE_OF_CONDUCT.md +0 -0
  36. {gpjax-0.13.3 → gpjax-0.13.5}/.github/FUNDING.yml +0 -0
  37. {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  38. {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  39. {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  40. {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  41. {gpjax-0.13.3 → gpjax-0.13.5}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  42. {gpjax-0.13.3 → gpjax-0.13.5}/.github/codecov.yml +0 -0
  43. {gpjax-0.13.3 → gpjax-0.13.5}/.github/commitlint.config.js +0 -0
  44. {gpjax-0.13.3 → gpjax-0.13.5}/.github/dependabot.yml +0 -0
  45. {gpjax-0.13.3 → gpjax-0.13.5}/.github/labeler.yml +0 -0
  46. {gpjax-0.13.3 → gpjax-0.13.5}/.github/labels.yml +0 -0
  47. {gpjax-0.13.3 → gpjax-0.13.5}/.github/pull_request_template.md +0 -0
  48. {gpjax-0.13.3 → gpjax-0.13.5}/.github/release-drafter.yml +0 -0
  49. {gpjax-0.13.3 → gpjax-0.13.5}/CITATION.bib +0 -0
  50. {gpjax-0.13.3 → gpjax-0.13.5}/LICENSE.txt +0 -0
  51. {gpjax-0.13.3 → gpjax-0.13.5}/Makefile +0 -0
  52. {gpjax-0.13.3 → gpjax-0.13.5}/docs/CODE_OF_CONDUCT.md +0 -0
  53. {gpjax-0.13.3 → gpjax-0.13.5}/docs/GOVERNANCE.md +0 -0
  54. {gpjax-0.13.3 → gpjax-0.13.5}/docs/contributing.md +0 -0
  55. {gpjax-0.13.3 → gpjax-0.13.5}/docs/design.md +0 -0
  56. {gpjax-0.13.3 → gpjax-0.13.5}/docs/index.md +0 -0
  57. {gpjax-0.13.3 → gpjax-0.13.5}/docs/index.rst +0 -0
  58. {gpjax-0.13.3 → gpjax-0.13.5}/docs/installation.md +0 -0
  59. {gpjax-0.13.3 → gpjax-0.13.5}/docs/javascripts/katex.js +0 -0
  60. {gpjax-0.13.3 → gpjax-0.13.5}/docs/refs.bib +0 -0
  61. {gpjax-0.13.3 → gpjax-0.13.5}/docs/scripts/gen_examples.py +0 -0
  62. {gpjax-0.13.3 → gpjax-0.13.5}/docs/scripts/gen_pages.py +0 -0
  63. {gpjax-0.13.3 → gpjax-0.13.5}/docs/scripts/notebook_converter.py +0 -0
  64. {gpjax-0.13.3 → gpjax-0.13.5}/docs/scripts/sharp_bits_figure.py +0 -0
  65. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/GP.pdf +0 -0
  66. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/GP.svg +0 -0
  67. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/bijector_figure.svg +0 -0
  68. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/css/gpjax_theme.css +0 -0
  69. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/favicon.ico +0 -0
  70. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/gpjax.mplstyle +0 -0
  71. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/gpjax_logo.pdf +0 -0
  72. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/gpjax_logo.svg +0 -0
  73. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/jaxkern/lato.ttf +0 -0
  74. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/jaxkern/logo.png +0 -0
  75. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/jaxkern/logo.svg +0 -0
  76. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/jaxkern/main.py +0 -0
  77. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/step_size_figure.png +0 -0
  78. {gpjax-0.13.3 → gpjax-0.13.5}/docs/static/step_size_figure.svg +0 -0
  79. {gpjax-0.13.3 → gpjax-0.13.5}/docs/stylesheets/extra.css +0 -0
  80. {gpjax-0.13.3 → gpjax-0.13.5}/docs/stylesheets/permalinks.css +0 -0
  81. {gpjax-0.13.3 → gpjax-0.13.5}/examples/backend.py +0 -0
  82. {gpjax-0.13.3 → gpjax-0.13.5}/examples/barycentres/barycentre_gp.gif +0 -0
  83. {gpjax-0.13.3 → gpjax-0.13.5}/examples/barycentres.py +0 -0
  84. {gpjax-0.13.3 → gpjax-0.13.5}/examples/classification.py +0 -0
  85. {gpjax-0.13.3 → gpjax-0.13.5}/examples/collapsed_vi.py +0 -0
  86. {gpjax-0.13.3 → gpjax-0.13.5}/examples/constructing_new_kernels.py +0 -0
  87. {gpjax-0.13.3 → gpjax-0.13.5}/examples/data/max_tempeature_switzerland.csv +0 -0
  88. {gpjax-0.13.3 → gpjax-0.13.5}/examples/data/yacht_hydrodynamics.data +0 -0
  89. {gpjax-0.13.3 → gpjax-0.13.5}/examples/deep_kernels.py +0 -0
  90. {gpjax-0.13.3 → gpjax-0.13.5}/examples/gpjax.mplstyle +0 -0
  91. {gpjax-0.13.3 → gpjax-0.13.5}/examples/graph_kernels.py +0 -0
  92. {gpjax-0.13.3 → gpjax-0.13.5}/examples/intro_to_gps/decomposed_mll.png +0 -0
  93. {gpjax-0.13.3 → gpjax-0.13.5}/examples/intro_to_gps/generating_process.png +0 -0
  94. {gpjax-0.13.3 → gpjax-0.13.5}/examples/intro_to_gps.py +0 -0
  95. {gpjax-0.13.3 → gpjax-0.13.5}/examples/intro_to_kernels.py +0 -0
  96. {gpjax-0.13.3 → gpjax-0.13.5}/examples/likelihoods_guide.py +0 -0
  97. {gpjax-0.13.3 → gpjax-0.13.5}/examples/oceanmodelling.py +0 -0
  98. {gpjax-0.13.3 → gpjax-0.13.5}/examples/poisson.py +0 -0
  99. {gpjax-0.13.3 → gpjax-0.13.5}/examples/uncollapsed_vi.py +0 -0
  100. {gpjax-0.13.3 → gpjax-0.13.5}/examples/utils.py +0 -0
  101. {gpjax-0.13.3 → gpjax-0.13.5}/examples/yacht.py +0 -0
  102. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/dataset.py +0 -0
  103. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/distributions.py +0 -0
  104. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/fit.py +0 -0
  105. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/integrators.py +0 -0
  106. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/__init__.py +0 -0
  107. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/approximations/__init__.py +0 -0
  108. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/approximations/rff.py +0 -0
  109. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/base.py +0 -0
  110. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/__init__.py +0 -0
  111. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/base.py +0 -0
  112. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/basis_functions.py +0 -0
  113. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  114. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/dense.py +0 -0
  115. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/diagonal.py +0 -0
  116. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/computations/eigen.py +0 -0
  117. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  118. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/graph.py +0 -0
  119. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/non_euclidean/utils.py +0 -0
  120. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/nonstationary/__init__.py +0 -0
  121. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  122. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/nonstationary/linear.py +0 -0
  123. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  124. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/__init__.py +0 -0
  125. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/base.py +0 -0
  126. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/matern12.py +0 -0
  127. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/matern32.py +0 -0
  128. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/matern52.py +0 -0
  129. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/periodic.py +0 -0
  130. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  131. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  132. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/rbf.py +0 -0
  133. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/utils.py +0 -0
  134. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/kernels/stationary/white.py +0 -0
  135. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/linalg/__init__.py +0 -0
  136. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/linalg/operations.py +0 -0
  137. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/linalg/operators.py +0 -0
  138. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/linalg/utils.py +0 -0
  139. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/numpyro_extras.py +0 -0
  140. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/scan.py +0 -0
  141. {gpjax-0.13.3 → gpjax-0.13.5}/gpjax/typing.py +0 -0
  142. {gpjax-0.13.3 → gpjax-0.13.5}/static/CONTRIBUTING.md +0 -0
  143. {gpjax-0.13.3 → gpjax-0.13.5}/static/paper.bib +0 -0
  144. {gpjax-0.13.3 → gpjax-0.13.5}/static/paper.md +0 -0
  145. {gpjax-0.13.3 → gpjax-0.13.5}/static/paper.pdf +0 -0
  146. {gpjax-0.13.3 → gpjax-0.13.5}/tests/__init__.py +0 -0
  147. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_dataset.py +0 -0
  148. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_fit.py +0 -0
  149. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_gaussian_distribution.py +0 -0
  150. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_gps.py +0 -0
  151. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_imports.py +0 -0
  152. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_integrators.py +0 -0
  153. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/__init__.py +0 -0
  154. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_approximations.py +0 -0
  155. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_base.py +0 -0
  156. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_computation.py +0 -0
  157. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_non_euclidean.py +0 -0
  158. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_nonstationary.py +0 -0
  159. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_stationary.py +0 -0
  160. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_kernels/test_utils.py +0 -0
  161. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_likelihoods.py +0 -0
  162. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_linalg.py +0 -0
  163. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_markdown.py +0 -0
  164. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_numpyro_extras.py +0 -0
  165. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_objectives.py +0 -0
  166. {gpjax-0.13.3 → gpjax-0.13.5}/tests/test_variational_families.py +0 -0
@@ -16,7 +16,7 @@ jobs:
16
16
 
17
17
  steps:
18
18
  - name: Checkout repository
19
- uses: actions/checkout@v5
19
+ uses: actions/checkout@v6
20
20
 
21
21
  - name: Auto-label based on files changed
22
22
  if: github.event_name == 'pull_request'
@@ -27,7 +27,7 @@ jobs:
27
27
  steps:
28
28
  # Grap the latest commit from the branch
29
29
  - name: Checkout the branch
30
- uses: actions/checkout@v5
30
+ uses: actions/checkout@v6
31
31
  with:
32
32
  persist-credentials: false
33
33
 
@@ -61,7 +61,7 @@ jobs:
61
61
  uv run mkdocs build
62
62
 
63
63
  - name: Deploy Page 🚀
64
- uses: JamesIves/github-pages-deploy-action@v4.7.4
64
+ uses: JamesIves/github-pages-deploy-action@v4.8.0
65
65
  with:
66
66
  branch: gh-pages
67
67
  folder: site
@@ -13,7 +13,7 @@ jobs:
13
13
 
14
14
  steps:
15
15
  - name: Checkout repository
16
- uses: actions/checkout@v5
16
+ uses: actions/checkout@v6
17
17
  with:
18
18
  fetch-depth: 0
19
19
 
@@ -17,7 +17,7 @@ jobs:
17
17
  fail-fast: true
18
18
  steps:
19
19
  - name: Check out the code
20
- uses: actions/checkout@v5
20
+ uses: actions/checkout@v6
21
21
  with:
22
22
  fetch-depth: 1
23
23
  - name: Set up Python ${{ matrix.python-version }}
@@ -29,7 +29,7 @@ jobs:
29
29
 
30
30
  steps:
31
31
  - name: Checkout repository
32
- uses: actions/checkout@v5
32
+ uses: actions/checkout@v6
33
33
  with:
34
34
  fetch-depth: 0
35
35
 
@@ -72,7 +72,7 @@ jobs:
72
72
 
73
73
  steps:
74
74
  - name: Checkout repository
75
- uses: actions/checkout@v5
75
+ uses: actions/checkout@v6
76
76
 
77
77
  - name: Set up Python ${{ matrix.python-version }}
78
78
  uses: actions/setup-python@v6
@@ -108,7 +108,7 @@ jobs:
108
108
 
109
109
  steps:
110
110
  - name: Checkout repository
111
- uses: actions/checkout@v5
111
+ uses: actions/checkout@v6
112
112
 
113
113
  - name: Set up Python
114
114
  uses: actions/setup-python@v6
@@ -132,7 +132,7 @@ jobs:
132
132
  uv run bandit -r gpjax/ -f json -o bandit-report.json || echo "Bandit scan completed with warnings"
133
133
 
134
134
  - name: Upload security reports
135
- uses: actions/upload-artifact@v5
135
+ uses: actions/upload-artifact@v6
136
136
  with:
137
137
  name: security-reports
138
138
  path: |
@@ -146,7 +146,7 @@ jobs:
146
146
 
147
147
  steps:
148
148
  - name: Checkout repository
149
- uses: actions/checkout@v5
149
+ uses: actions/checkout@v6
150
150
 
151
151
  - name: Set up Python
152
152
  uses: actions/setup-python@v6
@@ -166,7 +166,7 @@ jobs:
166
166
  uv run twine check dist/*
167
167
 
168
168
  - name: Upload build artifacts
169
- uses: actions/upload-artifact@v5
169
+ uses: actions/upload-artifact@v6
170
170
  with:
171
171
  name: dist-packages
172
172
  path: dist/
@@ -181,7 +181,7 @@ jobs:
181
181
 
182
182
  steps:
183
183
  - name: Checkout repository
184
- uses: actions/checkout@v5
184
+ uses: actions/checkout@v6
185
185
  with:
186
186
  fetch-depth: 0
187
187
 
@@ -261,10 +261,10 @@ jobs:
261
261
 
262
262
  steps:
263
263
  - name: Checkout repository
264
- uses: actions/checkout@v5
264
+ uses: actions/checkout@v6
265
265
 
266
266
  - name: Download build artifacts
267
- uses: actions/download-artifact@v6
267
+ uses: actions/download-artifact@v7
268
268
  with:
269
269
  name: dist-packages
270
270
  path: dist/
@@ -294,7 +294,7 @@ jobs:
294
294
 
295
295
  steps:
296
296
  - name: Download build artifacts
297
- uses: actions/download-artifact@v6
297
+ uses: actions/download-artifact@v7
298
298
  with:
299
299
  name: dist-packages
300
300
  path: dist/
@@ -8,5 +8,5 @@ jobs:
8
8
  ruff:
9
9
  runs-on: ubuntu-latest
10
10
  steps:
11
- - uses: actions/checkout@v5
11
+ - uses: actions/checkout@v6
12
12
  - uses: chartboost/ruff-action@v1
@@ -20,7 +20,7 @@ jobs:
20
20
 
21
21
  steps:
22
22
  - name: Checkout repository
23
- uses: actions/checkout@v5
23
+ uses: actions/checkout@v6
24
24
 
25
25
  - name: Set up Python
26
26
  uses: actions/setup-python@v6
@@ -47,7 +47,7 @@ jobs:
47
47
  uv run bandit -r gpjax/ -f json -o bandit-report.json || true
48
48
 
49
49
  - name: Upload dependency scan results
50
- uses: actions/upload-artifact@v5
50
+ uses: actions/upload-artifact@v6
51
51
  if: always()
52
52
  with:
53
53
  name: security-scan-results
@@ -62,7 +62,7 @@ jobs:
62
62
 
63
63
  steps:
64
64
  - name: Checkout repository
65
- uses: actions/checkout@v5
65
+ uses: actions/checkout@v6
66
66
  with:
67
67
  fetch-depth: 0 # Fetch full history for comprehensive scanning
68
68
 
@@ -18,7 +18,7 @@ jobs:
18
18
  steps:
19
19
  # Grap the latest commit from the branch
20
20
  - name: Checkout the branch
21
- uses: actions/checkout@v5
21
+ uses: actions/checkout@v6
22
22
  with:
23
23
  persist-credentials: false
24
24
 
@@ -16,7 +16,7 @@ jobs:
16
16
  runs-on: ${{ matrix.os }}
17
17
  steps:
18
18
  - name: Check out the code
19
- uses: actions/checkout@v5
19
+ uses: actions/checkout@v6
20
20
  with:
21
21
  fetch-depth: 1
22
22
  - name: Set up Python ${{ matrix.python-version }}
@@ -153,3 +153,8 @@ node_modules/
153
153
 
154
154
  docs/api
155
155
  docs/_examples
156
+ local_libs/
157
+ local_papers/
158
+ GEMINI.md
159
+ AGENTS.md
160
+ plans/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.13.3
3
+ Version: 0.13.5
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/thomaspinder/GPJax/issues
@@ -25,6 +25,7 @@ Requires-Dist: jaxtyping>0.2.10
25
25
  Requires-Dist: numpy>=2.0.0
26
26
  Requires-Dist: numpyro
27
27
  Requires-Dist: optax>0.2.1
28
+ Requires-Dist: tensorstore!=0.1.76; sys_platform == 'darwin'
28
29
  Requires-Dist: tqdm>4.66.2
29
30
  Provides-Extra: dev
30
31
  Requires-Dist: absolufy-imports>=0.3.1; extra == 'dev'
@@ -59,7 +60,7 @@ Requires-Dist: mkdocs-jupyter>=0.24.3; extra == 'docs'
59
60
  Requires-Dist: mkdocs-literate-nav>=0.6.0; extra == 'docs'
60
61
  Requires-Dist: mkdocs-material>=9.5.12; extra == 'docs'
61
62
  Requires-Dist: mkdocs>=1.5.3; extra == 'docs'
62
- Requires-Dist: mkdocstrings[python]<0.31.0; extra == 'docs'
63
+ Requires-Dist: mkdocstrings[python]<1.1.0; extra == 'docs'
63
64
  Requires-Dist: nbconvert>=7.16.2; extra == 'docs'
64
65
  Requires-Dist: networkx>=3.0; extra == 'docs'
65
66
  Requires-Dist: pandas>=1.5.3; extra == 'docs'
@@ -141,7 +142,7 @@ GPJax into the package it is today.
141
142
  > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
142
143
  > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
143
144
  > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
144
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
145
+ > - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
145
146
  > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
146
147
  > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
147
148
  > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
@@ -70,7 +70,7 @@ GPJax into the package it is today.
70
70
  > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/_examples/classification/#laplace-approximation)
71
71
  > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/_examples/constructing_new_kernels/#custom-kernel)
72
72
  > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/_examples/graph_kernels/)
73
- > - [**Pathwise Sampling**](https://docs.jaxgaussianprocesses.com/_examples/spatial/)
73
+ > - [**Heteroscedastic Inference**](https://docs.jaxgaussianprocesses.com/_examples/heteroscedastic_inference/)
74
74
  > - [**Learning Gaussian Process Barycentres**](https://docs.jaxgaussianprocesses.com/_examples/barycentres/)
75
75
  > - [**Deep Kernel Regression**](https://docs.jaxgaussianprocesses.com/_examples/deep_kernels/)
76
76
  > - [**Poisson Regression**](https://docs.jaxgaussianprocesses.com/_examples/poisson/)
@@ -178,12 +178,19 @@ points. We demonstrate its use in
178
178
 
179
179
  ## JIT compilation
180
180
 
181
- There are a subset of operations in GPJax that are not JIT compatible by default. This
182
- is because we have assertions in place to check the properties of the parameters. For
183
- example, we check that the lengthscale parameter that a user provides is positive. This
184
- makes for a better user experience as we can provide more informative error messages;
185
- however, JIT compiling functions wherein these assertions are made will break the code.
186
- As an example, consider the following code:
181
+ GPJax validates parameters at construction time using two kinds of checks:
182
+
183
+ 1. **Type checks** plain Python `isinstance` checks that verify values are array-like.
184
+ 2. **Value checks** JAX-compatible assertions (via `checkify`) that verify constraints
185
+ like positivity or bounds.
186
+
187
+ During JIT tracing, concrete values are replaced by abstract tracers. The type checks
188
+ use `isinstance`, which is a pure Python operation that cannot be intercepted by JAX's
189
+ `checkify` transformation. This means that constructing GPJax objects (kernels, mean
190
+ functions, likelihoods, etc.) **inside** a JIT boundary will fail.
191
+
192
+ As an example, consider the following code that constructs a kernel inside a
193
+ JIT-compiled function:
187
194
 
188
195
  ```python
189
196
  import jax
@@ -192,43 +199,40 @@ import gpjax as gpx
192
199
 
193
200
  x = jnp.linspace(0, 1, 10)[:, None]
194
201
 
195
- def compute_gram(lengthscale):
202
+ def compute_gram_bad(lengthscale):
196
203
  k = gpx.kernels.RBF(active_dims=[0], lengthscale=lengthscale, variance=jnp.array(1.0))
197
204
  return k.gram(x)
198
205
 
199
- compute_gram(1.0)
206
+ compute_gram_bad(1.0) # works fine outside JIT
200
207
  ```
201
208
 
202
- so far so good. However, if we try to JIT compile this function, we will get an error:
209
+ If we try to JIT compile this function, we get a `TypeError` because the kernel
210
+ constructor receives a JAX tracer instead of a concrete array:
203
211
 
204
212
  ```python
205
- jit_compute_gram = jax.jit(compute_gram)
213
+ jit_compute_gram_bad = jax.jit(compute_gram_bad)
206
214
  try:
207
- jit_compute_gram(1.0)
215
+ jit_compute_gram_bad(1.0)
208
216
  except Exception as e:
209
217
  print(e)
210
218
  ```
211
219
 
212
- This error is due to the fact that the `RBF` kernel contains an assertion that checks
213
- that the lengthscale is positive. It does not matter that the assertion is satisfied;
214
- the very presence of the assertion will break JIT compilation.
220
+ ### The fix: construct objects outside JIT
215
221
 
216
- To resolve this, we can use the `checkify` decorator to remove the assertion. This will
217
- allow the function to be JIT compiled.
222
+ The solution is to construct GPJax objects **outside** the JIT boundary and only JIT the
223
+ computation itself. This follows the standard JAX pattern of keeping object construction
224
+ separate from traced computation:
218
225
 
219
226
  ```python
220
- from jax.experimental import checkify
227
+ k = gpx.kernels.RBF(active_dims=[0], lengthscale=1.0, variance=jnp.array(1.0))
221
228
 
222
- jit_compute_gram = jax.jit(checkify.checkify(compute_gram))
223
- error, value = jit_compute_gram(1.0)
224
- ```
225
- By virtue of the `checkify.checkify`, a tuple is returned where the first element is the
226
- output of the assertion, and the second element is the value of the function.
229
+ @jax.jit
230
+ def compute_gram(x):
231
+ return k.gram(x)
227
232
 
228
- This design is not perfect, and in an ideal world we would not enforce the user to wrap
229
- their code in `checkify.checkify`. We are actively looking into cleaner ways to provide
230
- guardrails in a less intrusive manner. However, for now, should you try to JIT compile
231
- a component of GPJax wherein there is an assertion, you will need to wrap the function
232
- in `checkify.checkify` as shown above.
233
+ result = compute_gram(x)
234
+ ```
233
235
 
234
- For more on `checkify`, please see the [JAX Checkify Doc](https://docs.jax.dev/en/latest/debugging/checkify_guide.html).
236
+ More generally, any GPJax object should be constructed outside of `jax.jit`, `jax.vmap`,
237
+ or `jax.grad` boundaries. Once constructed, their methods can be freely used inside
238
+ these JAX transformations.