gpjax 0.11.2__py3-none-any.whl → 0.12.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,411 @@
1
+ """Linear operator abstractions for GPJax."""
2
+
3
+ from abc import (
4
+ ABC,
5
+ abstractmethod,
6
+ )
7
+ from typing import (
8
+ Any,
9
+ List,
10
+ Tuple,
11
+ Union,
12
+ )
13
+
14
+ from jax import Array
15
+ import jax.numpy as jnp
16
+ import jax.tree_util as jtu
17
+ from jaxtyping import Float
18
+
19
+
20
+ class LinearOperator(ABC):
21
+ """Abstract base class for linear operators."""
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ @property
27
+ @abstractmethod
28
+ def shape(self) -> Tuple[int, int]:
29
+ """Return the shape of the operator."""
30
+
31
+ @property
32
+ @abstractmethod
33
+ def dtype(self) -> jnp.dtype:
34
+ """Return the data type of the operator."""
35
+
36
+ @abstractmethod
37
+ def to_dense(self) -> Float[Array, "M N"]:
38
+ """Convert the operator to a dense JAX array."""
39
+
40
+ @property
41
+ def T(self) -> "LinearOperator":
42
+ """Return the transpose of the operator."""
43
+ # Default implementation: convert to dense and transpose
44
+ return Dense(self.to_dense().T)
45
+
46
+ def __matmul__(self, other):
47
+ """Matrix multiplication with another array or operator."""
48
+ if hasattr(other, "to_dense"):
49
+ # Other is a LinearOperator
50
+ return Dense(self.to_dense() @ other.to_dense())
51
+ else:
52
+ # Other is a JAX array
53
+ return self.to_dense() @ other
54
+
55
+ def __rmatmul__(self, other):
56
+ """Right matrix multiplication (other @ self)."""
57
+ if hasattr(other, "to_dense"):
58
+ # Other is a LinearOperator
59
+ return Dense(other.to_dense() @ self.to_dense())
60
+ else:
61
+ # Other is a JAX array
62
+ return other @ self.to_dense()
63
+
64
+ def __add__(self, other):
65
+ """Addition with another array or operator."""
66
+ if hasattr(other, "to_dense"):
67
+ # Other is a LinearOperator
68
+ return Dense(self.to_dense() + other.to_dense())
69
+ else:
70
+ # Other is a JAX array
71
+ return Dense(self.to_dense() + other)
72
+
73
+ def __radd__(self, other):
74
+ """Right addition (other + self)."""
75
+ if hasattr(other, "to_dense"):
76
+ # Other is a LinearOperator
77
+ return Dense(other.to_dense() + self.to_dense())
78
+ else:
79
+ # Other is a JAX array
80
+ return Dense(other + self.to_dense())
81
+
82
+ def __sub__(self, other):
83
+ """Subtraction with another array or operator."""
84
+ if hasattr(other, "to_dense"):
85
+ # Other is a LinearOperator
86
+ return Dense(self.to_dense() - other.to_dense())
87
+ else:
88
+ # Other is a JAX array
89
+ return Dense(self.to_dense() - other)
90
+
91
+ def __rsub__(self, other):
92
+ """Right subtraction (other - self)."""
93
+ if hasattr(other, "to_dense"):
94
+ # Other is a LinearOperator
95
+ return Dense(other.to_dense() - self.to_dense())
96
+ else:
97
+ # Other is a JAX array
98
+ return Dense(other - self.to_dense())
99
+
100
+ def __mul__(self, other):
101
+ """Scalar multiplication (self * scalar)."""
102
+ if jnp.isscalar(other):
103
+ return Dense(self.to_dense() * other)
104
+ else:
105
+ # Element-wise multiplication with array
106
+ return Dense(self.to_dense() * other)
107
+
108
+ def __rmul__(self, other):
109
+ """Right scalar multiplication (scalar * self)."""
110
+ if jnp.isscalar(other):
111
+ return Dense(other * self.to_dense())
112
+ else:
113
+ # Element-wise multiplication with array
114
+ return Dense(other * self.to_dense())
115
+
116
+
117
+ class Dense(LinearOperator):
118
+ """Dense linear operator wrapping a JAX array."""
119
+
120
+ def __init__(self, array: Float[Array, "M N"]):
121
+ super().__init__()
122
+ self.array = array
123
+
124
+ @property
125
+ def shape(self) -> Tuple[int, int]:
126
+ return self.array.shape
127
+
128
+ @property
129
+ def dtype(self) -> jnp.dtype:
130
+ return self.array.dtype
131
+
132
+ def to_dense(self) -> Float[Array, "M N"]:
133
+ return self.array
134
+
135
+ @property
136
+ def T(self) -> "Dense":
137
+ return Dense(self.array.T)
138
+
139
+
140
+ class Diagonal(LinearOperator):
141
+ """Diagonal linear operator."""
142
+
143
+ def __init__(self, diagonal: Float[Array, " N"]):
144
+ super().__init__()
145
+ self.diagonal = diagonal
146
+
147
+ @property
148
+ def shape(self) -> Tuple[int, int]:
149
+ n = self.diagonal.shape[0]
150
+ return (n, n)
151
+
152
+ @property
153
+ def dtype(self) -> jnp.dtype:
154
+ return self.diagonal.dtype
155
+
156
+ def to_dense(self) -> Float[Array, "N N"]:
157
+ return jnp.diag(self.diagonal)
158
+
159
+ @property
160
+ def T(self) -> "Diagonal":
161
+ return Diagonal(self.diagonal)
162
+
163
+
164
+ class Identity(LinearOperator):
165
+ """Identity linear operator."""
166
+
167
+ def __init__(self, shape: Union[int, Tuple[int, int]], dtype=jnp.float64):
168
+ super().__init__()
169
+ if isinstance(shape, int):
170
+ self._shape = (shape, shape)
171
+ else:
172
+ if shape[0] != shape[1]:
173
+ raise ValueError(f"Identity matrix must be square, got shape {shape}")
174
+ self._shape = shape
175
+ self._dtype = dtype
176
+
177
+ @property
178
+ def shape(self) -> Tuple[int, int]:
179
+ return self._shape
180
+
181
+ @property
182
+ def dtype(self) -> Any:
183
+ return self._dtype
184
+
185
+ def to_dense(self) -> Float[Array, "N N"]:
186
+ n = self._shape[0]
187
+ return jnp.eye(n, dtype=self._dtype)
188
+
189
+ @property
190
+ def T(self) -> "Identity":
191
+ return Identity(self._shape, dtype=self._dtype)
192
+
193
+
194
+ class Triangular(LinearOperator):
195
+ """Triangular linear operator."""
196
+
197
+ def __init__(self, array: Float[Array, "N N"], lower: bool = True):
198
+ super().__init__()
199
+ self.array = array
200
+ self.lower = lower
201
+
202
+ @property
203
+ def shape(self) -> Tuple[int, int]:
204
+ return self.array.shape
205
+
206
+ @property
207
+ def dtype(self) -> Any:
208
+ return self.array.dtype
209
+
210
+ def to_dense(self) -> Float[Array, "N N"]:
211
+ if self.lower:
212
+ return jnp.tril(self.array)
213
+ else:
214
+ return jnp.triu(self.array)
215
+
216
+ @property
217
+ def T(self) -> "Triangular":
218
+ return Triangular(self.array.T, lower=not self.lower)
219
+
220
+
221
+ class BlockDiag(LinearOperator):
222
+ """Block diagonal linear operator."""
223
+
224
+ def __init__(
225
+ self, operators: List[LinearOperator], multiplicities: List[int] = None
226
+ ):
227
+ super().__init__()
228
+ self.operators = operators
229
+
230
+ # Handle multiplicities - how many times each block is repeated
231
+ if multiplicities is None:
232
+ self.multiplicities = [1] * len(operators)
233
+ else:
234
+ if len(multiplicities) != len(operators):
235
+ raise ValueError(
236
+ f"Length of multiplicities ({len(multiplicities)}) must match operators ({len(operators)})"
237
+ )
238
+ self.multiplicities = multiplicities
239
+
240
+ # Calculate total shape with multiplicities
241
+ rows = sum(
242
+ op.shape[0] * mult
243
+ for op, mult in zip(operators, self.multiplicities, strict=False)
244
+ )
245
+ cols = sum(
246
+ op.shape[1] * mult
247
+ for op, mult in zip(operators, self.multiplicities, strict=False)
248
+ )
249
+ self._shape = (rows, cols)
250
+
251
+ # Use dtype of first operator (assuming all same dtype)
252
+ if operators:
253
+ self._dtype = operators[0].dtype
254
+ else:
255
+ self._dtype = jnp.float64
256
+
257
+ @property
258
+ def shape(self) -> Tuple[int, int]:
259
+ return self._shape
260
+
261
+ @property
262
+ def dtype(self) -> Any:
263
+ return self._dtype
264
+
265
+ def to_dense(self) -> Float[Array, "M N"]:
266
+ if not self.operators:
267
+ return jnp.zeros(self._shape, dtype=self._dtype)
268
+
269
+ # Convert each operator to dense and create block diagonal with multiplicities
270
+ expanded_blocks = []
271
+ for op, mult in zip(self.operators, self.multiplicities, strict=False):
272
+ op_dense = op.to_dense()
273
+ for _ in range(mult):
274
+ expanded_blocks.append(op_dense)
275
+
276
+ # Create the full block diagonal matrix
277
+ n_blocks = len(expanded_blocks)
278
+ if n_blocks == 0:
279
+ return jnp.zeros(self._shape, dtype=self._dtype)
280
+
281
+ # Build the block diagonal matrix
282
+ rows = []
283
+ for i in range(n_blocks):
284
+ row = []
285
+ for j in range(n_blocks):
286
+ if i == j:
287
+ row.append(expanded_blocks[i])
288
+ else:
289
+ row.append(
290
+ jnp.zeros(
291
+ (expanded_blocks[i].shape[0], expanded_blocks[j].shape[1]),
292
+ dtype=self._dtype,
293
+ )
294
+ )
295
+ rows.append(row)
296
+ return jnp.block(rows)
297
+
298
+ @property
299
+ def T(self) -> "BlockDiag":
300
+ transposed_ops = [op.T for op in self.operators]
301
+ return BlockDiag(transposed_ops, multiplicities=self.multiplicities)
302
+
303
+
304
+ class Kronecker(LinearOperator):
305
+ """Kronecker product linear operator."""
306
+
307
+ def __init__(self, operators: List[LinearOperator]):
308
+ super().__init__()
309
+ if len(operators) < 2:
310
+ raise ValueError("Kronecker product requires at least 2 operators")
311
+ self.operators = operators
312
+
313
+ # Calculate shape as product of individual shapes
314
+ rows = 1
315
+ cols = 1
316
+ for op in operators:
317
+ rows *= op.shape[0]
318
+ cols *= op.shape[1]
319
+ self._shape = (rows, cols)
320
+
321
+ # Use dtype of first operator
322
+ self._dtype = operators[0].dtype
323
+
324
+ @property
325
+ def shape(self) -> Tuple[int, int]:
326
+ return self._shape
327
+
328
+ @property
329
+ def dtype(self) -> Any:
330
+ return self._dtype
331
+
332
+ def to_dense(self) -> Float[Array, "M N"]:
333
+ # Convert to dense and compute Kronecker product
334
+ result = self.operators[0].to_dense()
335
+ for op in self.operators[1:]:
336
+ result = jnp.kron(result, op.to_dense())
337
+ return result
338
+
339
+ @property
340
+ def T(self) -> "Kronecker":
341
+ transposed_ops = [op.T for op in self.operators]
342
+ return Kronecker(transposed_ops)
343
+
344
+
345
+ def _dense_tree_flatten(dense):
346
+ return (dense.array,), None
347
+
348
+
349
+ def _dense_tree_unflatten(aux_data, children):
350
+ return Dense(children[0])
351
+
352
+
353
+ jtu.register_pytree_node(Dense, _dense_tree_flatten, _dense_tree_unflatten)
354
+
355
+
356
+ def _diagonal_tree_flatten(diagonal):
357
+ return (diagonal.diagonal,), None
358
+
359
+
360
+ def _diagonal_tree_unflatten(aux_data, children):
361
+ return Diagonal(children[0])
362
+
363
+
364
+ jtu.register_pytree_node(Diagonal, _diagonal_tree_flatten, _diagonal_tree_unflatten)
365
+
366
+
367
+ def _identity_tree_flatten(identity):
368
+ return (), (identity._shape, identity._dtype)
369
+
370
+
371
+ def _identity_tree_unflatten(aux_data, children):
372
+ shape, dtype = aux_data
373
+ return Identity(shape, dtype)
374
+
375
+
376
+ jtu.register_pytree_node(Identity, _identity_tree_flatten, _identity_tree_unflatten)
377
+
378
+
379
+ def _triangular_tree_flatten(triangular):
380
+ return (triangular.array,), triangular.lower
381
+
382
+
383
+ def _triangular_tree_unflatten(aux_data, children):
384
+ return Triangular(children[0], aux_data)
385
+
386
+
387
+ jtu.register_pytree_node(
388
+ Triangular, _triangular_tree_flatten, _triangular_tree_unflatten
389
+ )
390
+
391
+
392
+ def _blockdiag_tree_flatten(blockdiag):
393
+ return tuple(blockdiag.operators), blockdiag.multiplicities
394
+
395
+
396
+ def _blockdiag_tree_unflatten(aux_data, children):
397
+ return BlockDiag(list(children), aux_data)
398
+
399
+
400
+ jtu.register_pytree_node(BlockDiag, _blockdiag_tree_flatten, _blockdiag_tree_unflatten)
401
+
402
+
403
+ def _kronecker_tree_flatten(kronecker):
404
+ return tuple(kronecker.operators), None
405
+
406
+
407
+ def _kronecker_tree_unflatten(aux_data, children):
408
+ return Kronecker(list(children))
409
+
410
+
411
+ jtu.register_pytree_node(Kronecker, _kronecker_tree_flatten, _kronecker_tree_unflatten)
gpjax/linalg/utils.py ADDED
@@ -0,0 +1,65 @@
1
+ """Utility functions for the linear algebra module."""
2
+
3
+ import jax.numpy as jnp
4
+ from jaxtyping import Array
5
+
6
+ from gpjax.linalg.operators import LinearOperator
7
+
8
+
9
+ class PSDAnnotation:
10
+ """Marker class for PSD (Positive Semi-Definite) annotations."""
11
+
12
+ def __call__(self, A: LinearOperator) -> LinearOperator:
13
+ """Make PSD annotation callable."""
14
+ return psd(A)
15
+
16
+
17
+ # Create the PSD marker similar to cola.PSD
18
+ PSD = PSDAnnotation()
19
+
20
+
21
+ def psd(A: LinearOperator) -> LinearOperator:
22
+ """Mark a linear operator as positive semi-definite.
23
+
24
+ This function acts as a marker/wrapper for positive semi-definite matrices.
25
+
26
+ Args:
27
+ A: A LinearOperator that is assumed to be positive semi-definite.
28
+
29
+ Returns:
30
+ The same LinearOperator, marked as PSD.
31
+ """
32
+ # Add annotations attribute if it doesn't exist
33
+ if not hasattr(A, "annotations"):
34
+ A.annotations = set()
35
+ A.annotations.add(PSD)
36
+ return A
37
+
38
+
39
+ def add_jitter(matrix: Array, jitter: float | Array = 1e-6) -> Array:
40
+ """Add jitter to the diagonal of a matrix for numerical stability.
41
+
42
+ This function adds a small positive value (jitter) to the diagonal elements
43
+ of a square matrix to improve numerical stability, particularly for
44
+ Cholesky decompositions and matrix inversions.
45
+
46
+ Args:
47
+ matrix: A square matrix to which jitter will be added.
48
+ jitter: The jitter value to add to the diagonal. Defaults to 1e-6.
49
+
50
+ Returns:
51
+ The matrix with jitter added to its diagonal.
52
+
53
+ Examples:
54
+ >>> import jax.numpy as jnp
55
+ >>> from gpjax.linalg.utils import add_jitter
56
+ >>> matrix = jnp.array([[1.0, 0.5], [0.5, 1.0]])
57
+ >>> jittered_matrix = add_jitter(matrix, jitter=0.01)
58
+ """
59
+ if matrix.ndim != 2:
60
+ raise ValueError(f"Expected 2D matrix, got {matrix.ndim}D array")
61
+
62
+ if matrix.shape[0] != matrix.shape[1]:
63
+ raise ValueError(f"Expected square matrix, got shape {matrix.shape}")
64
+
65
+ return matrix + jnp.eye(matrix.shape[0]) * jitter
gpjax/mean_functions.py CHANGED
@@ -27,8 +27,6 @@ from jaxtyping import (
27
27
 
28
28
  from gpjax.parameters import (
29
29
  Parameter,
30
- Real,
31
- Static,
32
30
  )
33
31
  from gpjax.typing import (
34
32
  Array,
@@ -132,12 +130,12 @@ class Constant(AbstractMeanFunction):
132
130
 
133
131
  def __init__(
134
132
  self,
135
- constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter, Static] = 0.0,
133
+ constant: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0,
136
134
  ):
137
- if isinstance(constant, Parameter) or isinstance(constant, Static):
135
+ if isinstance(constant, Parameter):
138
136
  self.constant = constant
139
137
  else:
140
- self.constant = Real(jnp.array(constant))
138
+ self.constant = jnp.array(constant)
141
139
 
142
140
  def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
143
141
  r"""Evaluate the mean function at the given points.
@@ -148,7 +146,10 @@ class Constant(AbstractMeanFunction):
148
146
  Returns:
149
147
  Float[Array, "1"]: The evaluated mean function.
150
148
  """
151
- return jnp.ones((x.shape[0], 1)) * self.constant.value
149
+ if isinstance(self.constant, Parameter):
150
+ return jnp.ones((x.shape[0], 1)) * self.constant.value
151
+ else:
152
+ return jnp.ones((x.shape[0], 1)) * self.constant
152
153
 
153
154
 
154
155
  class Zero(Constant):
@@ -160,7 +161,7 @@ class Zero(Constant):
160
161
  """
161
162
 
162
163
  def __init__(self):
163
- super().__init__(constant=Static(jnp.array(0.0)))
164
+ super().__init__(constant=0.0)
164
165
 
165
166
 
166
167
  class CombinationMeanFunction(AbstractMeanFunction):
gpjax/objectives.py CHANGED
@@ -1,13 +1,5 @@
1
1
  from typing import TypeVar
2
2
 
3
- from cola.annotations import PSD
4
- from cola.linalg.decompositions.decompositions import Cholesky
5
- from cola.linalg.inverse.inv import (
6
- inv,
7
- solve,
8
- )
9
- from cola.linalg.trace.diag_trace import diag
10
- from cola.ops.operators import I_like
11
3
  from flax import nnx
12
4
  from jax import vmap
13
5
  import jax.numpy as jnp
@@ -22,7 +14,13 @@ from gpjax.gps import (
22
14
  ConjugatePosterior,
23
15
  NonConjugatePosterior,
24
16
  )
25
- from gpjax.lower_cholesky import lower_cholesky
17
+ from gpjax.linalg import (
18
+ Dense,
19
+ lower_cholesky,
20
+ psd,
21
+ solve,
22
+ )
23
+ from gpjax.linalg.utils import add_jitter
26
24
  from gpjax.typing import (
27
25
  Array,
28
26
  ScalarFloat,
@@ -100,9 +98,9 @@ def conjugate_mll(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat:
100
98
 
101
99
  # Σ = (Kxx + Io²) = LLᵀ
102
100
  Kxx = posterior.prior.kernel.gram(x)
103
- Kxx += I_like(Kxx) * posterior.prior.jitter
104
- Sigma = Kxx + I_like(Kxx) * obs_noise
105
- Sigma = PSD(Sigma)
101
+ Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
102
+ Sigma_dense = Kxx_dense + jnp.eye(Kxx.shape[0]) * obs_noise
103
+ Sigma = psd(Dense(Sigma_dense))
106
104
 
107
105
  # p(y | x, θ), where θ are the model hyperparameters:
108
106
  mll = GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Sigma)
@@ -164,11 +162,14 @@ def conjugate_loocv(posterior: ConjugatePosterior, data: Dataset) -> ScalarFloat
164
162
 
165
163
  # Σ = (Kxx + Io²)
166
164
  Kxx = posterior.prior.kernel.gram(x)
167
- Sigma = Kxx + I_like(Kxx) * (obs_var + posterior.prior.jitter)
168
- Sigma = PSD(Sigma) # [N, N]
165
+ Sigma_dense = Kxx.to_dense() + jnp.eye(Kxx.shape[0]) * (
166
+ obs_var + posterior.prior.jitter
167
+ )
168
+ Sigma = psd(Dense(Sigma_dense)) # [N, N]
169
169
 
170
- Sigma_inv_y = solve(Sigma, y - mx, Cholesky()) # [N, 1]
171
- Sigma_inv_diag = diag(inv(Sigma, Cholesky()))[:, None] # [N, 1]
170
+ Sigma_inv_y = solve(Sigma, y - mx) # [N, 1]
171
+ Sigma_inv = jnp.linalg.inv(Sigma.to_dense())
172
+ Sigma_inv_diag = jnp.diag(Sigma_inv)[:, None] # [N, 1]
172
173
 
173
174
  loocv_means = mx + (y - mx) - Sigma_inv_y / Sigma_inv_diag
174
175
  loocv_stds = jnp.sqrt(1.0 / Sigma_inv_diag)
@@ -213,8 +214,8 @@ def log_posterior_density(
213
214
 
214
215
  # Gram matrix
215
216
  Kxx = posterior.prior.kernel.gram(x)
216
- Kxx += I_like(Kxx) * posterior.prior.jitter
217
- Kxx = PSD(Kxx)
217
+ Kxx_dense = add_jitter(Kxx.to_dense(), posterior.prior.jitter)
218
+ Kxx = psd(Dense(Kxx_dense))
218
219
  Lx = lower_cholesky(Kxx)
219
220
 
220
221
  # Compute the prior mean function
@@ -349,8 +350,8 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
349
350
  noise = variational_family.posterior.likelihood.obs_stddev.value**2
350
351
  z = variational_family.inducing_inputs.value
351
352
  Kzz = kernel.gram(z)
352
- Kzz += I_like(Kzz) * variational_family.jitter
353
- Kzz = PSD(Kzz)
353
+ Kzz_dense = add_jitter(Kzz.to_dense(), variational_family.jitter)
354
+ Kzz = psd(Dense(Kzz_dense))
354
355
  Kzx = kernel.cross_covariance(z, x)
355
356
  Kxx_diag = vmap(kernel, in_axes=(0, 0))(x, x)
356
357
  μx = mean_function(x)
@@ -383,7 +384,7 @@ def collapsed_elbo(variational_family: VF, data: Dataset) -> ScalarFloat:
383
384
  #
384
385
  # with A and B defined as above.
385
386
 
386
- A = solve(Lz, Kzx, Cholesky()) / jnp.sqrt(noise)
387
+ A = solve(Lz, Kzx) / jnp.sqrt(noise)
387
388
 
388
389
  # AAᵀ
389
390
  AAT = jnp.matmul(A, A.T)
gpjax/parameters.py CHANGED
@@ -21,23 +21,20 @@ def transform(
21
21
  r"""Transforms parameters using a bijector.
22
22
 
23
23
  Example:
24
- ```pycon
25
24
  >>> from gpjax.parameters import PositiveReal, transform
26
25
  >>> import jax.numpy as jnp
27
26
  >>> import numpyro.distributions.transforms as npt
28
27
  >>> from flax import nnx
29
28
  >>> params = nnx.State(
30
- >>> {
31
- >>> "a": PositiveReal(jnp.array([1.0])),
32
- >>> "b": PositiveReal(jnp.array([2.0])),
33
- >>> }
34
- >>> )
29
+ ... {
30
+ ... "a": PositiveReal(jnp.array([1.0])),
31
+ ... "b": PositiveReal(jnp.array([2.0])),
32
+ ... }
33
+ ... )
35
34
  >>> params_bijection = {'positive': npt.SoftplusTransform()}
36
35
  >>> transformed_params = transform(params, params_bijection)
37
36
  >>> print(transformed_params["a"].value)
38
- [1.3132617]
39
- ```
40
-
37
+ [1.3132617]
41
38
 
42
39
  Args:
43
40
  params: A nnx.State object containing parameters to be transformed.
@@ -49,7 +46,7 @@ def transform(
49
46
  """
50
47
 
51
48
  def _inner(param):
52
- bijector = params_bijection.get(param._tag, npt.IdentityTransform())
49
+ bijector = params_bijection.get(param.tag, npt.IdentityTransform())
53
50
  if inverse:
54
51
  transformed_value = bijector.inv(param.value)
55
52
  else:
@@ -60,10 +57,11 @@ def transform(
60
57
 
61
58
  gp_params, *other_params = params.split(Parameter, ...)
62
59
 
60
+ # Transform each parameter in the state
63
61
  transformed_gp_params: nnx.State = jtu.tree_map(
64
- lambda x: _inner(x),
62
+ lambda x: _inner(x) if isinstance(x, Parameter) else x,
65
63
  gp_params,
66
- is_leaf=lambda x: isinstance(x, nnx.VariableState),
64
+ is_leaf=lambda x: isinstance(x, Parameter),
67
65
  )
68
66
  return nnx.State.merge(transformed_gp_params, *other_params)
69
67
 
@@ -79,7 +77,7 @@ class Parameter(nnx.Variable[T]):
79
77
  _check_is_arraylike(value)
80
78
 
81
79
  super().__init__(value=jnp.asarray(value), **kwargs)
82
- self._tag = tag
80
+ self.tag = tag
83
81
 
84
82
 
85
83
  class NonNegativeReal(Parameter[T]):
@@ -124,16 +122,6 @@ class SigmoidBounded(Parameter[T]):
124
122
  )
125
123
 
126
124
 
127
- class Static(nnx.Variable[T]):
128
- """Static parameter that is not trainable."""
129
-
130
- def __init__(self, value: T, tag: ParameterTag = "static", **kwargs):
131
- _check_is_arraylike(value)
132
-
133
- super().__init__(value=jnp.asarray(value), tag=tag, **kwargs)
134
- self._tag = tag
135
-
136
-
137
125
  class LowerTriangular(Parameter[T]):
138
126
  """Parameter that is a lower triangular matrix."""
139
127