brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -9,12 +9,20 @@ from typing import Callable, Sequence, Tuple, Protocol
|
|
9
9
|
import jax
|
10
10
|
import numpy as np
|
11
11
|
from jax import tree_util
|
12
|
-
from jax.core import Primitive
|
13
12
|
from jax.interpreters import batching, ad
|
14
13
|
from jax.interpreters import xla, mlir
|
15
|
-
from jax.lib import xla_client
|
16
14
|
from jaxlib.hlo_helpers import custom_call
|
17
15
|
|
16
|
+
if jax.__version_info__ < (0, 4, 35):
|
17
|
+
from jax.lib import xla_client
|
18
|
+
else:
|
19
|
+
import jax.extend as je
|
20
|
+
|
21
|
+
if jax.__version_info__ < (0, 4, 38):
|
22
|
+
from jax.core import Primitive
|
23
|
+
else:
|
24
|
+
from jax.extend.core import Primitive
|
25
|
+
|
18
26
|
numba_installed = importlib.util.find_spec('numba') is not None
|
19
27
|
|
20
28
|
__all__ = [
|
@@ -143,7 +151,10 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
|
|
143
151
|
xla_c_rule = cfunc(sig)(new_f)
|
144
152
|
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
|
145
153
|
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
|
146
|
-
|
154
|
+
if jax.__version_info__ < (0, 4, 35):
|
155
|
+
xla_client.register_custom_call_target(target_name, capsule, "cpu")
|
156
|
+
else:
|
157
|
+
je.ffi.register_ffi_target(target_name, capsule, "cpu", api_version=0)
|
147
158
|
|
148
159
|
# call
|
149
160
|
return custom_call(
|
@@ -157,7 +168,7 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
|
|
157
168
|
|
158
169
|
|
159
170
|
def register_numba_mlir_cpu_translation_rule(
|
160
|
-
primitive:
|
171
|
+
primitive: Primitive,
|
161
172
|
cpu_kernel: Callable,
|
162
173
|
debug: bool = False
|
163
174
|
):
|
@@ -198,7 +209,7 @@ class XLACustomOp:
|
|
198
209
|
transpose_translation: Callable = None,
|
199
210
|
):
|
200
211
|
# primitive
|
201
|
-
self.primitive =
|
212
|
+
self.primitive = Primitive(name)
|
202
213
|
self.primitive.multiple_results = True
|
203
214
|
|
204
215
|
# abstract evaluation
|
brainstate/graph/__init__.py
CHANGED
@@ -14,20 +14,16 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from ._graph_context import *
|
18
|
-
from ._graph_context import __all__ as _graph_context__all__
|
19
|
-
from ._graph_convert import *
|
20
|
-
from ._graph_convert import __all__ as _graph_convert__all__
|
21
17
|
from ._graph_node import *
|
22
18
|
from ._graph_node import __all__ as _graph_node__all__
|
23
19
|
from ._graph_operation import *
|
24
20
|
from ._graph_operation import __all__ as _graph_operation__all__
|
25
21
|
|
26
|
-
__all__ = (
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
del (
|
31
|
-
|
32
|
-
|
33
|
-
|
22
|
+
__all__ = (
|
23
|
+
_graph_node__all__ +
|
24
|
+
_graph_operation__all__
|
25
|
+
)
|
26
|
+
del (
|
27
|
+
_graph_node__all__,
|
28
|
+
_graph_operation__all__
|
29
|
+
)
|
brainstate/graph/_graph_node.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
@@ -27,9 +27,7 @@ import numpy as np
|
|
27
27
|
|
28
28
|
from brainstate._state import State, TreefyState
|
29
29
|
from brainstate.typing import Key
|
30
|
-
from brainstate.util._error import TraceContextError
|
31
30
|
from brainstate.util._pretty_repr import PrettyRepr, pretty_repr_avoid_duplicate, PrettyType, PrettyAttr
|
32
|
-
from brainstate.util._tracers import StateJaxTracer
|
33
31
|
from ._graph_operation import register_graph_node_type
|
34
32
|
|
35
33
|
__all__ = [
|
@@ -44,7 +42,6 @@ class GraphNodeMeta(ABCMeta):
|
|
44
42
|
if not TYPE_CHECKING:
|
45
43
|
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
|
46
44
|
node = cls.__new__(cls, *args, **kwargs)
|
47
|
-
vars(node)['_trace_state'] = StateJaxTracer()
|
48
45
|
node.__init__(*args, **kwargs)
|
49
46
|
return node
|
50
47
|
|
@@ -64,9 +61,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
64
61
|
|
65
62
|
graph_invisible_attrs = ()
|
66
63
|
|
67
|
-
if TYPE_CHECKING:
|
68
|
-
_trace_state: StateJaxTracer
|
69
|
-
|
70
64
|
def __init_subclass__(cls) -> None:
|
71
65
|
super().__init_subclass__()
|
72
66
|
|
@@ -79,21 +73,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
79
73
|
clear=_node_clear,
|
80
74
|
)
|
81
75
|
|
82
|
-
# if not TYPE_CHECKING:
|
83
|
-
# def __setattr__(self, name: str, value: Any) -> None:
|
84
|
-
# self._setattr(name, value)
|
85
|
-
|
86
|
-
# def _setattr(self, name: str, value: Any) -> None:
|
87
|
-
# self.check_valid_context(lambda: f"Cannot mutate '{type(self).__name__}' from different trace level")
|
88
|
-
# object.__setattr__(self, name, value)
|
89
|
-
|
90
|
-
def check_valid_context(self, error_msg: Callable[[], str]) -> None:
|
91
|
-
"""
|
92
|
-
Check if the current context is valid for the object to be mutated.
|
93
|
-
"""
|
94
|
-
if not self._trace_state.is_valid():
|
95
|
-
raise TraceContextError(error_msg())
|
96
|
-
|
97
76
|
def __deepcopy__(self: G, memo=None) -> G:
|
98
77
|
"""
|
99
78
|
Deepcopy the object.
|
@@ -214,7 +193,6 @@ def _node_create_empty(
|
|
214
193
|
) -> G:
|
215
194
|
node_type, = static
|
216
195
|
node = object.__new__(node_type)
|
217
|
-
vars(node).update(_trace_state=StateJaxTracer())
|
218
196
|
return node
|
219
197
|
|
220
198
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
2
|
# The credit should go to the Flax authors.
|
3
3
|
#
|
4
|
-
# Copyright 2024 The Flax Authors
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
5
|
#
|
6
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
7
|
# you may not use this file except in compliance with the License.
|
@@ -17,13 +17,10 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import unittest
|
19
19
|
from collections.abc import Callable
|
20
|
-
from functools import partial
|
21
20
|
from threading import Thread
|
22
|
-
from typing import Any
|
23
21
|
|
24
22
|
import jax
|
25
23
|
import jax.numpy as jnp
|
26
|
-
import pytest
|
27
24
|
from absl.testing import absltest, parameterized
|
28
25
|
|
29
26
|
import brainstate as bst
|
@@ -354,125 +351,6 @@ class TestGraphUtils(absltest.TestCase):
|
|
354
351
|
assert m2.tree.a is not m.tree.a
|
355
352
|
assert m2.tree is not m.tree
|
356
353
|
|
357
|
-
@pytest.mark.skip(reason='Not implemented')
|
358
|
-
def test_cached_unflatten(self):
|
359
|
-
class Foo(bst.graph.Node):
|
360
|
-
def __init__(self, ):
|
361
|
-
self.a = bst.nn.Linear(2, 2)
|
362
|
-
self.b = bst.nn.BatchNorm1d([10, 2])
|
363
|
-
|
364
|
-
def f(m: Foo):
|
365
|
-
m.a, m.b = m.b, m.a # type: ignore
|
366
|
-
|
367
|
-
m = Foo()
|
368
|
-
a = m.a
|
369
|
-
b = m.b
|
370
|
-
|
371
|
-
ref_out_idx_out = bst.graph.RefMap()
|
372
|
-
graphdef: bst.graph.GraphDef[Foo]
|
373
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
374
|
-
|
375
|
-
@partial(jax.jit, static_argnums=(0,))
|
376
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
377
|
-
idx_out_ref_in: dict[int, Any] = {}
|
378
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
379
|
-
f(m)
|
380
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
381
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
382
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
383
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
384
|
-
return state, static_out
|
385
|
-
|
386
|
-
static_out: bst.graph.Static
|
387
|
-
state, static_out = f_pure(graphdef, state)
|
388
|
-
idx_out_idx_in: dict[int, int]
|
389
|
-
graphdef, idx_out_idx_in = static_out.value
|
390
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
391
|
-
ref_out_idx_out, idx_out_idx_in
|
392
|
-
)
|
393
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
394
|
-
assert m2 is m
|
395
|
-
assert m2.a is b
|
396
|
-
assert m2.b is a
|
397
|
-
|
398
|
-
@pytest.mark.skip(reason='Not implemented')
|
399
|
-
def test_cached_unflatten_swap_variables(self):
|
400
|
-
class Foo(bst.graph.Node):
|
401
|
-
def __init__(self):
|
402
|
-
self.a = bst.ParamState(1)
|
403
|
-
self.b = bst.ParamState(2)
|
404
|
-
|
405
|
-
def f(m: Foo):
|
406
|
-
m.a, m.b = m.b, m.a
|
407
|
-
|
408
|
-
m = Foo()
|
409
|
-
a = m.a
|
410
|
-
b = m.b
|
411
|
-
|
412
|
-
ref_out_idx_out = bst.graph.RefMap[Any, int]()
|
413
|
-
graphdef: bst.graph.GraphDef[Foo]
|
414
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
415
|
-
|
416
|
-
@partial(jax.jit, static_argnums=(0,))
|
417
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
418
|
-
idx_out_ref_in: dict[int, Any] = {}
|
419
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
420
|
-
f(m)
|
421
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
422
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
423
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
424
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
425
|
-
return state, static_out
|
426
|
-
|
427
|
-
static_out: bst.graph.Static
|
428
|
-
state, static_out = f_pure(graphdef, state)
|
429
|
-
idx_out_idx_in: dict[int, int]
|
430
|
-
graphdef, idx_out_idx_in = static_out.value
|
431
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
432
|
-
ref_out_idx_out, idx_out_idx_in
|
433
|
-
)
|
434
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
435
|
-
assert m2 is m
|
436
|
-
assert m2.a is b
|
437
|
-
assert m2.b is a
|
438
|
-
|
439
|
-
@pytest.mark.skip(reason='Not implemented')
|
440
|
-
def test_cached_unflatten_add_self_reference(self):
|
441
|
-
class Foo(bst.graph.Node):
|
442
|
-
def __init__(self):
|
443
|
-
self.ref = None
|
444
|
-
|
445
|
-
def f(m: Foo):
|
446
|
-
m.ref = m
|
447
|
-
|
448
|
-
m = Foo()
|
449
|
-
|
450
|
-
ref_out_idx_out = bst.graph.RefMap()
|
451
|
-
graphdef: bst.graph.GraphDef[Foo]
|
452
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
|
453
|
-
|
454
|
-
@partial(jax.jit, static_argnums=(0,))
|
455
|
-
def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
|
456
|
-
idx_out_ref_in: dict[int, Any] = {}
|
457
|
-
m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
|
458
|
-
f(m)
|
459
|
-
ref_in_idx_in = bst.graph.RefMap[Any, int]()
|
460
|
-
graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
|
461
|
-
idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
|
462
|
-
static_out = bst.graph.Static((graphdef, idx_out_idx_in))
|
463
|
-
return state, static_out
|
464
|
-
|
465
|
-
static_out: bst.graph.Static
|
466
|
-
state, static_out = f_pure(graphdef, state)
|
467
|
-
idx_out_idx_in: dict[int, int]
|
468
|
-
graphdef, idx_out_idx_in = static_out.value
|
469
|
-
idx_in_ref_out = bst.graph.compose_mapping_reversed(
|
470
|
-
ref_out_idx_out, idx_out_idx_in
|
471
|
-
)
|
472
|
-
m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
|
473
|
-
assert m2 is m
|
474
|
-
assert m2.ref is m2
|
475
|
-
|
476
354
|
def test_call_jit_update(self):
|
477
355
|
class Counter(bst.graph.Node):
|
478
356
|
def __init__(self):
|
@@ -527,43 +405,6 @@ class TestGraphUtils(absltest.TestCase):
|
|
527
405
|
self.assertEqual(nodes['a'].count.value, 0)
|
528
406
|
self.assertEqual(nodes['b'].count.value, 1)
|
529
407
|
|
530
|
-
def test_to_tree_simple(self):
|
531
|
-
m = bst.nn.Linear(2, 3, )
|
532
|
-
impure_tree = (m, 1, {'b': m})
|
533
|
-
|
534
|
-
pure_tree = bst.graph.graph_to_tree(impure_tree)
|
535
|
-
|
536
|
-
t1 = pure_tree[0]
|
537
|
-
t2 = pure_tree[2]['b']
|
538
|
-
|
539
|
-
self.assertEqual(pure_tree[1], 1)
|
540
|
-
self.assertIsInstance(t1, bst.graph.NodeStates)
|
541
|
-
assert isinstance(t1, bst.graph.NodeStates)
|
542
|
-
self.assertIsInstance(t2, bst.graph.NodeStates)
|
543
|
-
assert isinstance(t2, bst.graph.NodeStates)
|
544
|
-
self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
|
545
|
-
self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
|
546
|
-
self.assertLen(t1.states[0].to_flat(), 1)
|
547
|
-
self.assertLen(t2.states[0].to_flat(), 0)
|
548
|
-
|
549
|
-
impure_tree2 = bst.graph.tree_to_graph(pure_tree)
|
550
|
-
|
551
|
-
m1_out = impure_tree2[0]
|
552
|
-
m2_out = impure_tree2[2]['b']
|
553
|
-
|
554
|
-
self.assertIs(m1_out, m2_out)
|
555
|
-
self.assertEqual(impure_tree2[1], 1)
|
556
|
-
|
557
|
-
def test_to_tree_consistent_prefix(self):
|
558
|
-
m = bst.nn.Linear(2, 3, )
|
559
|
-
impure_tree = (m, 1, {'b': m})
|
560
|
-
prefix = (0, None, 0)
|
561
|
-
pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
|
562
|
-
|
563
|
-
prefix = (0, None, 1)
|
564
|
-
with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
|
565
|
-
bst.graph.graph_to_tree(impure_tree, prefix=prefix)
|
566
|
-
|
567
408
|
|
568
409
|
class SimpleModule(bst.nn.Module):
|
569
410
|
pass
|
@@ -22,8 +22,8 @@ import numpy as np
|
|
22
22
|
|
23
23
|
from brainstate import environ, init, random
|
24
24
|
from brainstate._state import ShortTermState
|
25
|
-
from brainstate._state import State
|
26
|
-
from brainstate.compile import while_loop
|
25
|
+
from brainstate._state import State, maybe_state
|
26
|
+
from brainstate.compile import while_loop
|
27
27
|
from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
|
28
28
|
from brainstate.nn._module import Module
|
29
29
|
from brainstate.typing import ArrayLike, Size, DTypeLike
|
@@ -198,55 +198,97 @@ class PoissonInput(Module):
|
|
198
198
|
self.weight = weight
|
199
199
|
|
200
200
|
def update(self):
|
201
|
-
p = self.freq * environ.get_dt()
|
202
|
-
a = self.num_input * p
|
203
|
-
b = self.num_input * (1 - p)
|
204
|
-
|
205
|
-
target = self.target()
|
206
201
|
target_state = getattr(self.target.module, self.target.item)
|
207
202
|
|
208
203
|
# generate Poisson input
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
204
|
+
poisson_input(
|
205
|
+
self.freq,
|
206
|
+
self.num_input,
|
207
|
+
self.weight,
|
208
|
+
target_state,
|
209
|
+
self.indices,
|
213
210
|
)
|
214
211
|
|
215
|
-
# update target variable
|
216
|
-
target_state.value = target.at[self.indices].add(inp * self.weight)
|
217
|
-
|
218
212
|
|
219
213
|
def poisson_input(
|
220
|
-
freq:
|
214
|
+
freq: u.Quantity[u.Hz],
|
221
215
|
num_input: int,
|
222
|
-
weight:
|
216
|
+
weight: u.Quantity,
|
223
217
|
target: State,
|
224
218
|
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
225
219
|
):
|
226
220
|
"""
|
227
221
|
Poisson Input to the given :py:class:`brainstate.State`.
|
228
222
|
"""
|
223
|
+
freq = maybe_state(freq)
|
224
|
+
weight = maybe_state(weight)
|
225
|
+
|
229
226
|
assert isinstance(target, State), 'The target must be a State.'
|
230
|
-
p = freq * environ.get_dt()
|
227
|
+
p = (freq * environ.get_dt()).to_decimal()
|
231
228
|
a = num_input * p
|
232
229
|
b = num_input * (1 - p)
|
233
230
|
tar_val = target.value
|
231
|
+
cond = u.math.logical_and(a > 5, b > 5)
|
232
|
+
|
234
233
|
if indices is None:
|
235
234
|
# generate Poisson input
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
235
|
+
branch1 = jax.tree.map(
|
236
|
+
lambda tar: random.normal(
|
237
|
+
a,
|
238
|
+
b * p,
|
239
|
+
tar.shape,
|
240
|
+
dtype=tar.dtype
|
242
241
|
),
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
242
|
+
tar_val,
|
243
|
+
is_leaf=u.math.is_quantity
|
244
|
+
)
|
245
|
+
branch2 = jax.tree.map(
|
246
|
+
lambda tar: random.binomial(
|
247
|
+
num_input,
|
248
|
+
p,
|
249
|
+
tar.shape,
|
250
|
+
check_valid=False,
|
251
|
+
dtype=tar.dtype
|
252
|
+
),
|
253
|
+
tar_val,
|
254
|
+
is_leaf=u.math.is_quantity,
|
255
|
+
)
|
256
|
+
|
257
|
+
inp = jax.tree.map(
|
258
|
+
lambda b1, b2: u.math.where(cond, b1, b2),
|
259
|
+
branch1,
|
260
|
+
branch2,
|
261
|
+
is_leaf=u.math.is_quantity,
|
248
262
|
)
|
249
263
|
|
264
|
+
# inp = jax.lax.cond(
|
265
|
+
# cond,
|
266
|
+
# lambda rand_key: jax.tree.map(
|
267
|
+
# lambda tar: random.normal(
|
268
|
+
# a,
|
269
|
+
# b * p,
|
270
|
+
# tar.shape,
|
271
|
+
# key=rand_key,
|
272
|
+
# dtype=tar.dtype
|
273
|
+
# ),
|
274
|
+
# tar_val,
|
275
|
+
# is_leaf=u.math.is_quantity
|
276
|
+
# ),
|
277
|
+
# lambda rand_key: jax.tree.map(
|
278
|
+
# lambda tar: random.binomial(
|
279
|
+
# num_input,
|
280
|
+
# p,
|
281
|
+
# tar.shape,
|
282
|
+
# key=rand_key,
|
283
|
+
# check_valid=False,
|
284
|
+
# dtype=tar.dtype
|
285
|
+
# ),
|
286
|
+
# tar_val,
|
287
|
+
# is_leaf=u.math.is_quantity,
|
288
|
+
# ),
|
289
|
+
# random.split_key()
|
290
|
+
# )
|
291
|
+
|
250
292
|
# update target variable
|
251
293
|
target.value = jax.tree.map(
|
252
294
|
lambda x: x * weight,
|
@@ -256,19 +298,62 @@ def poisson_input(
|
|
256
298
|
|
257
299
|
else:
|
258
300
|
# generate Poisson input
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
301
|
+
branch1 = jax.tree.map(
|
302
|
+
lambda tar: random.normal(
|
303
|
+
a,
|
304
|
+
b * p,
|
305
|
+
tar[indices].shape,
|
306
|
+
dtype=tar.dtype
|
265
307
|
),
|
266
|
-
|
267
|
-
|
268
|
-
tar_val,
|
269
|
-
is_leaf=u.math.is_quantity
|
270
|
-
)
|
308
|
+
tar_val,
|
309
|
+
is_leaf=u.math.is_quantity
|
271
310
|
)
|
311
|
+
branch2 = jax.tree.map(
|
312
|
+
lambda tar: random.binomial(
|
313
|
+
num_input,
|
314
|
+
p,
|
315
|
+
tar[indices].shape,
|
316
|
+
# check_valid=False,
|
317
|
+
dtype=tar.dtype
|
318
|
+
),
|
319
|
+
tar_val,
|
320
|
+
is_leaf=u.math.is_quantity
|
321
|
+
)
|
322
|
+
|
323
|
+
inp = jax.tree.map(
|
324
|
+
lambda b1, b2: u.math.where(cond, b1, b2),
|
325
|
+
branch1,
|
326
|
+
branch2,
|
327
|
+
is_leaf=u.math.is_quantity,
|
328
|
+
)
|
329
|
+
|
330
|
+
# inp = jax.lax.cond(
|
331
|
+
# cond,
|
332
|
+
# lambda rand_key: jax.tree.map(
|
333
|
+
# lambda tar: random.normal(
|
334
|
+
# a,
|
335
|
+
# b * p,
|
336
|
+
# tar[indices].shape,
|
337
|
+
# key=rand_key,
|
338
|
+
# dtype=tar.dtype
|
339
|
+
# ),
|
340
|
+
# tar_val,
|
341
|
+
# is_leaf=u.math.is_quantity
|
342
|
+
# ),
|
343
|
+
# lambda rand_key: jax.tree.map(
|
344
|
+
# lambda tar: random.binomial(
|
345
|
+
# num_input,
|
346
|
+
# p,
|
347
|
+
# tar[indices].shape,
|
348
|
+
# key=rand_key,
|
349
|
+
# check_valid=False,
|
350
|
+
# dtype=tar.dtype
|
351
|
+
# ),
|
352
|
+
# tar_val,
|
353
|
+
# is_leaf=u.math.is_quantity
|
354
|
+
# ),
|
355
|
+
# random.split_key()
|
356
|
+
# )
|
272
357
|
|
273
358
|
# update target variable
|
274
359
|
target.value = jax.tree.map(
|
@@ -191,6 +191,7 @@ class _Conv(_BaseConv):
|
|
191
191
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
192
192
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
193
193
|
name: str = None,
|
194
|
+
param_type: type = ParamState,
|
194
195
|
):
|
195
196
|
super().__init__(in_size=in_size,
|
196
197
|
out_channels=out_channels,
|
@@ -215,7 +216,7 @@ class _Conv(_BaseConv):
|
|
215
216
|
params['bias'] = bias
|
216
217
|
|
217
218
|
# The weight operation
|
218
|
-
self.weight =
|
219
|
+
self.weight = param_type(params)
|
219
220
|
|
220
221
|
# Evaluate the output shape
|
221
222
|
abstract_y = jax.eval_shape(
|
@@ -346,6 +347,7 @@ class _ScaledWSConv(_BaseConv):
|
|
346
347
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
347
348
|
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
348
349
|
name: str = None,
|
350
|
+
param_type: type = ParamState,
|
349
351
|
):
|
350
352
|
super().__init__(in_size=in_size,
|
351
353
|
out_channels=out_channels,
|
@@ -379,7 +381,7 @@ class _ScaledWSConv(_BaseConv):
|
|
379
381
|
self.eps = eps
|
380
382
|
|
381
383
|
# The weight operation
|
382
|
-
self.weight =
|
384
|
+
self.weight = param_type(params)
|
383
385
|
|
384
386
|
# Evaluate the output shape
|
385
387
|
abstract_y = jax.eval_shape(
|