brainstate 0.1.0.post20241210__py2.py3-none-any.whl → 0.1.0.post20241220__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@ class JittedFunction(Callable):
40
40
  jitted_fun: jax.stages.Wrapped # the jitted function
41
41
  clear_cache: Callable # clear the cache of the jitted function
42
42
  eval_shape: Callable # evaluate the shape of the jitted function
43
- lower: Callable # lower the jitted function
43
+ compile: Callable # lower the jitted function
44
44
  trace: Callable # trace the jitted
45
45
 
46
46
  def __call__(self, *args, **kwargs):
@@ -104,7 +104,18 @@ def _get_jitted_fun(
104
104
  def eval_shape():
105
105
  raise NotImplementedError
106
106
 
107
- def lower():
107
+ def trace():
108
+ """Trace this function explicitly for the given arguments.
109
+
110
+ A traced function is staged out of Python and translated to a jaxpr. It is
111
+ ready for lowering but not yet lowered.
112
+
113
+ Returns:
114
+ A ``Traced`` instance representing the tracing.
115
+ """
116
+ raise NotImplementedError
117
+
118
+ def compile(*args, **params):
108
119
  """Lower this function explicitly for the given arguments.
109
120
 
110
121
  A lowered function is staged out of Python and translated to a
@@ -114,18 +125,13 @@ def _get_jitted_fun(
114
125
  Returns:
115
126
  A ``Lowered`` instance representing the lowering.
116
127
  """
117
- raise NotImplementedError
118
-
119
- def trace():
120
- """Trace this function explicitly for the given arguments.
128
+ # compile the function and get the state trace
129
+ state_trace = fun.compile_function_and_get_state_trace(*args, **params, return_only_write=True)
130
+ read_state_vals = state_trace.get_read_state_values(True)
121
131
 
122
- A traced function is staged out of Python and translated to a jaxpr. It is
123
- ready for lowering but not yet lowered.
132
+ # call the jitted function
133
+ return jit_fun.lower(state_trace.get_state_values(), *args, **params).compile()
124
134
 
125
- Returns:
126
- A ``Traced`` instance representing the tracing.
127
- """
128
- raise NotImplementedError
129
135
 
130
136
  jitted_fun: JittedFunction
131
137
 
@@ -144,8 +150,8 @@ def _get_jitted_fun(
144
150
  # evaluate the shape of the jitted function
145
151
  jitted_fun.eval_shape = eval_shape
146
152
 
147
- # lower the jitted function
148
- jitted_fun.lower = lower
153
+ # compile the jitted function
154
+ jitted_fun.compile = compile
149
155
 
150
156
  # trace the jitted
151
157
  jitted_fun.trace = trace
@@ -63,7 +63,7 @@ def scan(
63
63
  length: int | None = None,
64
64
  reverse: bool = False,
65
65
  unroll: int | bool = 1,
66
- pbar: ProgressBar | None = None,
66
+ pbar: ProgressBar | int | None = None,
67
67
  ) -> Tuple[Carry, Y]:
68
68
  """
69
69
  Scan a function over leading array axes while carrying along state.
@@ -177,7 +177,13 @@ def scan(
177
177
  has_pbar = False
178
178
  if pbar is not None:
179
179
  has_pbar = True
180
- f = _wrap_fun_with_pbar(f, pbar.init(length))
180
+ if isinstance(pbar, ProgressBar):
181
+ pbar_runner = pbar.init(length)
182
+ elif isinstance(pbar, int):
183
+ pbar_runner = ProgressBar(freq=pbar).init(length)
184
+ else:
185
+ raise TypeError("pbar argument should be a ProgressBar instance or an integer.")
186
+ f = _wrap_fun_with_pbar(f, pbar_runner)
181
187
  init = (0, init) if pbar else init
182
188
 
183
189
  # not jit
@@ -230,7 +236,7 @@ def checkpointed_scan(
230
236
  xs: X,
231
237
  length: Optional[int] = None,
232
238
  base: int = 16,
233
- pbar: Optional[ProgressBar] = None,
239
+ pbar: Optional[ProgressBar | int] = None,
234
240
  ) -> Tuple[Carry, Y]:
235
241
  """
236
242
  Scan a function over leading array axes while carrying along state.
@@ -288,8 +294,10 @@ def checkpointed_scan(
288
294
  length, = unique_lengths
289
295
 
290
296
  # function with progress bar
291
- if pbar is not None:
297
+ if isinstance(pbar, ProgressBar):
292
298
  pbar_runner = pbar.init(length)
299
+ elif isinstance(pbar, int):
300
+ pbar_runner = ProgressBar(freq=pbar).init(length)
293
301
  else:
294
302
  pbar_runner = None
295
303
 
@@ -380,7 +388,7 @@ def for_loop(
380
388
  length: Optional[int] = None,
381
389
  reverse: bool = False,
382
390
  unroll: int | bool = 1,
383
- pbar: Optional[ProgressBar] = None
391
+ pbar: Optional[ProgressBar | int] = None
384
392
  ) -> Y:
385
393
  """
386
394
  ``for-loop`` control flow with :py:class:`~.State`.
@@ -432,7 +440,7 @@ def checkpointed_for_loop(
432
440
  *xs: X,
433
441
  length: Optional[int] = None,
434
442
  base: int = 16,
435
- pbar: Optional[ProgressBar] = None,
443
+ pbar: Optional[ProgressBar | int] = None,
436
444
  ) -> Y:
437
445
  """
438
446
  ``for-loop`` control flow with :py:class:`~.State` with a checkpointed version, similar to :py:func:`for_loop`.
@@ -35,6 +35,8 @@ class ProgressBar(object):
35
35
 
36
36
  def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
37
37
  self.print_freq = freq
38
+ if isinstance(freq, int):
39
+ assert freq > 0, "Print rate should be > 0."
38
40
  self.print_count = count
39
41
  if self.print_freq is not None and self.print_count is not None:
40
42
  raise ValueError("Cannot specify both count and freq.")
@@ -114,17 +116,17 @@ class ProgressBarRunner(object):
114
116
 
115
117
  _ = jax.lax.cond(
116
118
  iter_num == 0,
117
- lambda: jax.debug.callback(self._define_tqdm),
119
+ lambda: jax.debug.callback(self._define_tqdm, ordered=True),
118
120
  lambda: None,
119
121
  )
120
122
  _ = jax.lax.cond(
121
123
  iter_num % self.print_freq == (self.print_freq - 1),
122
- lambda: jax.debug.callback(self._update_tqdm),
124
+ lambda: jax.debug.callback(self._update_tqdm, ordered=True),
123
125
  lambda: None,
124
126
  )
125
127
  _ = jax.lax.cond(
126
128
  iter_num == self.n - 1,
127
- lambda: jax.debug.callback(self._close_tqdm),
129
+ lambda: jax.debug.callback(self._close_tqdm, ordered=True),
128
130
  lambda: None,
129
131
  )
130
132
 
@@ -14,14 +14,16 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
+ from ._csr import *
17
18
  from ._csr_mv import *
18
- from ._csr_mv import __all__ as __all_csr
19
19
  from ._fixedprob_mv import *
20
- from ._fixedprob_mv import __all__ as __all_fixed_probability
21
20
  from ._linear_mv import *
22
21
  from ._xla_custom_op import *
23
- from ._xla_custom_op import __all__ as __all_xla_custom_op
24
- from ._linear_mv import __all__ as __all_linear
25
22
 
26
- __all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
27
- del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
23
+ __all__ = [
24
+ 'CSRLinear',
25
+ 'FixedProb',
26
+ 'XLACustomOp',
27
+ 'CSR',
28
+ 'CSC',
29
+ ]