cvmatrix 3.2.0__tar.gz → 3.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/CHANGELOG.md +5 -0
  2. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/PKG-INFO +1 -1
  3. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/cvmatrix/__init__.py +1 -1
  4. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/cvmatrix/cvmatrix.py +42 -35
  5. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/pyproject.toml +1 -1
  6. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/uv.lock +1 -1
  7. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.github/actions/build/action.yml +0 -0
  8. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.github/actions/release/action.yml +0 -0
  9. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.github/actions/test/action.yml +0 -0
  10. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.github/workflows/package_workflow.yml +0 -0
  11. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.github/workflows/pull_request_package_workflow.yml +0 -0
  12. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.github/workflows/pull_request_test_workflow.yml +0 -0
  13. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.github/workflows/test_workflow.yml +0 -0
  14. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.gitignore +0 -0
  15. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/.readthedocs.yaml +0 -0
  16. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/CONTRIBUTING.md +0 -0
  17. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/LICENSE +0 -0
  18. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/README.md +0 -0
  19. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/README.md +0 -0
  20. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/__init__.py +0 -0
  21. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/benchmark.py +0 -0
  22. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/benchmark_cvmatrix.png +0 -0
  23. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/benchmark_cvmatrix_numpy_vs_jax.png +0 -0
  24. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/benchmark_cvmatrix_vs_naive.png +0 -0
  25. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/benchmark_jax_variants.png +0 -0
  26. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/benchmark_results.csv +0 -0
  27. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/benchmark_results_jax.csv +0 -0
  28. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/benchmarks/plot_benchmark.py +0 -0
  29. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/cvmatrix/partitioner.py +0 -0
  30. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/CONTRIBUTING.md +0 -0
  31. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/Makefile +0 -0
  32. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/README.md +0 -0
  33. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/api.rst +0 -0
  34. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/conf.py +0 -0
  35. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/index.rst +0 -0
  36. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/make.bat +0 -0
  37. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/requirements.txt +0 -0
  38. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/source/cvmatrix.rst +0 -0
  39. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/source/cvmatrix_module.rst +0 -0
  40. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/docs/source/partitioner_module.rst +0 -0
  41. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/examples/__init__.py +0 -0
  42. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/examples/training_matrices.py +0 -0
  43. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/examples/training_matrices_jax.py +0 -0
  44. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/tests/__init__.py +0 -0
  45. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/tests/load_data.py +0 -0
  46. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/tests/naive_cvmatrix.py +0 -0
  47. {cvmatrix-3.2.0 → cvmatrix-3.2.1}/tests/test_cvmatrix.py +0 -0
@@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
5
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6
6
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
7
 
8
+ ## [3.2.1] - 2026-07-01
9
+
10
+ ### Changed
11
+ - `import cvmatrix` no longer imports JAX, even when JAX is installed. JAX was previously imported at module load time to broaden the array/scalar type hints so that `backend="jax"` values satisfy runtime type checking; it is now imported lazily, only when the JAX backend is actually resolved (`backend="jax"`). The array/scalar type aliases start NumPy-only and are broadened in-place to also admit `jax.Array` the first time the JAX backend is used (annotations are deferred via `from __future__ import annotations`, so `typeguard` resolves them against the broadened aliases at call time). This keeps NumPy-only import paths — and downstream packages that use only the NumPy backend — free of the JAX import cost. The `numpy` and `jax` backends are behavior- and result-identical to 3.2.0.
12
+
8
13
  ## [3.2.0] - 2026-06-29
9
14
 
10
15
  ### Added
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cvmatrix
3
- Version: 3.2.0
3
+ Version: 3.2.1
4
4
  Summary: Fast computation of possibly weighted and possibly centered/scaled training set kernel matrices in a cross-validation setting.
5
5
  Project-URL: Homepage, https://cvmatrix.readthedocs.io/en/latest/
6
6
  Project-URL: Repository, https://github.com/Sm00thix/CVMatrix
@@ -1,4 +1,4 @@
1
- __version__ = "3.2.0"
1
+ __version__ = "3.2.1"
2
2
  __all__ = ["CVMatrix", "Partitioner"]
3
3
  from .cvmatrix import CVMatrix
4
4
  from .partitioner import Partitioner
@@ -10,45 +10,49 @@ Author: Ole-Christian Galbo Engstrøm
10
10
  E-mail: ocge@foss.dk
11
11
  """
12
12
 
13
- from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
13
+ from __future__ import annotations
14
+
15
+ from typing import Literal, Optional, Tuple, Union
14
16
 
15
17
  import numpy as np
16
18
  from numpy import typing as npt
17
19
 
18
- try:
19
- # Broaden the array/scalar type hints so that, with backend="jax", jax.numpy values
20
- # (which are jax.Array instances, including under jit/vmap tracing) satisfy runtime
21
- # type checking. numpy remains the only required dependency; this is skipped if jax
22
- # is not installed, leaving the hints numpy-only.
23
- import jax as _jax
24
-
25
- Array = Union[np.ndarray, _jax.Array]
26
- Scalar = Union[np.floating, _jax.Array, float, int]
27
- # float dtype types accepted by `dtype`: numpy float types (`type[np.floating]`) and
28
- # JAX scalar dtypes. The latter (e.g. jnp.float64/float32/bfloat16) are instances of
29
- # JAX's scalar-type metaclass -- NOT numpy.floating subclasses -- so they must be
30
- # admitted via that metaclass.
31
- _JaxScalarMeta = type(_jax.numpy.float64)
32
- FloatDType = Union[type[np.floating], _JaxScalarMeta]
33
- # Abstract values produced while a function is traced by jax.jit/jax.vmap. Used to
34
- # skip data-dependent validity raises that cannot run under tracing (the host-side
35
- # caller is expected to validate folds before vmap); eager (concrete) jax values are
36
- # NOT tracers, so they still validate like the numpy backend.
37
- _TRACER_TYPES: tuple = (_jax.core.Tracer,)
38
- except ImportError: # pragma: no cover - exercised only without jax
39
- Array = np.ndarray
40
- Scalar = Union[np.floating, float, int]
41
- FloatDType = type[np.floating]
42
- _TRACER_TYPES = ()
43
-
44
- if TYPE_CHECKING:
45
- # JAX's scalar-dtype metaclass (`type(jnp.float64)`) has no public *static* type, so
46
- # the precise runtime `FloatDType` union above is built from a value (`_JaxScalarMeta`)
47
- # that static type checkers reject in a type expression ("variables are not allowed in
48
- # type expressions"). For static analysis only, expose numpy's standard dtype-specifier
49
- # alias, which cleanly accepts both numpy float types and JAX scalar dtypes. At run
50
- # time `TYPE_CHECKING` is False, so typeguard still validates against the precise union.
51
- FloatDType = npt.DTypeLike
20
+ # Array/scalar type aliases. NumPy is the only required dependency, so JAX is NEVER
21
+ # imported at module load -- ``import cvmatrix`` stays JAX-free even when JAX is
22
+ # installed. The aliases are numpy-only by default and are broadened in-place to
23
+ # also admit ``jax.Array`` values by ``_enable_jax_typing()``, which runs the first
24
+ # time the JAX backend is resolved. Because ``from __future__ import annotations``
25
+ # keeps annotations as strings, typeguard resolves them against the *current* module
26
+ # globals at call time, so the broadened unions take effect once JAX is in use.
27
+ Array = np.ndarray
28
+ Scalar = Union[np.floating, float, int]
29
+ # ``dtype`` specifier: ``npt.DTypeLike`` cleanly admits both numpy float types and
30
+ # JAX scalar dtypes (e.g. ``jnp.float64``) without importing JAX.
31
+ FloatDType = npt.DTypeLike
32
+ # Abstract values produced while a function is traced by jax.jit/jax.vmap. Used to
33
+ # skip data-dependent validity raises that cannot run under tracing (the host-side
34
+ # caller is expected to validate folds before vmap). Empty until the JAX backend is
35
+ # used, so the numpy backend's isinstance checks are always False (numpy values are
36
+ # never tracers); eager (concrete) jax values are NOT tracers, so they still
37
+ # validate like the numpy backend.
38
+ _TRACER_TYPES: tuple = ()
39
+
40
+
41
+ def _enable_jax_typing() -> None:
42
+ """Broaden the array/scalar aliases and tracer types to admit JAX values.
43
+
44
+ Called (idempotently) when the JAX backend is resolved. Kept out of module
45
+ import so that ``import cvmatrix`` never imports JAX. Because
46
+ ``from __future__ import annotations`` keeps annotations as strings, typeguard
47
+ resolves annotations against these (now broadened) module globals at call time,
48
+ so ``jax.Array`` values pass the runtime type checks once JAX is in use.
49
+ """
50
+ import jax
51
+
52
+ global Array, Scalar, _TRACER_TYPES
53
+ Array = Union[np.ndarray, jax.Array]
54
+ Scalar = Union[np.floating, jax.Array, float, int]
55
+ _TRACER_TYPES = (jax.core.Tracer,)
52
56
 
53
57
 
54
58
  def _resolve_backend(backend: str):
@@ -85,6 +89,9 @@ def _resolve_backend(backend: str):
85
89
  "backend='jax' requires the optional JAX dependency. Install it with "
86
90
  "`pip install cvmatrix[jax]`."
87
91
  ) from e
92
+ # Broaden the array/scalar type aliases to admit jax.Array now that JAX is
93
+ # in use (kept out of module import so importing cvmatrix stays JAX-free).
94
+ _enable_jax_typing()
88
95
  return jnp
89
96
  raise ValueError(f"Invalid backend: {backend!r}. Must be 'numpy' or 'jax'.")
90
97
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cvmatrix"
3
- version = "3.2.0"
3
+ version = "3.2.1"
4
4
  description = "Fast computation of possibly weighted and possibly centered/scaled training set kernel matrices in a cross-validation setting."
5
5
  authors = [{ name = "Sm00thix", email = "oleemail@icloud.com" }]
6
6
  maintainers = [{ name = "Sm00thix", email = "oleemail@icloud.com" }]
@@ -116,7 +116,7 @@ toml = [
116
116
 
117
117
  [[package]]
118
118
  name = "cvmatrix"
119
- version = "3.2.0"
119
+ version = "3.2.1"
120
120
  source = { editable = "." }
121
121
  dependencies = [
122
122
  { name = "numpy", version = "2.4.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" },
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes