brainstate 0.1.3__py2.py3-none-any.whl → 0.1.5__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +1 -16
  3. brainstate/_state.py +1 -0
  4. brainstate/augment/_mapping.py +9 -9
  5. brainstate/augment/_mapping_test.py +162 -0
  6. brainstate/compile/_jit.py +14 -5
  7. brainstate/compile/_make_jaxpr.py +78 -22
  8. brainstate/compile/_make_jaxpr_test.py +13 -2
  9. brainstate/graph/_graph_node.py +1 -1
  10. brainstate/graph/_graph_operation.py +4 -4
  11. brainstate/mixin.py +31 -2
  12. brainstate/nn/__init__.py +8 -5
  13. brainstate/nn/_common.py +7 -19
  14. brainstate/nn/_delay.py +13 -1
  15. brainstate/nn/_dropout.py +5 -4
  16. brainstate/nn/_dynamics.py +39 -44
  17. brainstate/nn/_exp_euler.py +13 -16
  18. brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
  19. brainstate/nn/_inputs.py +1 -1
  20. brainstate/nn/_linear_mv.py +1 -1
  21. brainstate/nn/_module.py +5 -5
  22. brainstate/nn/_projection.py +190 -98
  23. brainstate/nn/_synapse.py +5 -9
  24. brainstate/nn/_synaptic_projection.py +376 -86
  25. brainstate/random/_rand_state.py +13 -7
  26. brainstate/surrogate.py +1 -1
  27. brainstate/typing.py +1 -1
  28. brainstate/util/__init__.py +14 -14
  29. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  30. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/METADATA +1 -1
  31. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/RECORD +42 -42
  32. /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  33. /brainstate/util/{_caller.py → caller.py} +0 -0
  34. /brainstate/util/{_error.py → error.py} +0 -0
  35. /brainstate/util/{_others.py → others.py} +0 -0
  36. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  37. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  38. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  39. /brainstate/util/{_struct.py → struct.py} +0 -0
  40. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/LICENSE +0 -0
  41. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/WHEEL +0 -0
  42. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -17,7 +17,7 @@
17
17
  A ``State``-based Transformation System for Program Compilation and Augmentation
18
18
  """
19
19
 
20
- __version__ = "0.1.3"
20
+ __version__ = "0.1.5"
21
21
 
22
22
  from . import augment
23
23
  from . import compile
@@ -16,10 +16,9 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
- import importlib.util
20
19
  from contextlib import contextmanager
21
20
  from functools import partial
22
- from typing import Iterable, Hashable, TypeVar, Callable, TYPE_CHECKING
21
+ from typing import Iterable, Hashable, TypeVar, Callable
23
22
 
24
23
  import jax
25
24
 
@@ -31,7 +30,6 @@ __all__ = [
31
30
  'get_aval',
32
31
  'Tracer',
33
32
  'to_concrete_aval',
34
- 'brainevent',
35
33
  'safe_map',
36
34
  'safe_zip',
37
35
  'unzip2',
@@ -47,8 +45,6 @@ T3 = TypeVar("T3")
47
45
 
48
46
  from saiunit._compatible_import import wrap_init
49
47
 
50
- brainevent_installed = importlib.util.find_spec('brainevent') is not None
51
-
52
48
  from jax.core import get_aval, Tracer
53
49
 
54
50
  if jax.__version_info__ < (0, 5, 0):
@@ -150,14 +146,3 @@ def to_concrete_aval(aval):
150
146
  return aval.to_concrete_value()
151
147
  return aval
152
148
 
153
-
154
- if not brainevent_installed:
155
- if not TYPE_CHECKING:
156
- class BrainEvent:
157
- def __getattr__(self, item):
158
- raise ImportError('brainevent is not installed, please install brainevent first.')
159
-
160
- brainevent = BrainEvent()
161
-
162
- else:
163
- import brainevent
brainstate/_state.py CHANGED
@@ -50,6 +50,7 @@ __all__ = [
50
50
  'LongTermState',
51
51
  'HiddenState',
52
52
  'ParamState',
53
+ 'BatchState',
53
54
  'TreefyState',
54
55
  'FakeState',
55
56
 
@@ -185,10 +185,10 @@ def _compile_stateful_function(
185
185
  if isinstance(in_axes, int):
186
186
  args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
187
187
  elif isinstance(in_axes, tuple):
188
- args = tuple(
189
- [arg if in_axis is None else _remove_axis(arg, in_axis)
190
- for arg, in_axis in zip(args, in_axes)]
191
- )
188
+ args = tuple([
189
+ arg if in_axis is None else _remove_axis(arg, in_axis)
190
+ for arg, in_axis in zip(args, in_axes)
191
+ ])
192
192
  stateful_fn.make_jaxpr(state_vals, args)
193
193
  return stateful_fn.get_arg_cache_key(state_vals, args)
194
194
 
@@ -383,10 +383,7 @@ def _vmap_transform(
383
383
  stateful_fn.axis_env = axis_env
384
384
 
385
385
  # stateful function
386
- stateful_fn = StatefulFunction(
387
- _vmap_fn_for_compilation,
388
- name='vmap',
389
- )
386
+ stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
390
387
 
391
388
  @functools.wraps(f)
392
389
  def new_fn_for_vmap(
@@ -460,7 +457,10 @@ def _vmap_transform(
460
457
  # analyze vmapping axis error
461
458
  for state in state_trace.get_write_states():
462
459
  leaves = jax.tree.leaves(state.value)
463
- if any([isinstance(leaf, BatchTracer) for leaf in leaves]) and state not in out_state_to_axis:
460
+ if (
461
+ any([isinstance(leaf, BatchTracer) and (leaf.batch_dim is not None) for leaf in leaves])
462
+ and state not in out_state_to_axis
463
+ ):
464
464
  if isinstance(state, RandomState) and state in rng_sets:
465
465
  continue
466
466
  state.raise_error_with_source_info(
@@ -19,6 +19,8 @@ import unittest
19
19
  import jax
20
20
  import jax.numpy as jnp
21
21
  import numpy as np
22
+ from jax import vmap
23
+ from jax.lax import psum, pmean, pmax
22
24
 
23
25
  import brainstate
24
26
  import brainstate.augment
@@ -433,3 +435,163 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
433
435
  foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
434
436
 
435
437
  faulty_init()
438
+
439
+
440
+ class TestAxisName:
441
+ def test1(self):
442
+ def compute_stats_with_axis_name(x):
443
+ """Compute statistics using named axis operations"""
444
+ # Sum across the named axis 'batch'
445
+ total_sum = psum(x, axis_name='batch')
446
+
447
+ # Mean across the named axis 'batch'
448
+ mean_val = pmean(x, axis_name='batch')
449
+
450
+ # Max across the named axis 'batch'
451
+ max_val = pmax(x, axis_name='batch')
452
+
453
+ return {
454
+ 'sum': total_sum,
455
+ 'mean': mean_val,
456
+ 'max': max_val,
457
+ 'original': x
458
+ }
459
+
460
+ batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
461
+ print("Input batch data:", batch_data)
462
+
463
+ # vmap with axis name 'batch'
464
+ vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
465
+ result_jax = vectorized_stats_jax(batch_data)
466
+
467
+ # vmap with axis name 'batch'
468
+ vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
469
+ result = vectorized_stats(batch_data)
470
+
471
+ # vmap with axis name 'batch'
472
+ vectorized_stats_v2 = brainstate.transform.jit(
473
+ brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
474
+ )
475
+ result_v2 = vectorized_stats_v2(batch_data)
476
+
477
+ for key in result_jax.keys():
478
+ print(f" {key}: {result_jax[key]}")
479
+ assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
480
+ assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
481
+
482
+ def test_nested_vmap(self):
483
+ def nested_computation(x):
484
+ """Computation with multiple named axes"""
485
+ # Sum over 'inner' axis, then mean over 'outer' axis
486
+ inner_sum = psum(x, axis_name='inner')
487
+ outer_mean = pmean(inner_sum, axis_name='outer')
488
+ return outer_mean
489
+
490
+ # Create 2D batch data
491
+ data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
492
+ print("Input 2D data shape:", data_2d.shape)
493
+ print("Input 2D data:\n", data_2d)
494
+
495
+ # Nested vmap: first over inner dimension, then outer dimension
496
+ inner_vmap = vmap(nested_computation, axis_name='inner')
497
+ nested_vmap = vmap(inner_vmap, axis_name='outer')
498
+
499
+ result_2d = nested_vmap(data_2d)
500
+ print("Result after nested vmap:", result_2d)
501
+
502
+ inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
503
+ nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
504
+ result_2d_bst = nested_vmap_bst(data_2d)
505
+ print("Result after nested vmap:", result_2d_bst)
506
+
507
+ assert jnp.allclose(result_2d, result_2d_bst)
508
+
509
+ def _gradient_averaging_simulation_bst(self):
510
+ def loss_function(params, x, y):
511
+ """Simple quadratic loss"""
512
+ pred = params * x
513
+ return (pred - y) ** 2
514
+
515
+ def compute_gradients_with_averaging(params, batch_x, batch_y):
516
+ """Compute gradients and average them across the batch"""
517
+ # Compute per-sample gradients
518
+ grad_fn = jax.grad(loss_function, argnums=0)
519
+ per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
520
+
521
+ # Average gradients across batch using named axis
522
+ def average_grads(grads):
523
+ return pmean(grads, axis_name='batch')
524
+
525
+ # Apply averaging with named axis
526
+ averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
527
+ return averaged_grads
528
+
529
+ # Example data
530
+ params = 2.0
531
+ batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
532
+ batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
533
+
534
+ print("Parameters:", params)
535
+ print("Batch X:", batch_x)
536
+ print("Batch Y:", batch_y)
537
+
538
+ # Compute individual gradients first
539
+ grad_fn = jax.grad(loss_function, argnums=0)
540
+ individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
541
+ print("Individual gradients:", individual_grads)
542
+
543
+ # Now compute averaged gradients using axis names
544
+ averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
545
+ print("Averaged gradients:", averaged_grads)
546
+
547
+ return individual_grads, averaged_grads
548
+
549
+ def _gradient_averaging_simulation_jax(self):
550
+ def loss_function(params, x, y):
551
+ """Simple quadratic loss"""
552
+ pred = params * x
553
+ return (pred - y) ** 2
554
+
555
+ def compute_gradients_with_averaging(params, batch_x, batch_y):
556
+ """Compute gradients and average them across the batch"""
557
+ # Compute per-sample gradients
558
+ grad_fn = jax.grad(loss_function, argnums=0)
559
+ per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
560
+
561
+ # Average gradients across batch using named axis
562
+ def average_grads(grads):
563
+ return pmean(grads, axis_name='batch')
564
+
565
+ # Apply averaging with named axis
566
+ averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
567
+ return averaged_grads
568
+
569
+ # Example data
570
+ params = 2.0
571
+ batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
572
+ batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
573
+
574
+ print("Parameters:", params)
575
+ print("Batch X:", batch_x)
576
+ print("Batch Y:", batch_y)
577
+
578
+ # Compute individual gradients first
579
+ grad_fn = jax.grad(loss_function, argnums=0)
580
+ individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
581
+ print("Individual gradients:", individual_grads)
582
+
583
+ # Now compute averaged gradients using axis names
584
+ averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
585
+ print("Averaged gradients:", averaged_grads)
586
+
587
+ return individual_grads, averaged_grads
588
+
589
+ def test_gradient_averaging_simulation(self):
590
+ individual_grads, averaged_grads = self._gradient_averaging_simulation_bst()
591
+ individual_grads_jax, averaged_grads_jax = self._gradient_averaging_simulation_jax()
592
+ assert jnp.allclose(individual_grads, individual_grads_jax)
593
+ assert jnp.allclose(averaged_grads, averaged_grads_jax)
594
+
595
+
596
+
597
+
@@ -51,6 +51,7 @@ def _get_jitted_fun(
51
51
  out_shardings,
52
52
  static_argnums,
53
53
  donate_argnums,
54
+ static_argnames,
54
55
  donate_argnames,
55
56
  keep_unused,
56
57
  device,
@@ -59,10 +60,12 @@ def _get_jitted_fun(
59
60
  abstracted_axes,
60
61
  **kwargs
61
62
  ) -> JittedFunction:
62
- static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
63
+ static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
64
+ donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
63
65
  fun = StatefulFunction(
64
66
  fun,
65
67
  static_argnums=static_argnums,
68
+ static_argnames=static_argnames,
66
69
  abstracted_axes=abstracted_axes,
67
70
  cache_type='jit',
68
71
  name='jit'
@@ -70,7 +73,8 @@ def _get_jitted_fun(
70
73
  jit_fun = jax.jit(
71
74
  fun.jaxpr_call,
72
75
  static_argnums=tuple(i + 1 for i in static_argnums),
73
- donate_argnums=donate_argnums,
76
+ static_argnames=static_argnames,
77
+ donate_argnums=tuple(i + 1 for i in donate_argnums),
74
78
  donate_argnames=donate_argnames,
75
79
  keep_unused=keep_unused,
76
80
  device=device,
@@ -179,6 +183,7 @@ def jit(
179
183
  out_shardings=sharding_impls.UNSPECIFIED,
180
184
  static_argnums: int | Sequence[int] | None = None,
181
185
  donate_argnums: int | Sequence[int] | None = None,
186
+ static_argnames: str | Sequence[str] | None = None,
182
187
  donate_argnames: str | Iterable[str] | None = None,
183
188
  keep_unused: bool = False,
184
189
  device: Device | None = None,
@@ -190,9 +195,6 @@ def jit(
190
195
  """
191
196
  Sets up ``fun`` for just-in-time compilation with XLA.
192
197
 
193
- Does not support setting ``static_argnames`` as in ``jax.jit()``.
194
-
195
-
196
198
  Args:
197
199
  fun: Function to be jitted.
198
200
  in_shardings: Pytree of structure matching that of arguments to ``fun``,
@@ -246,6 +248,11 @@ def jit(
246
248
  provided, ``inspect.signature`` is not used, and only actual
247
249
  parameters listed in either ``static_argnums`` or ``static_argnames`` will
248
250
  be treated as static.
251
+ static_argnames: An optional string or collection of strings specifying
252
+ which named arguments are treated as static (compile-time constant).
253
+ Operations that only depend on static arguments will be constant-folded in
254
+ Python (during tracing), and so the corresponding argument values can be
255
+ any Python object.
249
256
  donate_argnums: Specify which positional argument buffers are "donated" to
250
257
  the computation. It is safe to donate argument buffers if you no longer
251
258
  need them once the computation has finished. In some cases XLA can make
@@ -309,6 +316,7 @@ def jit(
309
316
  out_shardings=out_shardings,
310
317
  static_argnums=static_argnums,
311
318
  donate_argnums=donate_argnums,
319
+ static_argnames=static_argnames,
312
320
  donate_argnames=donate_argnames,
313
321
  keep_unused=keep_unused,
314
322
  device=device,
@@ -327,6 +335,7 @@ def jit(
327
335
  out_shardings,
328
336
  static_argnums,
329
337
  donate_argnums,
338
+ static_argnames,
330
339
  donate_argnames,
331
340
  keep_unused,
332
341
  device,
@@ -88,6 +88,12 @@ __all__ = [
88
88
  ]
89
89
 
90
90
 
91
+ def _ensure_str(x: str) -> str:
92
+ if not isinstance(x, str):
93
+ raise TypeError(f"argument is not a string: {x}")
94
+ return x
95
+
96
+
91
97
  def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
92
98
  """Convert x to a tuple of indices."""
93
99
  x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
@@ -97,6 +103,14 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
97
103
  return tuple(safe_map(operator.index, x))
98
104
 
99
105
 
106
+ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
107
+ """Convert x to a tuple of strings."""
108
+ if isinstance(x, str):
109
+ return (x,)
110
+ else:
111
+ return tuple(safe_map(_ensure_str, x))
112
+
113
+
100
114
  def _jax_v04_new_arg_fn(frame, trace, aval):
101
115
  """
102
116
  Transform a new argument to a tracer.
@@ -155,6 +169,9 @@ def _init_state_trace_stack(name) -> StateTraceStack:
155
169
  return state_trace
156
170
 
157
171
 
172
+ default_cache_key = ((), ())
173
+
174
+
158
175
  class StatefulFunction(PrettyObject):
159
176
  """
160
177
  A wrapper class for a function that collects the states that are read and written by the function. The states are
@@ -170,6 +187,7 @@ class StatefulFunction(PrettyObject):
170
187
  arguments and return value should be arrays, scalars, or standard Python
171
188
  containers (tuple/list/dict) thereof.
172
189
  static_argnums: See the :py:func:`jax.jit` docstring.
190
+ static_argnames: See the :py:func:`jax.jit` docstring.
173
191
  axis_env: Optional, a sequence of pairs where the first element is an axis
174
192
  name and the second element is a positive integer representing the size of
175
193
  the mapped axis with that name. This parameter is useful when lowering
@@ -199,6 +217,7 @@ class StatefulFunction(PrettyObject):
199
217
  self,
200
218
  fun: Callable,
201
219
  static_argnums: Union[int, Iterable[int]] = (),
220
+ static_argnames: Union[str, Iterable[str]] = (),
202
221
  axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
203
222
  abstracted_axes: Optional[Any] = None,
204
223
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
@@ -207,11 +226,12 @@ class StatefulFunction(PrettyObject):
207
226
  ):
208
227
  # explicit parameters
209
228
  self.fun = fun
210
- self.static_argnums = _ensure_index_tuple(tuple() if static_argnums is None else static_argnums)
229
+ self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
230
+ self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
211
231
  self.axis_env = axis_env
212
232
  self.abstracted_axes = abstracted_axes
213
233
  self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
214
- assert cache_type in [None, 'jit']
234
+ assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
215
235
  self.name = name
216
236
 
217
237
  # implicit parameters
@@ -226,7 +246,7 @@ class StatefulFunction(PrettyObject):
226
246
  return None
227
247
  return k, v
228
248
 
229
- def get_jaxpr(self, cache_key: Hashable = ()) -> ClosedJaxpr:
249
+ def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
230
250
  """
231
251
  Read the JAX Jaxpr representation of the function.
232
252
 
@@ -236,11 +256,13 @@ class StatefulFunction(PrettyObject):
236
256
  Returns:
237
257
  The JAX Jaxpr representation of the function.
238
258
  """
259
+ if cache_key is None:
260
+ cache_key = default_cache_key
239
261
  if cache_key not in self._cached_jaxpr:
240
262
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
241
263
  return self._cached_jaxpr[cache_key]
242
264
 
243
- def get_out_shapes(self, cache_key: Hashable = ()) -> PyTree:
265
+ def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
244
266
  """
245
267
  Read the output shapes of the function.
246
268
 
@@ -250,11 +272,13 @@ class StatefulFunction(PrettyObject):
250
272
  Returns:
251
273
  The output shapes of the function.
252
274
  """
275
+ if cache_key is None:
276
+ cache_key = default_cache_key
253
277
  if cache_key not in self._cached_out_shapes:
254
278
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
255
279
  return self._cached_out_shapes[cache_key]
256
280
 
257
- def get_out_treedef(self, cache_key: Hashable = ()) -> PyTree:
281
+ def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
258
282
  """
259
283
  Read the output tree of the function.
260
284
 
@@ -264,11 +288,13 @@ class StatefulFunction(PrettyObject):
264
288
  Returns:
265
289
  The output tree of the function.
266
290
  """
291
+ if cache_key is None:
292
+ cache_key = default_cache_key
267
293
  if cache_key not in self._cached_jaxpr_out_tree:
268
294
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
269
295
  return self._cached_jaxpr_out_tree[cache_key]
270
296
 
271
- def get_state_trace(self, cache_key: Hashable = ()) -> StateTraceStack:
297
+ def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
272
298
  """
273
299
  Read the state trace of the function.
274
300
 
@@ -278,11 +304,13 @@ class StatefulFunction(PrettyObject):
278
304
  Returns:
279
305
  The state trace of the function.
280
306
  """
307
+ if cache_key is None:
308
+ cache_key = default_cache_key
281
309
  if cache_key not in self._cached_state_trace:
282
310
  raise ValueError(f"the function is not called with the static arguments: {cache_key}")
283
311
  return self._cached_state_trace[cache_key]
284
312
 
285
- def get_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
313
+ def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
286
314
  """
287
315
  Read the states that are read and written by the function.
288
316
 
@@ -292,9 +320,11 @@ class StatefulFunction(PrettyObject):
292
320
  Returns:
293
321
  The states that are read and written by the function.
294
322
  """
323
+ if cache_key is None:
324
+ cache_key = default_cache_key
295
325
  return tuple(self.get_state_trace(cache_key).states)
296
326
 
297
- def get_read_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
327
+ def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
298
328
  """
299
329
  Read the states that are read by the function.
300
330
 
@@ -304,9 +334,11 @@ class StatefulFunction(PrettyObject):
304
334
  Returns:
305
335
  The states that are read by the function.
306
336
  """
337
+ if cache_key is None:
338
+ cache_key = default_cache_key
307
339
  return self.get_state_trace(cache_key).get_read_states()
308
340
 
309
- def get_write_states(self, cache_key: Hashable = ()) -> Tuple[State, ...]:
341
+ def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
310
342
  """
311
343
  Read the states that are written by the function.
312
344
 
@@ -316,6 +348,8 @@ class StatefulFunction(PrettyObject):
316
348
  Returns:
317
349
  The states that are written by the function.
318
350
  """
351
+ if cache_key is None:
352
+ cache_key = default_cache_key
319
353
  return self.get_state_trace(cache_key).get_write_states()
320
354
 
321
355
  def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
@@ -323,10 +357,11 @@ class StatefulFunction(PrettyObject):
323
357
  Get the static arguments from the arguments.
324
358
 
325
359
  Args:
326
- *args: The arguments to the function.
360
+ *args: The arguments to the function.
361
+ **kwargs: The keyword arguments to the function.
327
362
 
328
363
  Returns:
329
- The static arguments.
364
+ The static arguments and keyword arguments as a tuple.
330
365
  """
331
366
  if self.cache_type == 'jit':
332
367
  static_args, dyn_args = [], []
@@ -336,11 +371,18 @@ class StatefulFunction(PrettyObject):
336
371
  else:
337
372
  dyn_args.append(arg)
338
373
  dyn_args = jax.tree.map(shaped_abstractify, jax.tree.leaves(dyn_args))
339
- dyn_kwargs = jax.tree.map(shaped_abstractify, jax.tree.leaves(kwargs))
340
- return tuple([tuple(static_args), tuple(dyn_args), tuple(dyn_kwargs)])
374
+ static_kwargs, dyn_kwargs = [], []
375
+ for k, v in kwargs.items():
376
+ if k in self.static_argnames:
377
+ static_kwargs.append((k, v))
378
+ else:
379
+ dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
380
+ return tuple([tuple(static_args), tuple(dyn_args), tuple(static_kwargs), tuple(dyn_kwargs)])
341
381
  elif self.cache_type is None:
342
382
  num_arg = len(args)
343
- return tuple(args[i] for i in self.static_argnums if i < num_arg)
383
+ static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
384
+ static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
385
+ return tuple([static_args, static_kwargs])
344
386
  else:
345
387
  raise ValueError(f"Invalid cache type: {self.cache_type}")
346
388
 
@@ -389,7 +431,7 @@ class StatefulFunction(PrettyObject):
389
431
  self._cached_state_trace.clear()
390
432
 
391
433
  def _wrapped_fun_to_eval(
392
- self, cache_key, *args, return_only_write: bool = False, **kwargs,
434
+ self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
393
435
  ) -> Tuple[Any, Tuple[State, ...]]:
394
436
  """
395
437
  Wrap the function and return the states that are read and written by the function and the output of the function.
@@ -405,7 +447,7 @@ class StatefulFunction(PrettyObject):
405
447
  state_trace = _init_state_trace_stack(self.name)
406
448
  self._cached_state_trace[cache_key] = state_trace
407
449
  with state_trace:
408
- out = self.fun(*args, **kwargs)
450
+ out = self.fun(*args, **dyn_kwargs, **static_kwargs)
409
451
  state_values = (
410
452
  state_trace.get_write_state_values(True)
411
453
  if return_only_write else
@@ -430,8 +472,9 @@ class StatefulFunction(PrettyObject):
430
472
  the structure, shape, dtypes, and named shapes of the output of ``fun``.
431
473
 
432
474
  Args:
433
- *args: The arguments to the function.
434
- **kwargs: The keyword arguments to the function.
475
+ *args: The arguments to the function.
476
+ **kwargs: The keyword arguments to the function.
477
+ return_only_write: If True, only return the states that are written by the function.
435
478
  """
436
479
 
437
480
  # static args
@@ -440,17 +483,24 @@ class StatefulFunction(PrettyObject):
440
483
  if cache_key not in self._cached_state_trace:
441
484
  try:
442
485
  # jaxpr
486
+ static_kwargs, dyn_kwargs = {}, {}
487
+ for k, v in kwargs.items():
488
+ if k in self.static_argnames:
489
+ static_kwargs[k] = v
490
+ else:
491
+ dyn_kwargs[k] = v
443
492
  jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
444
493
  functools.partial(
445
494
  self._wrapped_fun_to_eval,
446
495
  cache_key,
496
+ static_kwargs,
447
497
  return_only_write=return_only_write
448
498
  ),
449
499
  static_argnums=self.static_argnums,
450
500
  axis_env=self.axis_env,
451
501
  return_shape=True,
452
502
  abstracted_axes=self.abstracted_axes
453
- )(*args, **kwargs)
503
+ )(*args, **dyn_kwargs)
454
504
  # returns
455
505
  self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
456
506
  self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
@@ -483,6 +533,7 @@ class StatefulFunction(PrettyObject):
483
533
  assert len(state_vals) == len(states), 'State length mismatch.'
484
534
 
485
535
  # parameters
536
+ kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
486
537
  args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
487
538
  args = jax.tree.flatten((args, kwargs, state_vals))[0]
488
539
 
@@ -519,12 +570,16 @@ class StatefulFunction(PrettyObject):
519
570
  def make_jaxpr(
520
571
  fun: Callable,
521
572
  static_argnums: Union[int, Iterable[int]] = (),
573
+ static_argnames: Union[str, Iterable[str]] = (),
522
574
  axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
523
575
  return_shape: bool = False,
524
576
  abstracted_axes: Optional[Any] = None,
525
577
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
526
- ) -> Callable[..., (Tuple[ClosedJaxpr, Tuple[State, ...]] |
527
- Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])]:
578
+ ) -> Callable[
579
+ ...,
580
+ (Tuple[ClosedJaxpr, Tuple[State, ...]] |
581
+ Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
582
+ ]:
528
583
  """
529
584
  Creates a function that produces its jaxpr given example args.
530
585
 
@@ -533,6 +588,7 @@ def make_jaxpr(
533
588
  arguments and return value should be arrays, scalars, or standard Python
534
589
  containers (tuple/list/dict) thereof.
535
590
  static_argnums: See the :py:func:`jax.jit` docstring.
591
+ static_argnames: See the :py:func:`jax.jit` docstring.
536
592
  axis_env: Optional, a sequence of pairs where the first element is an axis
537
593
  name and the second element is a positive integer representing the size of
538
594
  the mapped axis with that name. This parameter is useful when lowering
@@ -605,11 +661,11 @@ def make_jaxpr(
605
661
  stateful_fun = StatefulFunction(
606
662
  fun,
607
663
  static_argnums=static_argnums,
664
+ static_argnames=static_argnames,
608
665
  axis_env=axis_env,
609
666
  abstracted_axes=abstracted_axes,
610
667
  state_returns=state_returns,
611
668
  name='make_jaxpr'
612
-
613
669
  )
614
670
 
615
671
  @wraps(fun)