brainstate 0.1.3__py2.py3-none-any.whl → 0.1.5__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +1 -16
- brainstate/_state.py +1 -0
- brainstate/augment/_mapping.py +9 -9
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_jit.py +14 -5
- brainstate/compile/_make_jaxpr.py +78 -22
- brainstate/compile/_make_jaxpr_test.py +13 -2
- brainstate/graph/_graph_node.py +1 -1
- brainstate/graph/_graph_operation.py +4 -4
- brainstate/mixin.py +31 -2
- brainstate/nn/__init__.py +8 -5
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_delay.py +13 -1
- brainstate/nn/_dropout.py +5 -4
- brainstate/nn/_dynamics.py +39 -44
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear_mv.py +1 -1
- brainstate/nn/_module.py +5 -5
- brainstate/nn/_projection.py +190 -98
- brainstate/nn/_synapse.py +5 -9
- brainstate/nn/_synaptic_projection.py +376 -86
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +1 -1
- brainstate/typing.py +1 -1
- brainstate/util/__init__.py +14 -14
- brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/METADATA +1 -1
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/RECORD +42 -42
- /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
- /brainstate/util/{_caller.py → caller.py} +0 -0
- /brainstate/util/{_error.py → error.py} +0 -0
- /brainstate/util/{_others.py → others.py} +0 -0
- /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
- /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
- /brainstate/util/{_scaling.py → scaling.py} +0 -0
- /brainstate/util/{_struct.py → struct.py} +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/LICENSE +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/WHEEL +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_compatible_import.py
CHANGED
@@ -16,10 +16,9 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
-
import importlib.util
|
20
19
|
from contextlib import contextmanager
|
21
20
|
from functools import partial
|
22
|
-
from typing import Iterable, Hashable, TypeVar, Callable
|
21
|
+
from typing import Iterable, Hashable, TypeVar, Callable
|
23
22
|
|
24
23
|
import jax
|
25
24
|
|
@@ -31,7 +30,6 @@ __all__ = [
|
|
31
30
|
'get_aval',
|
32
31
|
'Tracer',
|
33
32
|
'to_concrete_aval',
|
34
|
-
'brainevent',
|
35
33
|
'safe_map',
|
36
34
|
'safe_zip',
|
37
35
|
'unzip2',
|
@@ -47,8 +45,6 @@ T3 = TypeVar("T3")
|
|
47
45
|
|
48
46
|
from saiunit._compatible_import import wrap_init
|
49
47
|
|
50
|
-
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
51
|
-
|
52
48
|
from jax.core import get_aval, Tracer
|
53
49
|
|
54
50
|
if jax.__version_info__ < (0, 5, 0):
|
@@ -150,14 +146,3 @@ def to_concrete_aval(aval):
|
|
150
146
|
return aval.to_concrete_value()
|
151
147
|
return aval
|
152
148
|
|
153
|
-
|
154
|
-
if not brainevent_installed:
|
155
|
-
if not TYPE_CHECKING:
|
156
|
-
class BrainEvent:
|
157
|
-
def __getattr__(self, item):
|
158
|
-
raise ImportError('brainevent is not installed, please install brainevent first.')
|
159
|
-
|
160
|
-
brainevent = BrainEvent()
|
161
|
-
|
162
|
-
else:
|
163
|
-
import brainevent
|
brainstate/_state.py
CHANGED
brainstate/augment/_mapping.py
CHANGED
@@ -185,10 +185,10 @@ def _compile_stateful_function(
|
|
185
185
|
if isinstance(in_axes, int):
|
186
186
|
args = jax.tree.map(lambda x: _remove_axis(x, in_axes), args)
|
187
187
|
elif isinstance(in_axes, tuple):
|
188
|
-
args = tuple(
|
189
|
-
|
190
|
-
|
191
|
-
)
|
188
|
+
args = tuple([
|
189
|
+
arg if in_axis is None else _remove_axis(arg, in_axis)
|
190
|
+
for arg, in_axis in zip(args, in_axes)
|
191
|
+
])
|
192
192
|
stateful_fn.make_jaxpr(state_vals, args)
|
193
193
|
return stateful_fn.get_arg_cache_key(state_vals, args)
|
194
194
|
|
@@ -383,10 +383,7 @@ def _vmap_transform(
|
|
383
383
|
stateful_fn.axis_env = axis_env
|
384
384
|
|
385
385
|
# stateful function
|
386
|
-
stateful_fn = StatefulFunction(
|
387
|
-
_vmap_fn_for_compilation,
|
388
|
-
name='vmap',
|
389
|
-
)
|
386
|
+
stateful_fn = StatefulFunction(_vmap_fn_for_compilation, name='vmap')
|
390
387
|
|
391
388
|
@functools.wraps(f)
|
392
389
|
def new_fn_for_vmap(
|
@@ -460,7 +457,10 @@ def _vmap_transform(
|
|
460
457
|
# analyze vmapping axis error
|
461
458
|
for state in state_trace.get_write_states():
|
462
459
|
leaves = jax.tree.leaves(state.value)
|
463
|
-
if
|
460
|
+
if (
|
461
|
+
any([isinstance(leaf, BatchTracer) and (leaf.batch_dim is not None) for leaf in leaves])
|
462
|
+
and state not in out_state_to_axis
|
463
|
+
):
|
464
464
|
if isinstance(state, RandomState) and state in rng_sets:
|
465
465
|
continue
|
466
466
|
state.raise_error_with_source_info(
|
@@ -19,6 +19,8 @@ import unittest
|
|
19
19
|
import jax
|
20
20
|
import jax.numpy as jnp
|
21
21
|
import numpy as np
|
22
|
+
from jax import vmap
|
23
|
+
from jax.lax import psum, pmean, pmax
|
22
24
|
|
23
25
|
import brainstate
|
24
26
|
import brainstate.augment
|
@@ -433,3 +435,163 @@ class TestVMAPNewStatesEdgeCases(unittest.TestCase):
|
|
433
435
|
foo.c = brainstate.State(jnp.arange(3)) # Original expected shape is (4,)
|
434
436
|
|
435
437
|
faulty_init()
|
438
|
+
|
439
|
+
|
440
|
+
class TestAxisName:
|
441
|
+
def test1(self):
|
442
|
+
def compute_stats_with_axis_name(x):
|
443
|
+
"""Compute statistics using named axis operations"""
|
444
|
+
# Sum across the named axis 'batch'
|
445
|
+
total_sum = psum(x, axis_name='batch')
|
446
|
+
|
447
|
+
# Mean across the named axis 'batch'
|
448
|
+
mean_val = pmean(x, axis_name='batch')
|
449
|
+
|
450
|
+
# Max across the named axis 'batch'
|
451
|
+
max_val = pmax(x, axis_name='batch')
|
452
|
+
|
453
|
+
return {
|
454
|
+
'sum': total_sum,
|
455
|
+
'mean': mean_val,
|
456
|
+
'max': max_val,
|
457
|
+
'original': x
|
458
|
+
}
|
459
|
+
|
460
|
+
batch_data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
|
461
|
+
print("Input batch data:", batch_data)
|
462
|
+
|
463
|
+
# vmap with axis name 'batch'
|
464
|
+
vectorized_stats_jax = jax.jit(vmap(compute_stats_with_axis_name, axis_name='batch'))
|
465
|
+
result_jax = vectorized_stats_jax(batch_data)
|
466
|
+
|
467
|
+
# vmap with axis name 'batch'
|
468
|
+
vectorized_stats = brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
469
|
+
result = vectorized_stats(batch_data)
|
470
|
+
|
471
|
+
# vmap with axis name 'batch'
|
472
|
+
vectorized_stats_v2 = brainstate.transform.jit(
|
473
|
+
brainstate.transform.vmap(compute_stats_with_axis_name, axis_name='batch')
|
474
|
+
)
|
475
|
+
result_v2 = vectorized_stats_v2(batch_data)
|
476
|
+
|
477
|
+
for key in result_jax.keys():
|
478
|
+
print(f" {key}: {result_jax[key]}")
|
479
|
+
assert jnp.allclose(result_jax[key], result[key]), f"Mismatch in {key}"
|
480
|
+
assert jnp.allclose(result_jax[key], result_v2[key]), f"Mismatch in {key}"
|
481
|
+
|
482
|
+
def test_nested_vmap(self):
|
483
|
+
def nested_computation(x):
|
484
|
+
"""Computation with multiple named axes"""
|
485
|
+
# Sum over 'inner' axis, then mean over 'outer' axis
|
486
|
+
inner_sum = psum(x, axis_name='inner')
|
487
|
+
outer_mean = pmean(inner_sum, axis_name='outer')
|
488
|
+
return outer_mean
|
489
|
+
|
490
|
+
# Create 2D batch data
|
491
|
+
data_2d = jnp.arange(12.0).reshape(3, 4) # Shape: [outer_batch=3, inner_batch=4]
|
492
|
+
print("Input 2D data shape:", data_2d.shape)
|
493
|
+
print("Input 2D data:\n", data_2d)
|
494
|
+
|
495
|
+
# Nested vmap: first over inner dimension, then outer dimension
|
496
|
+
inner_vmap = vmap(nested_computation, axis_name='inner')
|
497
|
+
nested_vmap = vmap(inner_vmap, axis_name='outer')
|
498
|
+
|
499
|
+
result_2d = nested_vmap(data_2d)
|
500
|
+
print("Result after nested vmap:", result_2d)
|
501
|
+
|
502
|
+
inner_vmap_bst = brainstate.transform.vmap(nested_computation, axis_name='inner')
|
503
|
+
nested_vmap_bst = brainstate.transform.vmap(inner_vmap_bst, axis_name='outer')
|
504
|
+
result_2d_bst = nested_vmap_bst(data_2d)
|
505
|
+
print("Result after nested vmap:", result_2d_bst)
|
506
|
+
|
507
|
+
assert jnp.allclose(result_2d, result_2d_bst)
|
508
|
+
|
509
|
+
def _gradient_averaging_simulation_bst(self):
|
510
|
+
def loss_function(params, x, y):
|
511
|
+
"""Simple quadratic loss"""
|
512
|
+
pred = params * x
|
513
|
+
return (pred - y) ** 2
|
514
|
+
|
515
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
516
|
+
"""Compute gradients and average them across the batch"""
|
517
|
+
# Compute per-sample gradients
|
518
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
519
|
+
per_sample_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
520
|
+
|
521
|
+
# Average gradients across batch using named axis
|
522
|
+
def average_grads(grads):
|
523
|
+
return pmean(grads, axis_name='batch')
|
524
|
+
|
525
|
+
# Apply averaging with named axis
|
526
|
+
averaged_grads = vmap(average_grads, axis_name='batch')(per_sample_grads)
|
527
|
+
return averaged_grads
|
528
|
+
|
529
|
+
# Example data
|
530
|
+
params = 2.0
|
531
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
532
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
533
|
+
|
534
|
+
print("Parameters:", params)
|
535
|
+
print("Batch X:", batch_x)
|
536
|
+
print("Batch Y:", batch_y)
|
537
|
+
|
538
|
+
# Compute individual gradients first
|
539
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
540
|
+
individual_grads = vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
541
|
+
print("Individual gradients:", individual_grads)
|
542
|
+
|
543
|
+
# Now compute averaged gradients using axis names
|
544
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
545
|
+
print("Averaged gradients:", averaged_grads)
|
546
|
+
|
547
|
+
return individual_grads, averaged_grads
|
548
|
+
|
549
|
+
def _gradient_averaging_simulation_jax(self):
|
550
|
+
def loss_function(params, x, y):
|
551
|
+
"""Simple quadratic loss"""
|
552
|
+
pred = params * x
|
553
|
+
return (pred - y) ** 2
|
554
|
+
|
555
|
+
def compute_gradients_with_averaging(params, batch_x, batch_y):
|
556
|
+
"""Compute gradients and average them across the batch"""
|
557
|
+
# Compute per-sample gradients
|
558
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
559
|
+
per_sample_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
560
|
+
|
561
|
+
# Average gradients across batch using named axis
|
562
|
+
def average_grads(grads):
|
563
|
+
return pmean(grads, axis_name='batch')
|
564
|
+
|
565
|
+
# Apply averaging with named axis
|
566
|
+
averaged_grads = brainstate.transform.vmap(average_grads, axis_name='batch')(per_sample_grads)
|
567
|
+
return averaged_grads
|
568
|
+
|
569
|
+
# Example data
|
570
|
+
params = 2.0
|
571
|
+
batch_x = jnp.array([1.0, 2.0, 3.0, 4.0])
|
572
|
+
batch_y = jnp.array([2.0, 4.0, 7.0, 8.0])
|
573
|
+
|
574
|
+
print("Parameters:", params)
|
575
|
+
print("Batch X:", batch_x)
|
576
|
+
print("Batch Y:", batch_y)
|
577
|
+
|
578
|
+
# Compute individual gradients first
|
579
|
+
grad_fn = jax.grad(loss_function, argnums=0)
|
580
|
+
individual_grads = brainstate.transform.vmap(grad_fn, in_axes=(None, 0, 0))(params, batch_x, batch_y)
|
581
|
+
print("Individual gradients:", individual_grads)
|
582
|
+
|
583
|
+
# Now compute averaged gradients using axis names
|
584
|
+
averaged_grads = compute_gradients_with_averaging(params, batch_x, batch_y)
|
585
|
+
print("Averaged gradients:", averaged_grads)
|
586
|
+
|
587
|
+
return individual_grads, averaged_grads
|
588
|
+
|
589
|
+
def test_gradient_averaging_simulation(self):
|
590
|
+
individual_grads, averaged_grads = self._gradient_averaging_simulation_bst()
|
591
|
+
individual_grads_jax, averaged_grads_jax = self._gradient_averaging_simulation_jax()
|
592
|
+
assert jnp.allclose(individual_grads, individual_grads_jax)
|
593
|
+
assert jnp.allclose(averaged_grads, averaged_grads_jax)
|
594
|
+
|
595
|
+
|
596
|
+
|
597
|
+
|
brainstate/compile/_jit.py
CHANGED
@@ -51,6 +51,7 @@ def _get_jitted_fun(
|
|
51
51
|
out_shardings,
|
52
52
|
static_argnums,
|
53
53
|
donate_argnums,
|
54
|
+
static_argnames,
|
54
55
|
donate_argnames,
|
55
56
|
keep_unused,
|
56
57
|
device,
|
@@ -59,10 +60,12 @@ def _get_jitted_fun(
|
|
59
60
|
abstracted_axes,
|
60
61
|
**kwargs
|
61
62
|
) -> JittedFunction:
|
62
|
-
static_argnums =
|
63
|
+
static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
64
|
+
donate_argnums = tuple() if donate_argnums is None else _ensure_index_tuple(donate_argnums)
|
63
65
|
fun = StatefulFunction(
|
64
66
|
fun,
|
65
67
|
static_argnums=static_argnums,
|
68
|
+
static_argnames=static_argnames,
|
66
69
|
abstracted_axes=abstracted_axes,
|
67
70
|
cache_type='jit',
|
68
71
|
name='jit'
|
@@ -70,7 +73,8 @@ def _get_jitted_fun(
|
|
70
73
|
jit_fun = jax.jit(
|
71
74
|
fun.jaxpr_call,
|
72
75
|
static_argnums=tuple(i + 1 for i in static_argnums),
|
73
|
-
|
76
|
+
static_argnames=static_argnames,
|
77
|
+
donate_argnums=tuple(i + 1 for i in donate_argnums),
|
74
78
|
donate_argnames=donate_argnames,
|
75
79
|
keep_unused=keep_unused,
|
76
80
|
device=device,
|
@@ -179,6 +183,7 @@ def jit(
|
|
179
183
|
out_shardings=sharding_impls.UNSPECIFIED,
|
180
184
|
static_argnums: int | Sequence[int] | None = None,
|
181
185
|
donate_argnums: int | Sequence[int] | None = None,
|
186
|
+
static_argnames: str | Sequence[str] | None = None,
|
182
187
|
donate_argnames: str | Iterable[str] | None = None,
|
183
188
|
keep_unused: bool = False,
|
184
189
|
device: Device | None = None,
|
@@ -190,9 +195,6 @@ def jit(
|
|
190
195
|
"""
|
191
196
|
Sets up ``fun`` for just-in-time compilation with XLA.
|
192
197
|
|
193
|
-
Does not support setting ``static_argnames`` as in ``jax.jit()``.
|
194
|
-
|
195
|
-
|
196
198
|
Args:
|
197
199
|
fun: Function to be jitted.
|
198
200
|
in_shardings: Pytree of structure matching that of arguments to ``fun``,
|
@@ -246,6 +248,11 @@ def jit(
|
|
246
248
|
provided, ``inspect.signature`` is not used, and only actual
|
247
249
|
parameters listed in either ``static_argnums`` or ``static_argnames`` will
|
248
250
|
be treated as static.
|
251
|
+
static_argnames: An optional string or collection of strings specifying
|
252
|
+
which named arguments are treated as static (compile-time constant).
|
253
|
+
Operations that only depend on static arguments will be constant-folded in
|
254
|
+
Python (during tracing), and so the corresponding argument values can be
|
255
|
+
any Python object.
|
249
256
|
donate_argnums: Specify which positional argument buffers are "donated" to
|
250
257
|
the computation. It is safe to donate argument buffers if you no longer
|
251
258
|
need them once the computation has finished. In some cases XLA can make
|
@@ -309,6 +316,7 @@ def jit(
|
|
309
316
|
out_shardings=out_shardings,
|
310
317
|
static_argnums=static_argnums,
|
311
318
|
donate_argnums=donate_argnums,
|
319
|
+
static_argnames=static_argnames,
|
312
320
|
donate_argnames=donate_argnames,
|
313
321
|
keep_unused=keep_unused,
|
314
322
|
device=device,
|
@@ -327,6 +335,7 @@ def jit(
|
|
327
335
|
out_shardings,
|
328
336
|
static_argnums,
|
329
337
|
donate_argnums,
|
338
|
+
static_argnames,
|
330
339
|
donate_argnames,
|
331
340
|
keep_unused,
|
332
341
|
device,
|
@@ -88,6 +88,12 @@ __all__ = [
|
|
88
88
|
]
|
89
89
|
|
90
90
|
|
91
|
+
def _ensure_str(x: str) -> str:
|
92
|
+
if not isinstance(x, str):
|
93
|
+
raise TypeError(f"argument is not a string: {x}")
|
94
|
+
return x
|
95
|
+
|
96
|
+
|
91
97
|
def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
92
98
|
"""Convert x to a tuple of indices."""
|
93
99
|
x = jax.core.concrete_or_error(None, x, "expected a static index or sequence of indices.")
|
@@ -97,6 +103,14 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
|
97
103
|
return tuple(safe_map(operator.index, x))
|
98
104
|
|
99
105
|
|
106
|
+
def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
|
107
|
+
"""Convert x to a tuple of strings."""
|
108
|
+
if isinstance(x, str):
|
109
|
+
return (x,)
|
110
|
+
else:
|
111
|
+
return tuple(safe_map(_ensure_str, x))
|
112
|
+
|
113
|
+
|
100
114
|
def _jax_v04_new_arg_fn(frame, trace, aval):
|
101
115
|
"""
|
102
116
|
Transform a new argument to a tracer.
|
@@ -155,6 +169,9 @@ def _init_state_trace_stack(name) -> StateTraceStack:
|
|
155
169
|
return state_trace
|
156
170
|
|
157
171
|
|
172
|
+
default_cache_key = ((), ())
|
173
|
+
|
174
|
+
|
158
175
|
class StatefulFunction(PrettyObject):
|
159
176
|
"""
|
160
177
|
A wrapper class for a function that collects the states that are read and written by the function. The states are
|
@@ -170,6 +187,7 @@ class StatefulFunction(PrettyObject):
|
|
170
187
|
arguments and return value should be arrays, scalars, or standard Python
|
171
188
|
containers (tuple/list/dict) thereof.
|
172
189
|
static_argnums: See the :py:func:`jax.jit` docstring.
|
190
|
+
static_argnames: See the :py:func:`jax.jit` docstring.
|
173
191
|
axis_env: Optional, a sequence of pairs where the first element is an axis
|
174
192
|
name and the second element is a positive integer representing the size of
|
175
193
|
the mapped axis with that name. This parameter is useful when lowering
|
@@ -199,6 +217,7 @@ class StatefulFunction(PrettyObject):
|
|
199
217
|
self,
|
200
218
|
fun: Callable,
|
201
219
|
static_argnums: Union[int, Iterable[int]] = (),
|
220
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
202
221
|
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
203
222
|
abstracted_axes: Optional[Any] = None,
|
204
223
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
|
@@ -207,11 +226,12 @@ class StatefulFunction(PrettyObject):
|
|
207
226
|
):
|
208
227
|
# explicit parameters
|
209
228
|
self.fun = fun
|
210
|
-
self.static_argnums =
|
229
|
+
self.static_argnums = tuple() if static_argnums is None else _ensure_index_tuple(static_argnums)
|
230
|
+
self.static_argnames = tuple() if static_argnames is None else _ensure_str_tuple(static_argnames)
|
211
231
|
self.axis_env = axis_env
|
212
232
|
self.abstracted_axes = abstracted_axes
|
213
233
|
self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
|
214
|
-
assert cache_type in [None, 'jit']
|
234
|
+
assert cache_type in [None, 'jit'], f"Invalid cache type: {cache_type}"
|
215
235
|
self.name = name
|
216
236
|
|
217
237
|
# implicit parameters
|
@@ -226,7 +246,7 @@ class StatefulFunction(PrettyObject):
|
|
226
246
|
return None
|
227
247
|
return k, v
|
228
248
|
|
229
|
-
def get_jaxpr(self, cache_key: Hashable =
|
249
|
+
def get_jaxpr(self, cache_key: Hashable = None) -> ClosedJaxpr:
|
230
250
|
"""
|
231
251
|
Read the JAX Jaxpr representation of the function.
|
232
252
|
|
@@ -236,11 +256,13 @@ class StatefulFunction(PrettyObject):
|
|
236
256
|
Returns:
|
237
257
|
The JAX Jaxpr representation of the function.
|
238
258
|
"""
|
259
|
+
if cache_key is None:
|
260
|
+
cache_key = default_cache_key
|
239
261
|
if cache_key not in self._cached_jaxpr:
|
240
262
|
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
241
263
|
return self._cached_jaxpr[cache_key]
|
242
264
|
|
243
|
-
def get_out_shapes(self, cache_key: Hashable =
|
265
|
+
def get_out_shapes(self, cache_key: Hashable = None) -> PyTree:
|
244
266
|
"""
|
245
267
|
Read the output shapes of the function.
|
246
268
|
|
@@ -250,11 +272,13 @@ class StatefulFunction(PrettyObject):
|
|
250
272
|
Returns:
|
251
273
|
The output shapes of the function.
|
252
274
|
"""
|
275
|
+
if cache_key is None:
|
276
|
+
cache_key = default_cache_key
|
253
277
|
if cache_key not in self._cached_out_shapes:
|
254
278
|
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
255
279
|
return self._cached_out_shapes[cache_key]
|
256
280
|
|
257
|
-
def get_out_treedef(self, cache_key: Hashable =
|
281
|
+
def get_out_treedef(self, cache_key: Hashable = None) -> PyTree:
|
258
282
|
"""
|
259
283
|
Read the output tree of the function.
|
260
284
|
|
@@ -264,11 +288,13 @@ class StatefulFunction(PrettyObject):
|
|
264
288
|
Returns:
|
265
289
|
The output tree of the function.
|
266
290
|
"""
|
291
|
+
if cache_key is None:
|
292
|
+
cache_key = default_cache_key
|
267
293
|
if cache_key not in self._cached_jaxpr_out_tree:
|
268
294
|
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
269
295
|
return self._cached_jaxpr_out_tree[cache_key]
|
270
296
|
|
271
|
-
def get_state_trace(self, cache_key: Hashable =
|
297
|
+
def get_state_trace(self, cache_key: Hashable = None) -> StateTraceStack:
|
272
298
|
"""
|
273
299
|
Read the state trace of the function.
|
274
300
|
|
@@ -278,11 +304,13 @@ class StatefulFunction(PrettyObject):
|
|
278
304
|
Returns:
|
279
305
|
The state trace of the function.
|
280
306
|
"""
|
307
|
+
if cache_key is None:
|
308
|
+
cache_key = default_cache_key
|
281
309
|
if cache_key not in self._cached_state_trace:
|
282
310
|
raise ValueError(f"the function is not called with the static arguments: {cache_key}")
|
283
311
|
return self._cached_state_trace[cache_key]
|
284
312
|
|
285
|
-
def get_states(self, cache_key: Hashable =
|
313
|
+
def get_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
286
314
|
"""
|
287
315
|
Read the states that are read and written by the function.
|
288
316
|
|
@@ -292,9 +320,11 @@ class StatefulFunction(PrettyObject):
|
|
292
320
|
Returns:
|
293
321
|
The states that are read and written by the function.
|
294
322
|
"""
|
323
|
+
if cache_key is None:
|
324
|
+
cache_key = default_cache_key
|
295
325
|
return tuple(self.get_state_trace(cache_key).states)
|
296
326
|
|
297
|
-
def get_read_states(self, cache_key: Hashable =
|
327
|
+
def get_read_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
298
328
|
"""
|
299
329
|
Read the states that are read by the function.
|
300
330
|
|
@@ -304,9 +334,11 @@ class StatefulFunction(PrettyObject):
|
|
304
334
|
Returns:
|
305
335
|
The states that are read by the function.
|
306
336
|
"""
|
337
|
+
if cache_key is None:
|
338
|
+
cache_key = default_cache_key
|
307
339
|
return self.get_state_trace(cache_key).get_read_states()
|
308
340
|
|
309
|
-
def get_write_states(self, cache_key: Hashable =
|
341
|
+
def get_write_states(self, cache_key: Hashable = None) -> Tuple[State, ...]:
|
310
342
|
"""
|
311
343
|
Read the states that are written by the function.
|
312
344
|
|
@@ -316,6 +348,8 @@ class StatefulFunction(PrettyObject):
|
|
316
348
|
Returns:
|
317
349
|
The states that are written by the function.
|
318
350
|
"""
|
351
|
+
if cache_key is None:
|
352
|
+
cache_key = default_cache_key
|
319
353
|
return self.get_state_trace(cache_key).get_write_states()
|
320
354
|
|
321
355
|
def get_arg_cache_key(self, *args, **kwargs) -> Tuple:
|
@@ -323,10 +357,11 @@ class StatefulFunction(PrettyObject):
|
|
323
357
|
Get the static arguments from the arguments.
|
324
358
|
|
325
359
|
Args:
|
326
|
-
|
360
|
+
*args: The arguments to the function.
|
361
|
+
**kwargs: The keyword arguments to the function.
|
327
362
|
|
328
363
|
Returns:
|
329
|
-
The static arguments.
|
364
|
+
The static arguments and keyword arguments as a tuple.
|
330
365
|
"""
|
331
366
|
if self.cache_type == 'jit':
|
332
367
|
static_args, dyn_args = [], []
|
@@ -336,11 +371,18 @@ class StatefulFunction(PrettyObject):
|
|
336
371
|
else:
|
337
372
|
dyn_args.append(arg)
|
338
373
|
dyn_args = jax.tree.map(shaped_abstractify, jax.tree.leaves(dyn_args))
|
339
|
-
dyn_kwargs =
|
340
|
-
|
374
|
+
static_kwargs, dyn_kwargs = [], []
|
375
|
+
for k, v in kwargs.items():
|
376
|
+
if k in self.static_argnames:
|
377
|
+
static_kwargs.append((k, v))
|
378
|
+
else:
|
379
|
+
dyn_kwargs.append((k, jax.tree.map(shaped_abstractify, v)))
|
380
|
+
return tuple([tuple(static_args), tuple(dyn_args), tuple(static_kwargs), tuple(dyn_kwargs)])
|
341
381
|
elif self.cache_type is None:
|
342
382
|
num_arg = len(args)
|
343
|
-
|
383
|
+
static_args = tuple(args[i] for i in self.static_argnums if i < num_arg)
|
384
|
+
static_kwargs = tuple((k, v) for k, v in kwargs.items() if k in self.static_argnames)
|
385
|
+
return tuple([static_args, static_kwargs])
|
344
386
|
else:
|
345
387
|
raise ValueError(f"Invalid cache type: {self.cache_type}")
|
346
388
|
|
@@ -389,7 +431,7 @@ class StatefulFunction(PrettyObject):
|
|
389
431
|
self._cached_state_trace.clear()
|
390
432
|
|
391
433
|
def _wrapped_fun_to_eval(
|
392
|
-
self, cache_key, *args, return_only_write: bool = False, **
|
434
|
+
self, cache_key, static_kwargs: dict, *args, return_only_write: bool = False, **dyn_kwargs,
|
393
435
|
) -> Tuple[Any, Tuple[State, ...]]:
|
394
436
|
"""
|
395
437
|
Wrap the function and return the states that are read and written by the function and the output of the function.
|
@@ -405,7 +447,7 @@ class StatefulFunction(PrettyObject):
|
|
405
447
|
state_trace = _init_state_trace_stack(self.name)
|
406
448
|
self._cached_state_trace[cache_key] = state_trace
|
407
449
|
with state_trace:
|
408
|
-
out = self.fun(*args, **
|
450
|
+
out = self.fun(*args, **dyn_kwargs, **static_kwargs)
|
409
451
|
state_values = (
|
410
452
|
state_trace.get_write_state_values(True)
|
411
453
|
if return_only_write else
|
@@ -430,8 +472,9 @@ class StatefulFunction(PrettyObject):
|
|
430
472
|
the structure, shape, dtypes, and named shapes of the output of ``fun``.
|
431
473
|
|
432
474
|
Args:
|
433
|
-
|
434
|
-
|
475
|
+
*args: The arguments to the function.
|
476
|
+
**kwargs: The keyword arguments to the function.
|
477
|
+
return_only_write: If True, only return the states that are written by the function.
|
435
478
|
"""
|
436
479
|
|
437
480
|
# static args
|
@@ -440,17 +483,24 @@ class StatefulFunction(PrettyObject):
|
|
440
483
|
if cache_key not in self._cached_state_trace:
|
441
484
|
try:
|
442
485
|
# jaxpr
|
486
|
+
static_kwargs, dyn_kwargs = {}, {}
|
487
|
+
for k, v in kwargs.items():
|
488
|
+
if k in self.static_argnames:
|
489
|
+
static_kwargs[k] = v
|
490
|
+
else:
|
491
|
+
dyn_kwargs[k] = v
|
443
492
|
jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
|
444
493
|
functools.partial(
|
445
494
|
self._wrapped_fun_to_eval,
|
446
495
|
cache_key,
|
496
|
+
static_kwargs,
|
447
497
|
return_only_write=return_only_write
|
448
498
|
),
|
449
499
|
static_argnums=self.static_argnums,
|
450
500
|
axis_env=self.axis_env,
|
451
501
|
return_shape=True,
|
452
502
|
abstracted_axes=self.abstracted_axes
|
453
|
-
)(*args, **
|
503
|
+
)(*args, **dyn_kwargs)
|
454
504
|
# returns
|
455
505
|
self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
|
456
506
|
self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
|
@@ -483,6 +533,7 @@ class StatefulFunction(PrettyObject):
|
|
483
533
|
assert len(state_vals) == len(states), 'State length mismatch.'
|
484
534
|
|
485
535
|
# parameters
|
536
|
+
kwargs = {k: v for k, v in kwargs.items() if k not in self.static_argnames} # remove static kwargs
|
486
537
|
args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
|
487
538
|
args = jax.tree.flatten((args, kwargs, state_vals))[0]
|
488
539
|
|
@@ -519,12 +570,16 @@ class StatefulFunction(PrettyObject):
|
|
519
570
|
def make_jaxpr(
|
520
571
|
fun: Callable,
|
521
572
|
static_argnums: Union[int, Iterable[int]] = (),
|
573
|
+
static_argnames: Union[str, Iterable[str]] = (),
|
522
574
|
axis_env: Optional[Sequence[tuple[Hashable, int]]] = None,
|
523
575
|
return_shape: bool = False,
|
524
576
|
abstracted_axes: Optional[Any] = None,
|
525
577
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
526
|
-
) -> Callable[
|
527
|
-
|
578
|
+
) -> Callable[
|
579
|
+
...,
|
580
|
+
(Tuple[ClosedJaxpr, Tuple[State, ...]] |
|
581
|
+
Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])
|
582
|
+
]:
|
528
583
|
"""
|
529
584
|
Creates a function that produces its jaxpr given example args.
|
530
585
|
|
@@ -533,6 +588,7 @@ def make_jaxpr(
|
|
533
588
|
arguments and return value should be arrays, scalars, or standard Python
|
534
589
|
containers (tuple/list/dict) thereof.
|
535
590
|
static_argnums: See the :py:func:`jax.jit` docstring.
|
591
|
+
static_argnames: See the :py:func:`jax.jit` docstring.
|
536
592
|
axis_env: Optional, a sequence of pairs where the first element is an axis
|
537
593
|
name and the second element is a positive integer representing the size of
|
538
594
|
the mapped axis with that name. This parameter is useful when lowering
|
@@ -605,11 +661,11 @@ def make_jaxpr(
|
|
605
661
|
stateful_fun = StatefulFunction(
|
606
662
|
fun,
|
607
663
|
static_argnums=static_argnums,
|
664
|
+
static_argnames=static_argnames,
|
608
665
|
axis_env=axis_env,
|
609
666
|
abstracted_axes=abstracted_axes,
|
610
667
|
state_returns=state_returns,
|
611
668
|
name='make_jaxpr'
|
612
|
-
|
613
669
|
)
|
614
670
|
|
615
671
|
@wraps(fun)
|