brainstate 0.1.0.post20241125__py2.py3-none-any.whl → 0.1.0.post20241209__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_state.py CHANGED
@@ -99,6 +99,9 @@ def catch_new_states(tag: str = None) -> List:
99
99
 
100
100
 
101
101
  class Catcher:
102
+ """
103
+ The catcher to catch the new states.
104
+ """
102
105
  def __init__(self, tag: str):
103
106
  self.tag = tag
104
107
  self.state_ids = set()
@@ -231,6 +234,7 @@ class State(Generic[A], PrettyRepr):
231
234
  # avoid using self._setattr to avoid the check
232
235
  vars(self).update(metadata)
233
236
 
237
+ # record the state initialization
234
238
  record_state_init(self)
235
239
 
236
240
  if not TYPE_CHECKING:
@@ -290,7 +294,6 @@ class State(Generic[A], PrettyRepr):
290
294
  v: The value.
291
295
  """
292
296
  self.write_value(v)
293
- self._been_writen = True
294
297
 
295
298
  def write_value(self, v) -> None:
296
299
  # value checking
@@ -301,6 +304,8 @@ class State(Generic[A], PrettyRepr):
301
304
  record_state_value_write(self)
302
305
  # set the value
303
306
  self._value = v
307
+ # set flag
308
+ self._been_writen = True
304
309
 
305
310
  def restore_value(self, v) -> None:
306
311
  """
@@ -511,6 +516,15 @@ class LongTermState(State):
511
516
  __module__ = 'brainstate'
512
517
 
513
518
 
519
+ class BatchState(LongTermState):
520
+ """
521
+ The batch state, which is used to store the batch data in the program.
522
+ """
523
+
524
+ __module__ = 'brainstate'
525
+
526
+
527
+
514
528
  class HiddenState(ShortTermState):
515
529
  """
516
530
  The hidden state, which is used to store the hidden data in a dynamic model.
@@ -204,7 +204,91 @@ class TestVmap(unittest.TestCase):
204
204
  self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4)))
205
205
  self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
206
206
 
207
+ write_state_ids = [id(st) for st in trace.get_write_states()]
208
+ read_state_ids = [id(st) for st in trace.get_read_states()]
209
+
210
+ assert id(foo.a) in read_state_ids
211
+ assert id(foo.b) in write_state_ids
212
+
213
+ print(trace.get_write_states())
214
+ print(trace.get_read_states())
215
+
216
+
217
+
218
+ def test_vmap_jit(self):
219
+ class Foo(bst.nn.Module):
220
+ def __init__(self):
221
+ super().__init__()
222
+ self.a = bst.ParamState(jnp.arange(4))
223
+ self.b = bst.ShortTermState(jnp.arange(4))
224
+
225
+ def __call__(self):
226
+ self.b.value = self.a.value * self.b.value
227
+
228
+ @bst.augment.vmap
229
+ def mul(foo):
230
+ foo()
231
+
232
+ @bst.compile.jit
233
+ def mul_jit(inp):
234
+ mul(foo)
235
+ foo.a.value += inp
236
+
237
+ foo = Foo()
238
+ with bst.StateTraceStack() as trace:
239
+ mul_jit(1.)
240
+
241
+ print(foo.a.value)
242
+ print(foo.b.value)
243
+ self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
244
+ self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
245
+
246
+ write_state_ids = [id(st) for st in trace.get_write_states()]
247
+ read_state_ids = [id(st) for st in trace.get_read_states()]
248
+
249
+ assert id(foo.a) in write_state_ids
250
+ assert id(foo.b) in write_state_ids
251
+
207
252
  print(trace.get_write_states())
208
253
  print(trace.get_read_states())
209
254
 
210
255
 
256
+ def test_vmap_grad(self):
257
+ class Foo(bst.nn.Module):
258
+ def __init__(self):
259
+ super().__init__()
260
+ self.a = bst.ParamState(jnp.arange(4.))
261
+ self.b = bst.ShortTermState(jnp.arange(4.))
262
+
263
+ def __call__(self):
264
+ self.b.value = self.a.value * self.b.value
265
+
266
+ @bst.augment.vmap
267
+ def mul(foo):
268
+ foo()
269
+
270
+ def loss():
271
+ mul(foo)
272
+ return jnp.sum(foo.b.value)
273
+
274
+ foo = Foo()
275
+ with bst.StateTraceStack() as trace:
276
+ grads, loss = bst.augment.grad(loss, foo.states(bst.ParamState), return_value=True)()
277
+ print(grads)
278
+ print(loss)
279
+
280
+ # print(foo.a.value)
281
+ # print(foo.b.value)
282
+ # self.assertTrue(jnp.allclose(foo.a.value, jnp.arange(4) + 1.))
283
+ # self.assertTrue(jnp.allclose(foo.b.value, jnp.arange(4) * jnp.arange(4)))
284
+ #
285
+ # write_state_ids = [id(st) for st in trace.get_write_states()]
286
+ # read_state_ids = [id(st) for st in trace.get_read_states()]
287
+ #
288
+ # assert id(foo.a) in write_state_ids
289
+ # assert id(foo.b) in write_state_ids
290
+ #
291
+ # print(trace.get_write_states())
292
+ # print(trace.get_read_states())
293
+
294
+
@@ -211,7 +211,11 @@ def scan(
211
211
 
212
212
  # scan
213
213
  init = (all_writen_state_vals, init)
214
- (all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f, init, xs, length=length, reverse=reverse,
214
+ (all_writen_state_vals, carry), ys = jax.lax.scan(wrapped_f,
215
+ init,
216
+ xs,
217
+ length=length,
218
+ reverse=reverse,
215
219
  unroll=unroll)
216
220
  # assign the written state values and restore the read state values
217
221
  write_back_state_values(state_trace, all_read_state_vals, all_writen_state_vals)
@@ -72,7 +72,6 @@ from jax.util import wraps
72
72
  from brainstate._state import State, StateTraceStack
73
73
  from brainstate._utils import set_module_as
74
74
  from brainstate.typing import PyTree
75
- from brainstate.util._tracers import new_jax_trace
76
75
 
77
76
  AxisName = Hashable
78
77
 
@@ -112,28 +111,27 @@ def _new_arg_fn(frame, trace, aval):
112
111
  return tracer
113
112
 
114
113
 
115
- def _init_state_trace() -> StateTraceStack:
116
- # Should be within the calling of ``jax.make_jaxpr()``
117
- frame, trace = new_jax_trace()
114
+ def _new_jax_trace():
115
+ main = jax.core.thread_local_state.trace_state.trace_stack.stack[-1]
116
+ frame = main.jaxpr_stack[-1]
117
+ trace = pe.DynamicJaxprTrace(main, jax.core.cur_sublevel())
118
+ return frame, trace
119
+
120
+
121
+ def _init_state_trace_stack() -> StateTraceStack:
118
122
  state_trace: StateTraceStack = StateTraceStack()
119
- # Set the function to transform the new argument to a tracer
120
- state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
121
- return state_trace
122
123
 
124
+ if jax.__version_info__ < (0, 4, 36):
125
+ # Should be within the calling of ``jax.make_jaxpr()``
126
+ frame, trace = _new_jax_trace()
127
+ # Set the function to transform the new argument to a tracer
128
+ state_trace.set_new_arg(functools.partial(_new_arg_fn, frame, trace))
129
+ return state_trace
123
130
 
124
- # def wrapped_abstractify(x: Any) -> Any:
125
- # """
126
- # Abstractify the input.
127
- #
128
- # Args:
129
- # x: The input.
130
- #
131
- # Returns:
132
- # The abstractified input.
133
- # """
134
- # if isinstance(x, pe.DynamicJaxprTracer):
135
- # return jax.core.ShapedArray(x.aval.shape, x.aval.dtype, weak_type=x.aval.weak_type)
136
- # return shaped_abstractify(x)
131
+ else:
132
+ trace = jax.core.trace_ctx.trace
133
+ state_trace.set_new_arg(trace.new_arg)
134
+ return state_trace
137
135
 
138
136
 
139
137
  class StatefulFunction(object):
@@ -383,12 +381,15 @@ class StatefulFunction(object):
383
381
  A tuple of the states that are read and written by the function and the output of the function.
384
382
  """
385
383
  # state trace
386
- state_trace = _init_state_trace()
384
+ state_trace = _init_state_trace_stack()
387
385
  self._cached_state_trace[cache_key] = state_trace
388
386
  with state_trace:
389
387
  out = self.fun(*args, **kwargs)
390
- state_values = state_trace.get_write_state_values(
391
- True) if return_only_write else state_trace.get_state_values()
388
+ state_values = (
389
+ state_trace.get_write_state_values(True)
390
+ if return_only_write else
391
+ state_trace.get_state_values()
392
+ )
392
393
  state_trace.recovery_original_values()
393
394
 
394
395
  # State instance as functional returns is not allowed.
@@ -419,17 +420,21 @@ class StatefulFunction(object):
419
420
  try:
420
421
  # jaxpr
421
422
  jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
422
- functools.partial(self._wrapped_fun_to_eval, cache_key, return_only_write=return_only_write),
423
+ functools.partial(
424
+ self._wrapped_fun_to_eval,
425
+ cache_key,
426
+ return_only_write=return_only_write
427
+ ),
423
428
  static_argnums=self.static_argnums,
424
429
  axis_env=self.axis_env,
425
430
  return_shape=True,
426
431
  abstracted_axes=self.abstracted_axes
427
432
  )(*args, **kwargs)
428
-
429
433
  # returns
430
434
  self._cached_jaxpr_out_tree[cache_key] = jax.tree.structure((out_shapes, state_shapes))
431
435
  self._cached_out_shapes[cache_key] = (out_shapes, state_shapes)
432
436
  self._cached_jaxpr[cache_key] = jaxpr
437
+
433
438
  except Exception as e:
434
439
  try:
435
440
  self._cached_state_trace.pop(cache_key)
@@ -93,19 +93,37 @@ class ProgressBarRunner(object):
93
93
  self.tqdm_bars[0].update(self.remainder)
94
94
  self.tqdm_bars[0].close()
95
95
 
96
+ def _tqdm(self, is_init, is_print, is_final):
97
+ if is_init:
98
+ self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
99
+ self.tqdm_bars[0].set_description(self.message, refresh=False)
100
+ if is_print:
101
+ self.tqdm_bars[0].update(self.print_freq)
102
+ if is_final:
103
+ if self.remainder > 0:
104
+ self.tqdm_bars[0].update(self.remainder)
105
+ self.tqdm_bars[0].close()
106
+
96
107
  def __call__(self, iter_num, *args, **kwargs):
97
- _ = jax.lax.cond(
108
+ jax.debug.callback(
109
+ self._tqdm,
98
110
  iter_num == 0,
99
- lambda: jax.debug.callback(self._define_tqdm),
100
- lambda: None,
101
- )
102
- _ = jax.lax.cond(
103
111
  (iter_num + 1) % self.print_freq == 0,
104
- lambda: jax.debug.callback(self._update_tqdm),
105
- lambda: None,
106
- )
107
- _ = jax.lax.cond(
108
- iter_num == self.n - 1,
109
- lambda: jax.debug.callback(self._close_tqdm),
110
- lambda: None,
112
+ iter_num == self.n - 1
111
113
  )
114
+
115
+ # _ = jax.lax.cond(
116
+ # iter_num == 0,
117
+ # lambda: jax.debug.callback(self._define_tqdm, ordered=True),
118
+ # lambda: None,
119
+ # )
120
+ # _ = jax.lax.cond(
121
+ # (iter_num + 1) % self.print_freq == 0,
122
+ # lambda: jax.debug.callback(self._update_tqdm, ordered=True),
123
+ # lambda: None,
124
+ # )
125
+ # _ = jax.lax.cond(
126
+ # iter_num == self.n - 1,
127
+ # lambda: jax.debug.callback(self._close_tqdm, ordered=True),
128
+ # lambda: None,
129
+ # )
@@ -588,8 +588,7 @@ def glu(x: ArrayLike, axis: int = -1) -> Union[jax.Array, u.Quantity]:
588
588
 
589
589
  def log_softmax(x: ArrayLike,
590
590
  axis: int | tuple[int, ...] | None = -1,
591
- where: ArrayLike | None = None,
592
- initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
591
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
593
592
  r"""Log-Softmax function.
594
593
 
595
594
  Computes the logarithm of the :code:`softmax` function, which rescales
@@ -604,8 +603,6 @@ def log_softmax(x: ArrayLike,
604
603
  axis: the axis or axes along which the :code:`log_softmax` should be
605
604
  computed. Either an integer or a tuple of integers.
606
605
  where: Elements to include in the :code:`log_softmax`.
607
- initial: The minimum value used to shift the input array. Must be present
608
- when :code:`where` is not None.
609
606
 
610
607
  Returns:
611
608
  An array.
@@ -613,15 +610,12 @@ def log_softmax(x: ArrayLike,
613
610
  See also:
614
611
  :func:`softmax`
615
612
  """
616
- if initial is not None:
617
- initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
618
- return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where, initial=initial)
613
+ return _keep_unit(jax.nn.log_softmax, x, axis=axis, where=where)
619
614
 
620
615
 
621
616
  def softmax(x: ArrayLike,
622
617
  axis: int | tuple[int, ...] | None = -1,
623
- where: ArrayLike | None = None,
624
- initial: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
618
+ where: ArrayLike | None = None) -> Union[jax.Array, u.Quantity]:
625
619
  r"""Softmax function.
626
620
 
627
621
  Computes the function which rescales elements to the range :math:`[0, 1]`
@@ -645,9 +639,7 @@ def softmax(x: ArrayLike,
645
639
  See also:
646
640
  :func:`log_softmax`
647
641
  """
648
- if initial is not None:
649
- initial = u.Quantity(initial).in_unit(u.get_unit(x)).mantissa
650
- return _keep_unit(jax.nn.softmax, x, axis=axis, where=where, initial=initial)
642
+ return _keep_unit(jax.nn.softmax, x, axis=axis, where=where)
651
643
 
652
644
 
653
645
  def standardize(x: ArrayLike,
@@ -608,7 +608,10 @@ def _get_children(graph_def, state_mapping, index_ref, index_ref_cache):
608
608
  if isinstance(value, TreefyState):
609
609
  variable.update_from_ref(value)
610
610
  elif isinstance(value, State):
611
- variable.restore_value(value.value)
611
+ if value._been_writen:
612
+ variable.write_value(value.value)
613
+ else:
614
+ variable.restore_value(value.value)
612
615
  else:
613
616
  raise ValueError(f'Expected a State type for {key!r}, but got {type(value)}.')
614
617
  else: # if it doesn't, create a new variable
@@ -20,8 +20,10 @@ from typing import Dict, Callable, TypeVar
20
20
 
21
21
  import jax
22
22
 
23
+ from brainstate._state import catch_new_states
23
24
  from brainstate._utils import set_module_as
24
25
  from brainstate.graph import nodes
26
+ from brainstate.util._filter import Filter
25
27
  from ._module import Module
26
28
 
27
29
  # the maximum order
@@ -74,16 +76,29 @@ def call_order(level: int = 0, check_order_boundary: bool = True):
74
76
 
75
77
 
76
78
  @set_module_as('brainstate.nn')
77
- def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
79
+ def init_all_states(
80
+ target: T,
81
+ *args,
82
+ exclude: Filter = None,
83
+ **kwargs
84
+ ) -> T:
78
85
  """
79
86
  Collectively initialize states of all children nodes in the given target.
80
87
 
81
88
  Args:
82
89
  target: The target Module.
90
+ exclude: The filter to exclude some nodes.
91
+ tag: The tag for the new states.
92
+ args: The positional arguments for the initialization, which will be passed to the `init_state` method
93
+ of each node.
94
+ kwargs: The keyword arguments for the initialization, which will be passed to the `init_state` method
95
+ of each node.
83
96
 
84
97
  Returns:
85
98
  The target Module.
86
99
  """
100
+
101
+ # node that has `call_order` decorated
87
102
  nodes_with_order = []
88
103
 
89
104
  nodes_ = nodes(target).filter(Module)
@@ -97,7 +112,7 @@ def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
97
112
  else:
98
113
  node.init_state(*args, **kwargs)
99
114
 
100
- # reset the node's states
115
+ # reset the node's states with `call_order`
101
116
  for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
102
117
  node.init_state(*args, **kwargs)
103
118
 
@@ -115,6 +130,7 @@ def reset_all_states(target: Module, *args, **kwargs) -> Module:
115
130
  Returns:
116
131
  The target Module.
117
132
  """
133
+
118
134
  nodes_with_order = []
119
135
 
120
136
  # reset node whose `init_state` has no `call_order`
@@ -112,7 +112,7 @@ class STP(Synapse):
112
112
  self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
113
113
 
114
114
  def update(self, pre_spike):
115
- u = exp_euler_step(lambda u: self.U - u / self.tau_f, self.u.value)
115
+ u = exp_euler_step(lambda u: - u / self.tau_f, self.u.value)
116
116
  x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
117
117
 
118
118
  # --- original code:
@@ -17,7 +17,7 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  from functools import partial
20
- from typing import Optional
20
+ from typing import Optional, Sequence
21
21
 
22
22
  import brainunit as u
23
23
  import jax.numpy as jnp
@@ -29,7 +29,6 @@ from brainstate.typing import Size
29
29
 
30
30
  __all__ = [
31
31
  'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
32
- 'AlphaDropout', 'FeatureAlphaDropout',
33
32
  ]
34
33
 
35
34
 
@@ -47,9 +46,9 @@ class Dropout(ElementWiseBlock):
47
46
  research 15.1 (2014): 1929-1958.
48
47
 
49
48
  Args:
50
- prob: Probability to keep element of the tensor.
51
- mode: Mode. The computation mode of the object.
52
- name: str. The name of the dynamic system.
49
+ prob: Probability to keep element of the tensor.
50
+ broadcast_dims: dimensions that will share the same dropout mask.
51
+ name: str. The name of the dynamic system.
53
52
 
54
53
  """
55
54
  __module__ = 'brainstate.nn'
@@ -57,20 +56,28 @@ class Dropout(ElementWiseBlock):
57
56
  def __init__(
58
57
  self,
59
58
  prob: float = 0.5,
59
+ broadcast_dims: Sequence[int] = (),
60
60
  name: Optional[str] = None
61
61
  ) -> None:
62
62
  super().__init__(name=name)
63
63
  assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
64
64
  self.prob = prob
65
+ self.broadcast_dims = broadcast_dims
65
66
 
66
67
  def __call__(self, x):
67
68
  dtype = u.math.get_dtype(x)
68
69
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
69
70
  if fit_phase and self.prob < 1.:
70
- keep_mask = random.bernoulli(self.prob, x.shape)
71
- return jnp.where(keep_mask,
72
- jnp.asarray(x / self.prob, dtype=dtype),
73
- jnp.asarray(0., dtype=dtype))
71
+ broadcast_shape = list(x.shape)
72
+ for dim in self.broadcast_dims:
73
+ broadcast_shape[dim] = 1
74
+ keep_mask = random.bernoulli(self.prob, broadcast_shape)
75
+ keep_mask = jnp.broadcast_to(keep_mask, x.shape)
76
+ return jnp.where(
77
+ keep_mask,
78
+ jnp.asarray(x / self.prob, dtype=dtype),
79
+ jnp.asarray(0., dtype=dtype)
80
+ )
74
81
  else:
75
82
  return x
76
83
 
@@ -88,12 +95,11 @@ class _DropoutNd(ElementWiseBlock):
88
95
  name: Optional[str] = None
89
96
  ) -> None:
90
97
  super().__init__(name=name)
91
- assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
98
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
92
99
  self.prob = prob
93
100
  self.channel_axis = channel_axis
94
101
 
95
102
  def __call__(self, x):
96
-
97
103
  # check input shape
98
104
  inp_dim = u.math.ndim(x)
99
105
  if inp_dim not in (self.minimal_dim, self.minimal_dim + 1):
@@ -112,12 +118,15 @@ class _DropoutNd(ElementWiseBlock):
112
118
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
113
119
 
114
120
  # generate mask
115
- if fit_phase:
121
+ if fit_phase and self.prob < 1.:
116
122
  dtype = u.math.get_dtype(x)
117
- keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
118
- return jnp.where(keep_mask,
119
- jnp.asarray(x / self.prob, dtype=dtype),
120
- jnp.asarray(0., dtype=dtype))
123
+ keep_mask = random.bernoulli(self.prob, mask_shape)
124
+ keep_mask = jnp.broadcast_to(keep_mask, x.shape)
125
+ return jnp.where(
126
+ keep_mask,
127
+ jnp.asarray(x / self.prob, dtype=dtype),
128
+ jnp.asarray(0., dtype=dtype)
129
+ )
121
130
  else:
122
131
  return x
123
132
 
@@ -296,8 +305,8 @@ class AlphaDropout(_DropoutNd):
296
305
  """
297
306
  __module__ = 'brainstate.nn'
298
307
 
299
- def forward(self, x):
300
- return F.alpha_dropout(x, self.p, self.training)
308
+ def update(self, *args, **kwargs):
309
+ raise NotImplementedError("AlphaDropout is not supported in the current version.")
301
310
 
302
311
 
303
312
  class FeatureAlphaDropout(_DropoutNd):
@@ -344,8 +353,8 @@ class FeatureAlphaDropout(_DropoutNd):
344
353
  """
345
354
  __module__ = 'brainstate.nn'
346
355
 
347
- def forward(self, x):
348
- return F.feature_alpha_dropout(x, self.p, self.training)
356
+ def update(self, *args, **kwargs):
357
+ raise NotImplementedError("FeatureAlphaDropout is not supported in the current version.")
349
358
 
350
359
 
351
360
  class DropoutFixed(ElementWiseBlock):
@@ -396,7 +405,7 @@ class DropoutFixed(ElementWiseBlock):
396
405
  name: Optional[str] = None
397
406
  ) -> None:
398
407
  super().__init__(name=name)
399
- assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
408
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
400
409
  self.prob = prob
401
410
  self.in_size = in_size
402
411
  self.out_size = in_size
@@ -407,7 +416,7 @@ class DropoutFixed(ElementWiseBlock):
407
416
  def update(self, x):
408
417
  dtype = u.math.get_dtype(x)
409
418
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
410
- if fit_phase:
419
+ if fit_phase and self.prob < 1.:
411
420
  if self.mask.value.shape != x.shape:
412
421
  raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
413
422
  f"Please call `init_state()` method first.")