gpjax 0.13.0__tar.gz → 0.13.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/build_docs.yml +2 -2
  2. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/integration.yml +1 -1
  3. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/pr_greeting.yml +25 -2
  4. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/release.yml +7 -7
  5. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/security-analysis.yml +2 -2
  6. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/test_docs.yml +1 -1
  7. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/tests.yml +1 -1
  8. {gpjax-0.13.0 → gpjax-0.13.2}/PKG-INFO +1 -1
  9. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/__init__.py +1 -1
  10. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/computations/basis_functions.py +2 -4
  11. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/computations/eigen.py +1 -15
  12. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/non_euclidean/graph.py +7 -6
  13. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/non_euclidean/utils.py +30 -0
  14. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/variational_families.py +69 -5
  15. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_variational_families.py +59 -0
  16. {gpjax-0.13.0 → gpjax-0.13.2}/.github/CODE_OF_CONDUCT.md +0 -0
  17. {gpjax-0.13.0 → gpjax-0.13.2}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  18. {gpjax-0.13.0 → gpjax-0.13.2}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  19. {gpjax-0.13.0 → gpjax-0.13.2}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  20. {gpjax-0.13.0 → gpjax-0.13.2}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  21. {gpjax-0.13.0 → gpjax-0.13.2}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  22. {gpjax-0.13.0 → gpjax-0.13.2}/.github/codecov.yml +0 -0
  23. {gpjax-0.13.0 → gpjax-0.13.2}/.github/commitlint.config.js +0 -0
  24. {gpjax-0.13.0 → gpjax-0.13.2}/.github/dependabot.yml +0 -0
  25. {gpjax-0.13.0 → gpjax-0.13.2}/.github/labeler.yml +0 -0
  26. {gpjax-0.13.0 → gpjax-0.13.2}/.github/labels.yml +0 -0
  27. {gpjax-0.13.0 → gpjax-0.13.2}/.github/pull_request_template.md +0 -0
  28. {gpjax-0.13.0 → gpjax-0.13.2}/.github/release-drafter.yml +0 -0
  29. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/auto-label.yml +0 -0
  30. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/commit-lint.yml +0 -0
  31. {gpjax-0.13.0 → gpjax-0.13.2}/.github/workflows/ruff.yml +0 -0
  32. {gpjax-0.13.0 → gpjax-0.13.2}/.gitignore +0 -0
  33. {gpjax-0.13.0 → gpjax-0.13.2}/CITATION.bib +0 -0
  34. {gpjax-0.13.0 → gpjax-0.13.2}/LICENSE.txt +0 -0
  35. {gpjax-0.13.0 → gpjax-0.13.2}/Makefile +0 -0
  36. {gpjax-0.13.0 → gpjax-0.13.2}/README.md +0 -0
  37. {gpjax-0.13.0 → gpjax-0.13.2}/docs/CODE_OF_CONDUCT.md +0 -0
  38. {gpjax-0.13.0 → gpjax-0.13.2}/docs/GOVERNANCE.md +0 -0
  39. {gpjax-0.13.0 → gpjax-0.13.2}/docs/contributing.md +0 -0
  40. {gpjax-0.13.0 → gpjax-0.13.2}/docs/design.md +0 -0
  41. {gpjax-0.13.0 → gpjax-0.13.2}/docs/index.md +0 -0
  42. {gpjax-0.13.0 → gpjax-0.13.2}/docs/index.rst +0 -0
  43. {gpjax-0.13.0 → gpjax-0.13.2}/docs/installation.md +0 -0
  44. {gpjax-0.13.0 → gpjax-0.13.2}/docs/javascripts/katex.js +0 -0
  45. {gpjax-0.13.0 → gpjax-0.13.2}/docs/refs.bib +0 -0
  46. {gpjax-0.13.0 → gpjax-0.13.2}/docs/scripts/gen_examples.py +0 -0
  47. {gpjax-0.13.0 → gpjax-0.13.2}/docs/scripts/gen_pages.py +0 -0
  48. {gpjax-0.13.0 → gpjax-0.13.2}/docs/scripts/notebook_converter.py +0 -0
  49. {gpjax-0.13.0 → gpjax-0.13.2}/docs/scripts/sharp_bits_figure.py +0 -0
  50. {gpjax-0.13.0 → gpjax-0.13.2}/docs/sharp_bits.md +0 -0
  51. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/GP.pdf +0 -0
  52. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/GP.svg +0 -0
  53. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/bijector_figure.svg +0 -0
  54. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/css/gpjax_theme.css +0 -0
  55. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/favicon.ico +0 -0
  56. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/gpjax.mplstyle +0 -0
  57. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/gpjax_logo.pdf +0 -0
  58. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/gpjax_logo.svg +0 -0
  59. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/jaxkern/lato.ttf +0 -0
  60. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/jaxkern/logo.png +0 -0
  61. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/jaxkern/logo.svg +0 -0
  62. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/jaxkern/main.py +0 -0
  63. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/step_size_figure.png +0 -0
  64. {gpjax-0.13.0 → gpjax-0.13.2}/docs/static/step_size_figure.svg +0 -0
  65. {gpjax-0.13.0 → gpjax-0.13.2}/docs/stylesheets/extra.css +0 -0
  66. {gpjax-0.13.0 → gpjax-0.13.2}/docs/stylesheets/permalinks.css +0 -0
  67. {gpjax-0.13.0 → gpjax-0.13.2}/examples/backend.py +0 -0
  68. {gpjax-0.13.0 → gpjax-0.13.2}/examples/barycentres/barycentre_gp.gif +0 -0
  69. {gpjax-0.13.0 → gpjax-0.13.2}/examples/barycentres.py +0 -0
  70. {gpjax-0.13.0 → gpjax-0.13.2}/examples/classification.py +0 -0
  71. {gpjax-0.13.0 → gpjax-0.13.2}/examples/collapsed_vi.py +0 -0
  72. {gpjax-0.13.0 → gpjax-0.13.2}/examples/constructing_new_kernels.py +0 -0
  73. {gpjax-0.13.0 → gpjax-0.13.2}/examples/data/max_tempeature_switzerland.csv +0 -0
  74. {gpjax-0.13.0 → gpjax-0.13.2}/examples/data/yacht_hydrodynamics.data +0 -0
  75. {gpjax-0.13.0 → gpjax-0.13.2}/examples/deep_kernels.py +0 -0
  76. {gpjax-0.13.0 → gpjax-0.13.2}/examples/gpjax.mplstyle +0 -0
  77. {gpjax-0.13.0 → gpjax-0.13.2}/examples/graph_kernels.py +0 -0
  78. {gpjax-0.13.0 → gpjax-0.13.2}/examples/intro_to_gps/decomposed_mll.png +0 -0
  79. {gpjax-0.13.0 → gpjax-0.13.2}/examples/intro_to_gps/generating_process.png +0 -0
  80. {gpjax-0.13.0 → gpjax-0.13.2}/examples/intro_to_gps.py +0 -0
  81. {gpjax-0.13.0 → gpjax-0.13.2}/examples/intro_to_kernels.py +0 -0
  82. {gpjax-0.13.0 → gpjax-0.13.2}/examples/likelihoods_guide.py +0 -0
  83. {gpjax-0.13.0 → gpjax-0.13.2}/examples/oceanmodelling.py +0 -0
  84. {gpjax-0.13.0 → gpjax-0.13.2}/examples/poisson.py +0 -0
  85. {gpjax-0.13.0 → gpjax-0.13.2}/examples/regression.py +0 -0
  86. {gpjax-0.13.0 → gpjax-0.13.2}/examples/uncollapsed_vi.py +0 -0
  87. {gpjax-0.13.0 → gpjax-0.13.2}/examples/utils.py +0 -0
  88. {gpjax-0.13.0 → gpjax-0.13.2}/examples/yacht.py +0 -0
  89. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/citation.py +0 -0
  90. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/dataset.py +0 -0
  91. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/distributions.py +0 -0
  92. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/fit.py +0 -0
  93. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/gps.py +0 -0
  94. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/integrators.py +0 -0
  95. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/__init__.py +0 -0
  96. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/approximations/__init__.py +0 -0
  97. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/approximations/rff.py +0 -0
  98. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/base.py +0 -0
  99. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/computations/__init__.py +0 -0
  100. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/computations/base.py +0 -0
  101. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  102. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/computations/dense.py +0 -0
  103. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/computations/diagonal.py +0 -0
  104. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  105. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/nonstationary/__init__.py +0 -0
  106. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  107. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/nonstationary/linear.py +0 -0
  108. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  109. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/__init__.py +0 -0
  110. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/base.py +0 -0
  111. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/matern12.py +0 -0
  112. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/matern32.py +0 -0
  113. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/matern52.py +0 -0
  114. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/periodic.py +0 -0
  115. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  116. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  117. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/rbf.py +0 -0
  118. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/utils.py +0 -0
  119. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/kernels/stationary/white.py +0 -0
  120. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/likelihoods.py +0 -0
  121. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/linalg/__init__.py +0 -0
  122. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/linalg/operations.py +0 -0
  123. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/linalg/operators.py +0 -0
  124. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/linalg/utils.py +0 -0
  125. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/mean_functions.py +0 -0
  126. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/numpyro_extras.py +0 -0
  127. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/objectives.py +0 -0
  128. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/parameters.py +0 -0
  129. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/scan.py +0 -0
  130. {gpjax-0.13.0 → gpjax-0.13.2}/gpjax/typing.py +0 -0
  131. {gpjax-0.13.0 → gpjax-0.13.2}/mkdocs.yml +0 -0
  132. {gpjax-0.13.0 → gpjax-0.13.2}/pyproject.toml +0 -0
  133. {gpjax-0.13.0 → gpjax-0.13.2}/static/CONTRIBUTING.md +0 -0
  134. {gpjax-0.13.0 → gpjax-0.13.2}/static/paper.bib +0 -0
  135. {gpjax-0.13.0 → gpjax-0.13.2}/static/paper.md +0 -0
  136. {gpjax-0.13.0 → gpjax-0.13.2}/static/paper.pdf +0 -0
  137. {gpjax-0.13.0 → gpjax-0.13.2}/tests/__init__.py +0 -0
  138. {gpjax-0.13.0 → gpjax-0.13.2}/tests/conftest.py +0 -0
  139. {gpjax-0.13.0 → gpjax-0.13.2}/tests/integration_tests.py +0 -0
  140. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_citations.py +0 -0
  141. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_dataset.py +0 -0
  142. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_fit.py +0 -0
  143. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_gaussian_distribution.py +0 -0
  144. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_gps.py +0 -0
  145. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_imports.py +0 -0
  146. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_integrators.py +0 -0
  147. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_kernels/__init__.py +0 -0
  148. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_kernels/test_approximations.py +0 -0
  149. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_kernels/test_base.py +0 -0
  150. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_kernels/test_computation.py +0 -0
  151. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_kernels/test_non_euclidean.py +0 -0
  152. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_kernels/test_nonstationary.py +0 -0
  153. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_kernels/test_stationary.py +0 -0
  154. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_kernels/test_utils.py +0 -0
  155. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_likelihoods.py +0 -0
  156. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_linalg.py +0 -0
  157. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_markdown.py +0 -0
  158. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_mean_functions.py +0 -0
  159. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_numpyro_extras.py +0 -0
  160. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_objectives.py +0 -0
  161. {gpjax-0.13.0 → gpjax-0.13.2}/tests/test_parameters.py +0 -0
  162. {gpjax-0.13.0 → gpjax-0.13.2}/uv.lock +0 -0
@@ -40,7 +40,7 @@ jobs:
40
40
 
41
41
  # Install katex for math support
42
42
  - name: Install NPM
43
- uses: actions/setup-node@v5
43
+ uses: actions/setup-node@v6
44
44
  with:
45
45
  node-version: 16
46
46
  - name: Install KaTeX
@@ -49,7 +49,7 @@ jobs:
49
49
 
50
50
  # Install uv
51
51
  - name: Install uv
52
- uses: astral-sh/setup-uv@v6
52
+ uses: astral-sh/setup-uv@v7
53
53
  with:
54
54
  version: "latest"
55
55
 
@@ -27,7 +27,7 @@ jobs:
27
27
 
28
28
  # Install uv
29
29
  - name: Install uv
30
- uses: astral-sh/setup-uv@v6
30
+ uses: astral-sh/setup-uv@v7
31
31
  with:
32
32
  version: "latest"
33
33
 
@@ -13,8 +13,31 @@ jobs:
13
13
  steps:
14
14
  - uses: actions/first-interaction@v3.1.0
15
15
  with:
16
- repo-token: ${{ secrets.GITHUB_TOKEN }}
17
- pr-message: >+
16
+ repo_token: ${{ secrets.GITHUB_TOKEN }}
17
+ issue_message: |
18
+ Thank you for opening your first issue into GPJax!
19
+
20
+ If you have not heard from us in a while, please feel free to ping
21
+ `@gpjax/developers` or anyone who has commented on the PR.
22
+ Most of our reviewers are volunteers and sometimes things fall
23
+ through the cracks.
24
+
25
+
26
+ You can also join us [on
27
+ Slack](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw) for real-time
28
+ discussion.
29
+
30
+
31
+ For details on testing, writing docs, and our review process,
32
+ please see [the developer
33
+ guide](https://docs.jaxgaussianprocesses.com/contributing/)
34
+
35
+
36
+ We strive to be a welcoming and open project. Please follow our
37
+ [Code of
38
+ Conduct](https://github.com/JaxGaussianProcesses/GPJax/blob/main/.github/CODE_OF_CONDUCT.md).
39
+
40
+ pr_message: |
18
41
  Thank you for opening your first PR into GPJax!
19
42
 
20
43
 
@@ -80,7 +80,7 @@ jobs:
80
80
  python-version: ${{ matrix.python-version }}
81
81
 
82
82
  - name: Install uv
83
- uses: astral-sh/setup-uv@v6
83
+ uses: astral-sh/setup-uv@v7
84
84
 
85
85
  - name: Install dependencies
86
86
  run: |
@@ -116,7 +116,7 @@ jobs:
116
116
  python-version: '3.11'
117
117
 
118
118
  - name: Install uv
119
- uses: astral-sh/setup-uv@v6
119
+ uses: astral-sh/setup-uv@v7
120
120
 
121
121
  - name: Install dependencies
122
122
  run: |
@@ -132,7 +132,7 @@ jobs:
132
132
  uv run bandit -r gpjax/ -f json -o bandit-report.json || echo "Bandit scan completed with warnings"
133
133
 
134
134
  - name: Upload security reports
135
- uses: actions/upload-artifact@v4
135
+ uses: actions/upload-artifact@v5
136
136
  with:
137
137
  name: security-reports
138
138
  path: |
@@ -154,7 +154,7 @@ jobs:
154
154
  python-version: '3.11'
155
155
 
156
156
  - name: Install uv
157
- uses: astral-sh/setup-uv@v6
157
+ uses: astral-sh/setup-uv@v7
158
158
 
159
159
  - name: Build package
160
160
  run: |
@@ -166,7 +166,7 @@ jobs:
166
166
  uv run twine check dist/*
167
167
 
168
168
  - name: Upload build artifacts
169
- uses: actions/upload-artifact@v4
169
+ uses: actions/upload-artifact@v5
170
170
  with:
171
171
  name: dist-packages
172
172
  path: dist/
@@ -264,7 +264,7 @@ jobs:
264
264
  uses: actions/checkout@v5
265
265
 
266
266
  - name: Download build artifacts
267
- uses: actions/download-artifact@v5
267
+ uses: actions/download-artifact@v6
268
268
  with:
269
269
  name: dist-packages
270
270
  path: dist/
@@ -294,7 +294,7 @@ jobs:
294
294
 
295
295
  steps:
296
296
  - name: Download build artifacts
297
- uses: actions/download-artifact@v5
297
+ uses: actions/download-artifact@v6
298
298
  with:
299
299
  name: dist-packages
300
300
  path: dist/
@@ -28,7 +28,7 @@ jobs:
28
28
  python-version: '3.11'
29
29
 
30
30
  - name: Install uv
31
- uses: astral-sh/setup-uv@v6
31
+ uses: astral-sh/setup-uv@v7
32
32
  with:
33
33
  version: "latest"
34
34
 
@@ -47,7 +47,7 @@ jobs:
47
47
  uv run bandit -r gpjax/ -f json -o bandit-report.json || true
48
48
 
49
49
  - name: Upload dependency scan results
50
- uses: actions/upload-artifact@v4
50
+ uses: actions/upload-artifact@v5
51
51
  if: always()
52
52
  with:
53
53
  name: security-scan-results
@@ -32,7 +32,7 @@ jobs:
32
32
 
33
33
  # Install uv
34
34
  - name: Install uv
35
- uses: astral-sh/setup-uv@v6
35
+ uses: astral-sh/setup-uv@v7
36
36
  with:
37
37
  version: "latest"
38
38
 
@@ -26,7 +26,7 @@ jobs:
26
26
 
27
27
  # Install uv
28
28
  - name: Install uv
29
- uses: astral-sh/setup-uv@v6
29
+ uses: astral-sh/setup-uv@v7
30
30
  with:
31
31
  version: "latest"
32
32
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.13.0
3
+ Version: 0.13.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -40,7 +40,7 @@ __license__ = "MIT"
40
40
  __description__ = "Gaussian processes in JAX and Flax"
41
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
42
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43
- __version__ = "0.13.0"
43
+ __version__ = "0.13.2"
44
44
 
45
45
  __all__ = [
46
46
  "gps",
@@ -6,9 +6,7 @@ from jaxtyping import Float
6
6
  import gpjax
7
7
  from gpjax.kernels.computations.base import AbstractKernelComputation
8
8
  from gpjax.linalg import (
9
- Dense,
10
9
  Diagonal,
11
- psd,
12
10
  )
13
11
  from gpjax.typing import Array
14
12
 
@@ -27,9 +25,9 @@ class BasisFunctionComputation(AbstractKernelComputation):
27
25
  z2 = self.compute_features(kernel, y)
28
26
  return self.scaling(kernel) * jnp.matmul(z1, z2.T)
29
27
 
30
- def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Dense:
28
+ def _gram(self, kernel: K, inputs: Float[Array, "N D"]) -> Float[Array, "N N"]:
31
29
  z1 = self.compute_features(kernel, inputs)
32
- return psd(Dense(self.scaling(kernel) * jnp.matmul(z1, z1.T)))
30
+ return self.scaling(kernel) * jnp.matmul(z1, z1.T)
33
31
 
34
32
  def diagonal(self, kernel: K, inputs: Float[Array, "N D"]) -> Diagonal:
35
33
  r"""For a given kernel, compute the elementwise diagonal of the
@@ -15,7 +15,6 @@
15
15
 
16
16
 
17
17
  import beartype.typing as tp
18
- import jax.numpy as jnp
19
18
  from jaxtyping import (
20
19
  Float,
21
20
  Num,
@@ -39,17 +38,4 @@ class EigenKernelComputation(AbstractKernelComputation):
39
38
  def _cross_covariance(
40
39
  self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
41
40
  ) -> Float[Array, "N M"]:
42
- # Transform the eigenvalues of the graph Laplacian according to the
43
- # RBF kernel's SPDE form.
44
- S = jnp.power(
45
- kernel.eigenvalues
46
- + 2
47
- * kernel.smoothness.value
48
- / kernel.lengthscale.value
49
- / kernel.lengthscale.value,
50
- -kernel.smoothness.value,
51
- )
52
- S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
53
- # Scale the transform eigenvalues by the kernel variance
54
- S = jnp.multiply(S, kernel.variance.value)
55
- return kernel(x, y, S=S)
41
+ return kernel(x, y)
@@ -25,7 +25,10 @@ from gpjax.kernels.computations import (
25
25
  AbstractKernelComputation,
26
26
  EigenKernelComputation,
27
27
  )
28
- from gpjax.kernels.non_euclidean.utils import jax_gather_nd
28
+ from gpjax.kernels.non_euclidean.utils import (
29
+ calculate_heat_semigroup,
30
+ jax_gather_nd,
31
+ )
29
32
  from gpjax.kernels.stationary.base import StationaryKernel
30
33
  from gpjax.parameters import (
31
34
  Parameter,
@@ -98,14 +101,12 @@ class GraphKernel(StationaryKernel):
98
101
 
99
102
  super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
100
103
 
101
- def __call__( # TODO not consistent with general kernel interface
104
+ def __call__(
102
105
  self,
103
106
  x: Int[Array, "N 1"],
104
- y: Int[Array, "N 1"],
105
- *,
106
- S,
107
- **kwargs,
107
+ y: Int[Array, "M 1"],
108
108
  ):
109
+ S = calculate_heat_semigroup(self)
109
110
  Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
110
111
  jax_gather_nd(self.eigenvectors, y)
111
112
  ) # shape (n,n)
@@ -13,6 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from __future__ import annotations
17
+
18
+ import beartype.typing as tp
19
+ import jax.numpy as jnp
16
20
  from jaxtyping import (
17
21
  Float,
18
22
  Int,
@@ -20,6 +24,9 @@ from jaxtyping import (
20
24
 
21
25
  from gpjax.typing import Array
22
26
 
27
+ if tp.TYPE_CHECKING:
28
+ from gpjax.kernels.non_euclidean.graph import GraphKernel
29
+
23
30
 
24
31
  def jax_gather_nd(
25
32
  params: Float[Array, " N *rest"], indices: Int[Array, " M 1"]
@@ -41,3 +48,26 @@ def jax_gather_nd(
41
48
  """
42
49
  tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
43
50
  return params[tuple_indices]
51
+
52
+
53
+ def calculate_heat_semigroup(kernel: GraphKernel) -> Float[Array, "N M"]:
54
+ r"""Returns the rescaled heat semigroup, S
55
+
56
+ Args:
57
+ kernel: instance of the graph kernel
58
+
59
+ Returns:
60
+ S
61
+ """
62
+ S = jnp.power(
63
+ kernel.eigenvalues
64
+ + 2
65
+ * kernel.smoothness.value
66
+ / kernel.lengthscale.value
67
+ / kernel.lengthscale.value,
68
+ -kernel.smoothness.value,
69
+ )
70
+ S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
71
+ # Scale the transform eigenvalues by the kernel variance
72
+ S = jnp.multiply(S, kernel.variance.value)
73
+ return S
@@ -19,7 +19,10 @@ import beartype.typing as tp
19
19
  from flax import nnx
20
20
  import jax.numpy as jnp
21
21
  import jax.scipy as jsp
22
- from jaxtyping import Float
22
+ from jaxtyping import (
23
+ Float,
24
+ Int,
25
+ )
23
26
 
24
27
  from gpjax.dataset import Dataset
25
28
  from gpjax.distributions import GaussianDistribution
@@ -108,6 +111,7 @@ class AbstractVariationalGaussian(AbstractVariationalFamily[L]):
108
111
  self,
109
112
  posterior: AbstractPosterior[P, L],
110
113
  inducing_inputs: tp.Union[
114
+ Int[Array, "N D"],
111
115
  Float[Array, "N D"],
112
116
  Real,
113
117
  ],
@@ -140,7 +144,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
140
144
  def __init__(
141
145
  self,
142
146
  posterior: AbstractPosterior[P, L],
143
- inducing_inputs: Float[Array, "N D"],
147
+ inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]],
144
148
  variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
145
149
  variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
146
150
  jitter: ScalarFloat = 1e-6,
@@ -156,6 +160,12 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
156
160
  self.variational_mean = Real(variational_mean)
157
161
  self.variational_root_covariance = LowerTriangular(variational_root_covariance)
158
162
 
163
+ def _fmt_Kzt_Ktt(self, Kzt, Ktt):
164
+ return Kzt, Ktt
165
+
166
+ def _fmt_inducing_inputs(self):
167
+ return self.inducing_inputs.value
168
+
159
169
  def prior_kl(self) -> ScalarFloat:
160
170
  r"""Compute the prior KL divergence.
161
171
 
@@ -178,7 +188,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
178
188
  # Unpack variational parameters
179
189
  variational_mean = self.variational_mean.value
180
190
  variational_sqrt = self.variational_root_covariance.value
181
- inducing_inputs = self.inducing_inputs.value
191
+ inducing_inputs = self._fmt_inducing_inputs()
182
192
 
183
193
  # Unpack mean function and kernel
184
194
  mean_function = self.posterior.prior.mean_function
@@ -202,7 +212,9 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
202
212
 
203
213
  return q_inducing.kl_divergence(p_inducing)
204
214
 
205
- def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
215
+ def predict(
216
+ self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
217
+ ) -> GaussianDistribution:
206
218
  r"""Compute the predictive distribution of the GP at the test inputs t.
207
219
 
208
220
  This is the integral $q(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u$, which
@@ -222,7 +234,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
222
234
  # Unpack variational parameters
223
235
  variational_mean = self.variational_mean.value
224
236
  variational_sqrt = self.variational_root_covariance.value
225
- inducing_inputs = self.inducing_inputs.value
237
+ inducing_inputs = self._fmt_inducing_inputs()
226
238
 
227
239
  # Unpack mean function and kernel
228
240
  mean_function = self.posterior.prior.mean_function
@@ -241,6 +253,8 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
241
253
  Kzt = kernel.cross_covariance(inducing_inputs, test_points)
242
254
  test_mean = mean_function(test_points)
243
255
 
256
+ Kzt, Ktt = self._fmt_Kzt_Ktt(Kzt, Ktt)
257
+
244
258
  # Lz⁻¹ Kzt
245
259
  Lz_inv_Kzt = solve(Lz, Kzt)
246
260
 
@@ -259,8 +273,10 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
259
273
  - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
260
274
  + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
261
275
  )
276
+
262
277
  if hasattr(covariance, "to_dense"):
263
278
  covariance = covariance.to_dense()
279
+
264
280
  covariance = add_jitter(covariance, self.jitter)
265
281
  covariance = Dense(covariance)
266
282
 
@@ -269,6 +285,53 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
269
285
  )
270
286
 
271
287
 
288
+ class GraphVariationalGaussian(VariationalGaussian[L]):
289
+ r"""A variational Gaussian defined over graph-structured inducing inputs.
290
+
291
+ This subclass adapts the :class:`VariationalGaussian` family to the
292
+ case where the inducing inputs are discrete graph node indices rather
293
+ than continuous spatial coordinates.
294
+
295
+ The main differences are:
296
+ * Inducing inputs are integer node IDs.
297
+ * Kernel matrices are ensured to be dense and 2D.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ posterior: AbstractPosterior[P, L],
303
+ inducing_inputs: Int[Array, "N D"],
304
+ variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
305
+ variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
306
+ jitter: ScalarFloat = 1e-6,
307
+ ):
308
+ super().__init__(
309
+ posterior,
310
+ inducing_inputs,
311
+ variational_mean,
312
+ variational_root_covariance,
313
+ jitter,
314
+ )
315
+ self.inducing_inputs = self.inducing_inputs.value.astype(jnp.int64)
316
+
317
+ def _fmt_Kzt_Ktt(self, Kzt, Ktt):
318
+ Ktt = Ktt.to_dense() if hasattr(Ktt, "to_dense") else Ktt
319
+ Kzt = Kzt.to_dense() if hasattr(Kzt, "to_dense") else Kzt
320
+ Ktt = jnp.atleast_2d(Ktt)
321
+ Kzt = (
322
+ jnp.transpose(jnp.atleast_2d(Kzt)) if Kzt.ndim < 2 else jnp.atleast_2d(Kzt)
323
+ )
324
+ return Kzt, Ktt
325
+
326
+ def _fmt_inducing_inputs(self):
327
+ return self.inducing_inputs
328
+
329
+ @property
330
+ def num_inducing(self) -> int:
331
+ """The number of inducing inputs."""
332
+ return self.inducing_inputs.shape[0]
333
+
334
+
272
335
  class WhitenedVariationalGaussian(VariationalGaussian[L]):
273
336
  r"""The whitened variational Gaussian family of probability distributions.
274
337
 
@@ -811,6 +874,7 @@ __all__ = [
811
874
  "AbstractVariationalFamily",
812
875
  "AbstractVariationalGaussian",
813
876
  "VariationalGaussian",
877
+ "GraphVariationalGaussian",
814
878
  "WhitenedVariationalGaussian",
815
879
  "NaturalVariationalGaussian",
816
880
  "ExpectationVariationalGaussian",
@@ -25,6 +25,8 @@ from jaxtyping import (
25
25
  Array,
26
26
  Float,
27
27
  )
28
+ import networkx as nx
29
+ import numpy as np
28
30
  import numpyro.distributions as npd
29
31
  from numpyro.distributions import Distribution as NumpyroDistribution
30
32
  import pytest
@@ -35,6 +37,7 @@ from gpjax.variational_families import (
35
37
  AbstractVariationalFamily,
36
38
  CollapsedVariationalGaussian,
37
39
  ExpectationVariationalGaussian,
40
+ GraphVariationalGaussian,
38
41
  NaturalVariationalGaussian,
39
42
  VariationalGaussian,
40
43
  WhitenedVariationalGaussian,
@@ -118,6 +121,7 @@ def test_variational_gaussians(
118
121
  )
119
122
  likelihood = gpx.likelihoods.Gaussian(123)
120
123
  inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1)
124
+
121
125
  test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1)
122
126
 
123
127
  posterior = prior * likelihood
@@ -174,6 +178,61 @@ def test_variational_gaussians(
174
178
  assert sigma.shape == (n_test, n_test)
175
179
 
176
180
 
181
+ @pytest.mark.parametrize("n_test", [10, 20])
182
+ @pytest.mark.parametrize("n_inducing", [10, 20])
183
+ @pytest.mark.parametrize(
184
+ "variational_family",
185
+ [
186
+ GraphVariationalGaussian,
187
+ ],
188
+ )
189
+ def test_graph_variational_gaussian(
190
+ n_test: int,
191
+ n_inducing: int,
192
+ variational_family: AbstractVariationalFamily,
193
+ ) -> None:
194
+ G = nx.barbell_graph(100, 0)
195
+ L = nx.laplacian_matrix(G).toarray()
196
+
197
+ kernel = gpx.kernels.GraphKernel(
198
+ laplacian=L,
199
+ lengthscale=2.3,
200
+ variance=3.2,
201
+ smoothness=6.1,
202
+ )
203
+ meanf = gpx.mean_functions.Constant()
204
+ prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
205
+ likelihood = gpx.likelihoods.Bernoulli(num_datapoints=G.number_of_nodes())
206
+
207
+ inducing_inputs = jnp.array(
208
+ np.random.randint(low=1, high=100, size=(n_inducing, 1))
209
+ ).astype(jnp.int64)
210
+
211
+ test_inputs = jnp.array(np.random.randint(low=0, high=1, size=(n_test, 1))).astype(
212
+ jnp.int64
213
+ )
214
+
215
+ posterior = prior * likelihood
216
+ q = variational_family(posterior=posterior, inducing_inputs=inducing_inputs)
217
+ # Test KL
218
+ kl = q.prior_kl()
219
+ assert isinstance(kl, jnp.ndarray)
220
+ assert kl.shape == ()
221
+ assert kl >= 0.0
222
+
223
+ # Test predictions
224
+ predictive_dist = q(test_inputs)
225
+ assert isinstance(predictive_dist, NumpyroDistribution)
226
+
227
+ mu = predictive_dist.mean
228
+ sigma = predictive_dist.covariance()
229
+
230
+ assert isinstance(mu, jnp.ndarray)
231
+ assert isinstance(sigma, jnp.ndarray)
232
+ assert mu.shape == (n_test,)
233
+ assert sigma.shape == (n_test, n_test)
234
+
235
+
177
236
  @pytest.mark.parametrize("n_test", [1, 10])
178
237
  @pytest.mark.parametrize("n_datapoints", [1, 10])
179
238
  @pytest.mark.parametrize("n_inducing", [1, 10, 20])
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes