brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. brainstate/_state.py +77 -44
  2. brainstate/_state_test.py +0 -17
  3. brainstate/augment/_eval_shape.py +9 -10
  4. brainstate/augment/_eval_shape_test.py +1 -1
  5. brainstate/augment/_mapping.py +265 -277
  6. brainstate/augment/_mapping_test.py +147 -175
  7. brainstate/compile/_ad_checkpoint.py +6 -4
  8. brainstate/compile/_error_if_test.py +1 -0
  9. brainstate/compile/_jit.py +37 -28
  10. brainstate/compile/_loop_collect_return.py +8 -5
  11. brainstate/compile/_loop_no_collection.py +2 -0
  12. brainstate/compile/_make_jaxpr.py +7 -3
  13. brainstate/compile/_make_jaxpr_test.py +2 -1
  14. brainstate/compile/_progress_bar.py +68 -40
  15. brainstate/compile/_unvmap.py +6 -2
  16. brainstate/environ.py +28 -18
  17. brainstate/environ_test.py +4 -0
  18. brainstate/event/__init__.py +0 -2
  19. brainstate/event/_csr.py +266 -23
  20. brainstate/event/_csr_test.py +187 -0
  21. brainstate/event/_fixedprob_mv.py +4 -2
  22. brainstate/event/_fixedprob_mv_test.py +2 -1
  23. brainstate/event/_xla_custom_op.py +16 -5
  24. brainstate/graph/__init__.py +8 -12
  25. brainstate/graph/_graph_node.py +1 -23
  26. brainstate/graph/_graph_operation.py +1 -1
  27. brainstate/graph/_graph_operation_test.py +0 -159
  28. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  29. brainstate/nn/_interaction/_conv.py +4 -2
  30. brainstate/nn/_interaction/_linear.py +84 -10
  31. brainstate/random/_rand_funs.py +9 -2
  32. brainstate/random/_rand_seed.py +12 -2
  33. brainstate/random/_rand_state.py +50 -179
  34. brainstate/surrogate.py +5 -1
  35. brainstate/util/__init__.py +0 -4
  36. brainstate/util/_caller.py +1 -1
  37. brainstate/util/_dict.py +4 -1
  38. brainstate/util/_filter.py +1 -1
  39. brainstate/util/_pretty_repr.py +1 -1
  40. brainstate/util/_struct.py +1 -1
  41. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
  42. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
  43. brainstate/event/_csr_mv_test.py +0 -118
  44. brainstate/graph/_graph_context.py +0 -443
  45. brainstate/graph/_graph_context_test.py +0 -65
  46. brainstate/graph/_graph_convert.py +0 -246
  47. brainstate/util/_tracers.py +0 -68
  48. brainstate/util/_visualization.py +0 -47
  49. /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
  50. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
  51. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
  52. {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
@@ -9,12 +9,20 @@ from typing import Callable, Sequence, Tuple, Protocol
9
9
  import jax
10
10
  import numpy as np
11
11
  from jax import tree_util
12
- from jax.core import Primitive
13
12
  from jax.interpreters import batching, ad
14
13
  from jax.interpreters import xla, mlir
15
- from jax.lib import xla_client
16
14
  from jaxlib.hlo_helpers import custom_call
17
15
 
16
+ if jax.__version_info__ < (0, 4, 35):
17
+ from jax.lib import xla_client
18
+ else:
19
+ import jax.extend as je
20
+
21
+ if jax.__version_info__ < (0, 4, 38):
22
+ from jax.core import Primitive
23
+ else:
24
+ from jax.extend.core import Primitive
25
+
18
26
  numba_installed = importlib.util.find_spec('numba') is not None
19
27
 
20
28
  __all__ = [
@@ -143,7 +151,10 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
143
151
  xla_c_rule = cfunc(sig)(new_f)
144
152
  target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
145
153
  capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
146
- xla_client.register_custom_call_target(target_name, capsule, "cpu")
154
+ if jax.__version_info__ < (0, 4, 35):
155
+ xla_client.register_custom_call_target(target_name, capsule, "cpu")
156
+ else:
157
+ je.ffi.register_ffi_target(target_name, capsule, "cpu", api_version=0)
147
158
 
148
159
  # call
149
160
  return custom_call(
@@ -157,7 +168,7 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
157
168
 
158
169
 
159
170
  def register_numba_mlir_cpu_translation_rule(
160
- primitive: jax.core.Primitive,
171
+ primitive: Primitive,
161
172
  cpu_kernel: Callable,
162
173
  debug: bool = False
163
174
  ):
@@ -198,7 +209,7 @@ class XLACustomOp:
198
209
  transpose_translation: Callable = None,
199
210
  ):
200
211
  # primitive
201
- self.primitive = jax.core.Primitive(name)
212
+ self.primitive = Primitive(name)
202
213
  self.primitive.multiple_results = True
203
214
 
204
215
  # abstract evaluation
@@ -14,20 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from ._graph_context import *
18
- from ._graph_context import __all__ as _graph_context__all__
19
- from ._graph_convert import *
20
- from ._graph_convert import __all__ as _graph_convert__all__
21
17
  from ._graph_node import *
22
18
  from ._graph_node import __all__ as _graph_node__all__
23
19
  from ._graph_operation import *
24
20
  from ._graph_operation import __all__ as _graph_operation__all__
25
21
 
26
- __all__ = (_graph_context__all__ +
27
- _graph_convert__all__ +
28
- _graph_node__all__ +
29
- _graph_operation__all__)
30
- del (_graph_context__all__,
31
- _graph_convert__all__,
32
- _graph_node__all__,
33
- _graph_operation__all__)
22
+ __all__ = (
23
+ _graph_node__all__ +
24
+ _graph_operation__all__
25
+ )
26
+ del (
27
+ _graph_node__all__,
28
+ _graph_operation__all__
29
+ )
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -27,9 +27,7 @@ import numpy as np
27
27
 
28
28
  from brainstate._state import State, TreefyState
29
29
  from brainstate.typing import Key
30
- from brainstate.util._error import TraceContextError
31
30
  from brainstate.util._pretty_repr import PrettyRepr, pretty_repr_avoid_duplicate, PrettyType, PrettyAttr
32
- from brainstate.util._tracers import StateJaxTracer
33
31
  from ._graph_operation import register_graph_node_type
34
32
 
35
33
  __all__ = [
@@ -44,7 +42,6 @@ class GraphNodeMeta(ABCMeta):
44
42
  if not TYPE_CHECKING:
45
43
  def __call__(cls, *args: Any, **kwargs: Any) -> Any:
46
44
  node = cls.__new__(cls, *args, **kwargs)
47
- vars(node)['_trace_state'] = StateJaxTracer()
48
45
  node.__init__(*args, **kwargs)
49
46
  return node
50
47
 
@@ -64,9 +61,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
64
61
 
65
62
  graph_invisible_attrs = ()
66
63
 
67
- if TYPE_CHECKING:
68
- _trace_state: StateJaxTracer
69
-
70
64
  def __init_subclass__(cls) -> None:
71
65
  super().__init_subclass__()
72
66
 
@@ -79,21 +73,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
79
73
  clear=_node_clear,
80
74
  )
81
75
 
82
- # if not TYPE_CHECKING:
83
- # def __setattr__(self, name: str, value: Any) -> None:
84
- # self._setattr(name, value)
85
-
86
- # def _setattr(self, name: str, value: Any) -> None:
87
- # self.check_valid_context(lambda: f"Cannot mutate '{type(self).__name__}' from different trace level")
88
- # object.__setattr__(self, name, value)
89
-
90
- def check_valid_context(self, error_msg: Callable[[], str]) -> None:
91
- """
92
- Check if the current context is valid for the object to be mutated.
93
- """
94
- if not self._trace_state.is_valid():
95
- raise TraceContextError(error_msg())
96
-
97
76
  def __deepcopy__(self: G, memo=None) -> G:
98
77
  """
99
78
  Deepcopy the object.
@@ -214,7 +193,6 @@ def _node_create_empty(
214
193
  ) -> G:
215
194
  node_type, = static
216
195
  node = object.__new__(node_type)
217
- vars(node).update(_trace_state=StateJaxTracer())
218
196
  return node
219
197
 
220
198
 
@@ -1,7 +1,7 @@
1
1
  # The file is adapted from the Flax library (https://github.com/google/flax).
2
2
  # The credit should go to the Flax authors.
3
3
  #
4
- # Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
4
+ # Copyright 2024 The Flax Authors.
5
5
  #
6
6
  # Licensed under the Apache License, Version 2.0 (the "License");
7
7
  # you may not use this file except in compliance with the License.
@@ -17,13 +17,10 @@ from __future__ import annotations
17
17
 
18
18
  import unittest
19
19
  from collections.abc import Callable
20
- from functools import partial
21
20
  from threading import Thread
22
- from typing import Any
23
21
 
24
22
  import jax
25
23
  import jax.numpy as jnp
26
- import pytest
27
24
  from absl.testing import absltest, parameterized
28
25
 
29
26
  import brainstate as bst
@@ -354,125 +351,6 @@ class TestGraphUtils(absltest.TestCase):
354
351
  assert m2.tree.a is not m.tree.a
355
352
  assert m2.tree is not m.tree
356
353
 
357
- @pytest.mark.skip(reason='Not implemented')
358
- def test_cached_unflatten(self):
359
- class Foo(bst.graph.Node):
360
- def __init__(self, ):
361
- self.a = bst.nn.Linear(2, 2)
362
- self.b = bst.nn.BatchNorm1d([10, 2])
363
-
364
- def f(m: Foo):
365
- m.a, m.b = m.b, m.a # type: ignore
366
-
367
- m = Foo()
368
- a = m.a
369
- b = m.b
370
-
371
- ref_out_idx_out = bst.graph.RefMap()
372
- graphdef: bst.graph.GraphDef[Foo]
373
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
374
-
375
- @partial(jax.jit, static_argnums=(0,))
376
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
377
- idx_out_ref_in: dict[int, Any] = {}
378
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
379
- f(m)
380
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
381
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
382
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
383
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
384
- return state, static_out
385
-
386
- static_out: bst.graph.Static
387
- state, static_out = f_pure(graphdef, state)
388
- idx_out_idx_in: dict[int, int]
389
- graphdef, idx_out_idx_in = static_out.value
390
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
391
- ref_out_idx_out, idx_out_idx_in
392
- )
393
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
394
- assert m2 is m
395
- assert m2.a is b
396
- assert m2.b is a
397
-
398
- @pytest.mark.skip(reason='Not implemented')
399
- def test_cached_unflatten_swap_variables(self):
400
- class Foo(bst.graph.Node):
401
- def __init__(self):
402
- self.a = bst.ParamState(1)
403
- self.b = bst.ParamState(2)
404
-
405
- def f(m: Foo):
406
- m.a, m.b = m.b, m.a
407
-
408
- m = Foo()
409
- a = m.a
410
- b = m.b
411
-
412
- ref_out_idx_out = bst.graph.RefMap[Any, int]()
413
- graphdef: bst.graph.GraphDef[Foo]
414
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
415
-
416
- @partial(jax.jit, static_argnums=(0,))
417
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
418
- idx_out_ref_in: dict[int, Any] = {}
419
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
420
- f(m)
421
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
422
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
423
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
424
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
425
- return state, static_out
426
-
427
- static_out: bst.graph.Static
428
- state, static_out = f_pure(graphdef, state)
429
- idx_out_idx_in: dict[int, int]
430
- graphdef, idx_out_idx_in = static_out.value
431
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
432
- ref_out_idx_out, idx_out_idx_in
433
- )
434
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
435
- assert m2 is m
436
- assert m2.a is b
437
- assert m2.b is a
438
-
439
- @pytest.mark.skip(reason='Not implemented')
440
- def test_cached_unflatten_add_self_reference(self):
441
- class Foo(bst.graph.Node):
442
- def __init__(self):
443
- self.ref = None
444
-
445
- def f(m: Foo):
446
- m.ref = m
447
-
448
- m = Foo()
449
-
450
- ref_out_idx_out = bst.graph.RefMap()
451
- graphdef: bst.graph.GraphDef[Foo]
452
- graphdef, state = bst.graph.flatten(m, ref_index=ref_out_idx_out)
453
-
454
- @partial(jax.jit, static_argnums=(0,))
455
- def f_pure(graphdef: bst.graph.GraphDef[Foo], state):
456
- idx_out_ref_in: dict[int, Any] = {}
457
- m = bst.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in)
458
- f(m)
459
- ref_in_idx_in = bst.graph.RefMap[Any, int]()
460
- graphdef, state = bst.graph.flatten(m, ref_index=ref_in_idx_in)
461
- idx_out_idx_in = bst.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in)
462
- static_out = bst.graph.Static((graphdef, idx_out_idx_in))
463
- return state, static_out
464
-
465
- static_out: bst.graph.Static
466
- state, static_out = f_pure(graphdef, state)
467
- idx_out_idx_in: dict[int, int]
468
- graphdef, idx_out_idx_in = static_out.value
469
- idx_in_ref_out = bst.graph.compose_mapping_reversed(
470
- ref_out_idx_out, idx_out_idx_in
471
- )
472
- m2 = bst.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out)
473
- assert m2 is m
474
- assert m2.ref is m2
475
-
476
354
  def test_call_jit_update(self):
477
355
  class Counter(bst.graph.Node):
478
356
  def __init__(self):
@@ -527,43 +405,6 @@ class TestGraphUtils(absltest.TestCase):
527
405
  self.assertEqual(nodes['a'].count.value, 0)
528
406
  self.assertEqual(nodes['b'].count.value, 1)
529
407
 
530
- def test_to_tree_simple(self):
531
- m = bst.nn.Linear(2, 3, )
532
- impure_tree = (m, 1, {'b': m})
533
-
534
- pure_tree = bst.graph.graph_to_tree(impure_tree)
535
-
536
- t1 = pure_tree[0]
537
- t2 = pure_tree[2]['b']
538
-
539
- self.assertEqual(pure_tree[1], 1)
540
- self.assertIsInstance(t1, bst.graph.NodeStates)
541
- assert isinstance(t1, bst.graph.NodeStates)
542
- self.assertIsInstance(t2, bst.graph.NodeStates)
543
- assert isinstance(t2, bst.graph.NodeStates)
544
- self.assertIsInstance(t1.graphdef, bst.graph.NodeDef)
545
- self.assertIsInstance(t2.graphdef, bst.graph.NodeRef)
546
- self.assertLen(t1.states[0].to_flat(), 1)
547
- self.assertLen(t2.states[0].to_flat(), 0)
548
-
549
- impure_tree2 = bst.graph.tree_to_graph(pure_tree)
550
-
551
- m1_out = impure_tree2[0]
552
- m2_out = impure_tree2[2]['b']
553
-
554
- self.assertIs(m1_out, m2_out)
555
- self.assertEqual(impure_tree2[1], 1)
556
-
557
- def test_to_tree_consistent_prefix(self):
558
- m = bst.nn.Linear(2, 3, )
559
- impure_tree = (m, 1, {'b': m})
560
- prefix = (0, None, 0)
561
- pure_tree = bst.graph.graph_to_tree(impure_tree, prefix=prefix)
562
-
563
- prefix = (0, None, 1)
564
- with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
565
- bst.graph.graph_to_tree(impure_tree, prefix=prefix)
566
-
567
408
 
568
409
  class SimpleModule(bst.nn.Module):
569
410
  pass
@@ -22,8 +22,8 @@ import numpy as np
22
22
 
23
23
  from brainstate import environ, init, random
24
24
  from brainstate._state import ShortTermState
25
- from brainstate._state import State
26
- from brainstate.compile import while_loop, cond
25
+ from brainstate._state import State, maybe_state
26
+ from brainstate.compile import while_loop
27
27
  from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
28
28
  from brainstate.nn._module import Module
29
29
  from brainstate.typing import ArrayLike, Size, DTypeLike
@@ -198,55 +198,97 @@ class PoissonInput(Module):
198
198
  self.weight = weight
199
199
 
200
200
  def update(self):
201
- p = self.freq * environ.get_dt()
202
- a = self.num_input * p
203
- b = self.num_input * (1 - p)
204
-
205
- target = self.target()
206
201
  target_state = getattr(self.target.module, self.target.item)
207
202
 
208
203
  # generate Poisson input
209
- inp = cond(
210
- u.math.logical_and(a > 5, b > 5),
211
- lambda: random.normal(a, b * p, self.indices.shape),
212
- lambda: random.binomial(self.num_input, p, self.indices.shape).astype(float)
204
+ poisson_input(
205
+ self.freq,
206
+ self.num_input,
207
+ self.weight,
208
+ target_state,
209
+ self.indices,
213
210
  )
214
211
 
215
- # update target variable
216
- target_state.value = target.at[self.indices].add(inp * self.weight)
217
-
218
212
 
219
213
  def poisson_input(
220
- freq: ArrayLike,
214
+ freq: u.Quantity[u.Hz],
221
215
  num_input: int,
222
- weight: ArrayLike,
216
+ weight: u.Quantity,
223
217
  target: State,
224
218
  indices: Optional[Union[np.ndarray, jax.Array]] = None,
225
219
  ):
226
220
  """
227
221
  Poisson Input to the given :py:class:`brainstate.State`.
228
222
  """
223
+ freq = maybe_state(freq)
224
+ weight = maybe_state(weight)
225
+
229
226
  assert isinstance(target, State), 'The target must be a State.'
230
- p = freq * environ.get_dt()
227
+ p = (freq * environ.get_dt()).to_decimal()
231
228
  a = num_input * p
232
229
  b = num_input * (1 - p)
233
230
  tar_val = target.value
231
+ cond = u.math.logical_and(a > 5, b > 5)
232
+
234
233
  if indices is None:
235
234
  # generate Poisson input
236
- inp = cond(
237
- u.math.logical_and(a > 5, b > 5),
238
- lambda: jax.tree.map(
239
- lambda tar: random.normal(a, b * p, tar.shape),
240
- tar_val,
241
- is_leaf=u.math.is_quantity
235
+ branch1 = jax.tree.map(
236
+ lambda tar: random.normal(
237
+ a,
238
+ b * p,
239
+ tar.shape,
240
+ dtype=tar.dtype
242
241
  ),
243
- lambda: jax.tree.map(
244
- lambda tar: random.binomial(num_input, p, tar.shape).astype(float),
245
- tar_val,
246
- is_leaf=u.math.is_quantity
247
- )
242
+ tar_val,
243
+ is_leaf=u.math.is_quantity
244
+ )
245
+ branch2 = jax.tree.map(
246
+ lambda tar: random.binomial(
247
+ num_input,
248
+ p,
249
+ tar.shape,
250
+ check_valid=False,
251
+ dtype=tar.dtype
252
+ ),
253
+ tar_val,
254
+ is_leaf=u.math.is_quantity,
255
+ )
256
+
257
+ inp = jax.tree.map(
258
+ lambda b1, b2: u.math.where(cond, b1, b2),
259
+ branch1,
260
+ branch2,
261
+ is_leaf=u.math.is_quantity,
248
262
  )
249
263
 
264
+ # inp = jax.lax.cond(
265
+ # cond,
266
+ # lambda rand_key: jax.tree.map(
267
+ # lambda tar: random.normal(
268
+ # a,
269
+ # b * p,
270
+ # tar.shape,
271
+ # key=rand_key,
272
+ # dtype=tar.dtype
273
+ # ),
274
+ # tar_val,
275
+ # is_leaf=u.math.is_quantity
276
+ # ),
277
+ # lambda rand_key: jax.tree.map(
278
+ # lambda tar: random.binomial(
279
+ # num_input,
280
+ # p,
281
+ # tar.shape,
282
+ # key=rand_key,
283
+ # check_valid=False,
284
+ # dtype=tar.dtype
285
+ # ),
286
+ # tar_val,
287
+ # is_leaf=u.math.is_quantity,
288
+ # ),
289
+ # random.split_key()
290
+ # )
291
+
250
292
  # update target variable
251
293
  target.value = jax.tree.map(
252
294
  lambda x: x * weight,
@@ -256,19 +298,62 @@ def poisson_input(
256
298
 
257
299
  else:
258
300
  # generate Poisson input
259
- inp = cond(
260
- u.math.logical_and(a > 5, b > 5),
261
- lambda: jax.tree.map(
262
- lambda tar: random.normal(a, b * p, tar[indices].shape),
263
- tar_val,
264
- is_leaf=u.math.is_quantity
301
+ branch1 = jax.tree.map(
302
+ lambda tar: random.normal(
303
+ a,
304
+ b * p,
305
+ tar[indices].shape,
306
+ dtype=tar.dtype
265
307
  ),
266
- lambda: jax.tree.map(
267
- lambda tar: random.binomial(num_input, p, tar[indices].shape).astype(float),
268
- tar_val,
269
- is_leaf=u.math.is_quantity
270
- )
308
+ tar_val,
309
+ is_leaf=u.math.is_quantity
271
310
  )
311
+ branch2 = jax.tree.map(
312
+ lambda tar: random.binomial(
313
+ num_input,
314
+ p,
315
+ tar[indices].shape,
316
+ # check_valid=False,
317
+ dtype=tar.dtype
318
+ ),
319
+ tar_val,
320
+ is_leaf=u.math.is_quantity
321
+ )
322
+
323
+ inp = jax.tree.map(
324
+ lambda b1, b2: u.math.where(cond, b1, b2),
325
+ branch1,
326
+ branch2,
327
+ is_leaf=u.math.is_quantity,
328
+ )
329
+
330
+ # inp = jax.lax.cond(
331
+ # cond,
332
+ # lambda rand_key: jax.tree.map(
333
+ # lambda tar: random.normal(
334
+ # a,
335
+ # b * p,
336
+ # tar[indices].shape,
337
+ # key=rand_key,
338
+ # dtype=tar.dtype
339
+ # ),
340
+ # tar_val,
341
+ # is_leaf=u.math.is_quantity
342
+ # ),
343
+ # lambda rand_key: jax.tree.map(
344
+ # lambda tar: random.binomial(
345
+ # num_input,
346
+ # p,
347
+ # tar[indices].shape,
348
+ # key=rand_key,
349
+ # check_valid=False,
350
+ # dtype=tar.dtype
351
+ # ),
352
+ # tar_val,
353
+ # is_leaf=u.math.is_quantity
354
+ # ),
355
+ # random.split_key()
356
+ # )
272
357
 
273
358
  # update target variable
274
359
  target.value = jax.tree.map(
@@ -191,6 +191,7 @@ class _Conv(_BaseConv):
191
191
  b_init: Optional[Union[Callable, ArrayLike]] = None,
192
192
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
193
193
  name: str = None,
194
+ param_type: type = ParamState,
194
195
  ):
195
196
  super().__init__(in_size=in_size,
196
197
  out_channels=out_channels,
@@ -215,7 +216,7 @@ class _Conv(_BaseConv):
215
216
  params['bias'] = bias
216
217
 
217
218
  # The weight operation
218
- self.weight = ParamState(params)
219
+ self.weight = param_type(params)
219
220
 
220
221
  # Evaluate the output shape
221
222
  abstract_y = jax.eval_shape(
@@ -346,6 +347,7 @@ class _ScaledWSConv(_BaseConv):
346
347
  b_init: Optional[Union[Callable, ArrayLike]] = None,
347
348
  w_mask: Optional[Union[ArrayLike, Callable]] = None,
348
349
  name: str = None,
350
+ param_type: type = ParamState,
349
351
  ):
350
352
  super().__init__(in_size=in_size,
351
353
  out_channels=out_channels,
@@ -379,7 +381,7 @@ class _ScaledWSConv(_BaseConv):
379
381
  self.eps = eps
380
382
 
381
383
  # The weight operation
382
- self.weight = ParamState(params)
384
+ self.weight = param_type(params)
383
385
 
384
386
  # Evaluate the output shape
385
387
  abstract_y = jax.eval_shape(