brainstate 0.1.0.post20241210__py2.py3-none-any.whl → 0.1.0.post20241220__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,17 +17,23 @@ from __future__ import annotations
17
17
  from typing import Union, Optional, Sequence, Callable
18
18
 
19
19
  import brainunit as u
20
+ import jax
21
+ import numpy as np
20
22
 
21
23
  from brainstate import environ, init, random
22
24
  from brainstate._state import ShortTermState
23
- from brainstate.compile import while_loop
24
- from brainstate.nn._dynamics._dynamics_base import Dynamics
25
+ from brainstate._state import State
26
+ from brainstate.compile import while_loop, cond
27
+ from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
28
+ from brainstate.nn._module import Module
25
29
  from brainstate.typing import ArrayLike, Size, DTypeLike
26
30
 
27
31
  __all__ = [
28
32
  'SpikeTime',
29
33
  'PoissonSpike',
30
34
  'PoissonEncoder',
35
+ 'PoissonInput',
36
+ 'poisson_input',
31
37
  ]
32
38
 
33
39
 
@@ -152,3 +158,122 @@ class PoissonEncoder(Dynamics):
152
158
  spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
153
159
  spikes = u.math.asarray(spikes, dtype=self.spk_type)
154
160
  return spikes
161
+
162
+
163
+ class PoissonInput(Module):
164
+ """
165
+ Poisson Input to the given :py:class:`brainstate.State`.
166
+
167
+ Adds independent Poisson input to a target variable. For large
168
+ numbers of inputs, this is much more efficient than creating a
169
+ `PoissonGroup`. The synaptic events are generated randomly during the
170
+ simulation and are not preloaded and stored in memory. All the inputs must
171
+ target the same variable, have the same frequency and same synaptic weight.
172
+ All neurons in the target variable receive independent realizations of
173
+ Poisson spike trains.
174
+
175
+ Args:
176
+ target: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`.
177
+ num_input: The number of inputs.
178
+ freq: The frequency of each of the inputs. Must be a scalar.
179
+ weight: The synaptic weight. Must be a scalar.
180
+ name: The target name.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ target: Prefetch,
186
+ indices: Union[np.ndarray, jax.Array],
187
+ num_input: int,
188
+ freq: Union[int, float],
189
+ weight: Union[int, float],
190
+ name: Optional[str] = None,
191
+ ):
192
+ super().__init__(name=name)
193
+
194
+ self.target = target
195
+ self.indices = indices
196
+ self.num_input = num_input
197
+ self.freq = freq
198
+ self.weight = weight
199
+
200
+ def update(self):
201
+ p = self.freq * environ.get_dt()
202
+ a = self.num_input * p
203
+ b = self.num_input * (1 - p)
204
+
205
+ target = self.target()
206
+ target_state = getattr(self.target.module, self.target.item)
207
+
208
+ # generate Poisson input
209
+ inp = cond(
210
+ u.math.logical_and(a > 5, b > 5),
211
+ lambda: random.normal(a, b * p, self.indices.shape),
212
+ lambda: random.binomial(self.num_input, p, self.indices.shape).astype(float)
213
+ )
214
+
215
+ # update target variable
216
+ target_state.value = target.at[self.indices].add(inp * self.weight)
217
+
218
+
219
+ def poisson_input(
220
+ freq: ArrayLike,
221
+ num_input: int,
222
+ weight: ArrayLike,
223
+ target: State,
224
+ indices: Optional[Union[np.ndarray, jax.Array]] = None,
225
+ ):
226
+ """
227
+ Poisson Input to the given :py:class:`brainstate.State`.
228
+ """
229
+ assert isinstance(target, State), 'The target must be a State.'
230
+ p = freq * environ.get_dt()
231
+ a = num_input * p
232
+ b = num_input * (1 - p)
233
+ tar_val = target.value
234
+ if indices is None:
235
+ # generate Poisson input
236
+ inp = cond(
237
+ u.math.logical_and(a > 5, b > 5),
238
+ lambda: jax.tree.map(
239
+ lambda tar: random.normal(a, b * p, tar.shape),
240
+ tar_val,
241
+ is_leaf=u.math.is_quantity
242
+ ),
243
+ lambda: jax.tree.map(
244
+ lambda tar: random.binomial(num_input, p, tar.shape).astype(float),
245
+ tar_val,
246
+ is_leaf=u.math.is_quantity
247
+ )
248
+ )
249
+
250
+ # update target variable
251
+ target.value = jax.tree.map(
252
+ lambda x: x * weight,
253
+ inp,
254
+ is_leaf=u.math.is_quantity
255
+ )
256
+
257
+ else:
258
+ # generate Poisson input
259
+ inp = cond(
260
+ u.math.logical_and(a > 5, b > 5),
261
+ lambda: jax.tree.map(
262
+ lambda tar: random.normal(a, b * p, tar[indices].shape),
263
+ tar_val,
264
+ is_leaf=u.math.is_quantity
265
+ ),
266
+ lambda: jax.tree.map(
267
+ lambda tar: random.binomial(num_input, p, tar[indices].shape).astype(float),
268
+ tar_val,
269
+ is_leaf=u.math.is_quantity
270
+ )
271
+ )
272
+
273
+ # update target variable
274
+ target.value = jax.tree.map(
275
+ lambda x, tar: tar.at[indices].add(x * weight),
276
+ inp,
277
+ tar_val,
278
+ is_leaf=u.math.is_quantity
279
+ )
@@ -107,6 +107,8 @@ class Dynamics(Module):
107
107
 
108
108
  __module__ = 'brainstate.nn'
109
109
 
110
+ graph_invisible_attrs = ('_before_updates', '_after_updates', '_current_inputs', '_delta_inputs')
111
+
110
112
  # before updates
111
113
  _before_updates: Optional[Dict[Hashable, Callable]]
112
114
 
@@ -443,6 +445,16 @@ class Prefetch(Node):
443
445
  item = _get_prefetch_item(self)
444
446
  return item.value if isinstance(item, State) else item
445
447
 
448
+ def get_item_value(self):
449
+ item = _get_prefetch_item(self)
450
+ return item.value if isinstance(item, State) else item
451
+
452
+ def get_item(self):
453
+ """
454
+ Get
455
+ """
456
+ return _get_prefetch_item(self)
457
+
446
458
 
447
459
  class PrefetchDelay(Node):
448
460
  def __init__(self, module: Dynamics, item: str):
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Union, Callable
17
+ from typing import Union, Callable, Optional
18
18
 
19
19
  from brainstate._state import State
20
20
  from brainstate.mixin import AlignPost, ParamDescriber, BindCondData, JointTypes
@@ -60,24 +60,28 @@ def is_instance(x, cls) -> bool:
60
60
  return isinstance(x, cls)
61
61
 
62
62
 
63
- def get_post_repr(syn, out):
64
- return f'{syn.identifier} // {out.identifier}'
63
+ def get_post_repr(label, syn, out):
64
+ if label is None:
65
+ return f'{syn.identifier} // {out.identifier}'
66
+ else:
67
+ return f'{label}{syn.identifier} // {out.identifier}'
65
68
 
66
69
 
67
70
  def align_post_add_bef_update(
68
71
  syn_desc: ParamDescriber[AlignPost],
69
72
  out_desc: ParamDescriber[BindCondData],
70
73
  post: Dynamics,
71
- proj_name: str
74
+ proj_name: str,
75
+ label: str,
72
76
  ):
73
77
  # synapse and output initialization
74
- _post_repr = get_post_repr(syn_desc, out_desc)
78
+ _post_repr = get_post_repr(label, syn_desc, out_desc)
75
79
  if not post._has_before_update(_post_repr):
76
80
  syn_cls = syn_desc()
77
81
  out_cls = out_desc()
78
82
 
79
83
  # synapse and output initialization
80
- post.add_current_input(proj_name, out_cls)
84
+ post.add_current_input(proj_name, out_cls, label=label)
81
85
  post._add_before_update(_post_repr, _AlignPost(syn_cls, out_cls))
82
86
  syn = post._get_before_update(_post_repr).syn
83
87
  out = post._get_before_update(_post_repr).out
@@ -139,6 +143,7 @@ class AlignPostProj(Interaction):
139
143
  syn: Union[ParamDescriber[AlignPost], AlignPost],
140
144
  out: Union[ParamDescriber[SynOut], SynOut],
141
145
  post: Dynamics,
146
+ label: Optional[str] = None,
142
147
  ):
143
148
  super().__init__(name=get_unique_name(self.__class__.__name__))
144
149
 
@@ -154,12 +159,21 @@ class AlignPostProj(Interaction):
154
159
  # checking synapse and output models
155
160
  if is_instance(syn, ParamDescriber[AlignPost]):
156
161
  if not is_instance(out, ParamDescriber[SynOut]):
162
+ if is_instance(out, ParamDescriber):
163
+ raise TypeError(
164
+ f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
165
+ f'the synapse is an instance of {AlignPost}, but got {out}.'
166
+ )
157
167
  raise TypeError(
158
168
  f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
159
169
  f'the synapse is a describer, but we got {out}.'
160
170
  )
161
171
  merging = True
162
172
  else:
173
+ if is_instance(syn, ParamDescriber):
174
+ raise TypeError(
175
+ f'The synapse should be an instance of describer {ParamDescriber[AlignPost]}, but got {syn}.'
176
+ )
163
177
  if not is_instance(out, SynOut):
164
178
  raise TypeError(
165
179
  f'The output should be an instance of {SynOut} when the synapse is '
@@ -176,7 +190,11 @@ class AlignPostProj(Interaction):
176
190
 
177
191
  if merging:
178
192
  # synapse and output initialization
179
- syn, out = align_post_add_bef_update(syn_desc=syn, out_desc=out, post=post, proj_name=self.name)
193
+ syn, out = align_post_add_bef_update(syn_desc=syn,
194
+ out_desc=out,
195
+ post=post,
196
+ proj_name=self.name,
197
+ label=label)
180
198
  else:
181
199
  post.add_current_input(self.name, out)
182
200
 
@@ -59,17 +59,17 @@ class TestDropout(unittest.TestCase):
59
59
  expected_non_zero_elements = input_data[output_data != 0] * scale_factor
60
60
  np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
61
61
 
62
- def test_Dropout1d(self):
63
- dropout_layer = bst.nn.Dropout1d(prob=0.5)
64
- input_data = np.random.randn(2, 3, 4)
65
- with bst.environ.context(fit=True):
66
- output_data = dropout_layer(input_data)
67
- self.assertEqual(input_data.shape, output_data.shape)
68
- self.assertTrue(np.any(output_data == 0))
69
- scale_factor = 1 / (1 - 0.5)
70
- non_zero_elements = output_data[output_data != 0]
71
- expected_non_zero_elements = input_data[output_data != 0] * scale_factor
72
- np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
62
+ # def test_Dropout1d(self):
63
+ # dropout_layer = bst.nn.Dropout1d(prob=0.5)
64
+ # input_data = np.random.randn(2, 3, 4)
65
+ # with bst.environ.context(fit=True):
66
+ # output_data = dropout_layer(input_data)
67
+ # self.assertEqual(input_data.shape, output_data.shape)
68
+ # self.assertTrue(np.any(output_data == 0))
69
+ # scale_factor = 1 / (1 - 0.5)
70
+ # non_zero_elements = output_data[output_data != 0]
71
+ # expected_non_zero_elements = input_data[output_data != 0] * scale_factor
72
+ # np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
73
73
 
74
74
  def test_Dropout2d(self):
75
75
  dropout_layer = bst.nn.Dropout2d(prob=0.5)
@@ -20,10 +20,7 @@ from __future__ import annotations
20
20
  from typing import Callable, Union, Optional
21
21
 
22
22
  import brainunit as u
23
- import jax
24
23
  import jax.numpy as jnp
25
- from jax.experimental.sparse.coo import coo_matvec_p, coo_matmat_p, COOInfo
26
- from jax.experimental.sparse.csr import csr_matvec_p, csr_matmat_p
27
24
 
28
25
  from brainstate import init, functional
29
26
  from brainstate._state import ParamState
@@ -34,9 +31,7 @@ __all__ = [
34
31
  'Linear',
35
32
  'ScaledWSLinear',
36
33
  'SignedWLinear',
37
- 'CSRLinear',
38
- 'CSCLinear',
39
- 'COOLinear',
34
+ 'SparseLinear',
40
35
  'AllToAll',
41
36
  'OneToOne',
42
37
  ]
@@ -198,270 +193,48 @@ class ScaledWSLinear(Module):
198
193
  return y
199
194
 
200
195
 
201
- def csr_matmat(data, indices, indptr, B: jax.Array, *, shape, transpose: bool = False) -> jax.Array:
202
- """Product of CSR sparse matrix and a dense matrix.
203
-
204
- Args:
205
- data : array of shape ``(nse,)``.
206
- indices : array of shape ``(nse,)``
207
- indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
208
- B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and
209
- dtype ``mat.dtype``
210
- transpose : boolean specifying whether to transpose the sparse matrix
211
- before computing.
212
-
213
- Returns:
214
- C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)``
215
- representing the matrix vector product.
196
+ class SparseLinear(Module):
216
197
  """
217
- return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
218
-
219
-
220
- def csr_matvec(data, indices, indptr, v, *, shape, transpose=False) -> jax.Array:
221
- """Product of CSR sparse matrix and a dense vector.
198
+ Linear layer with Sparse Matrix (can be ``brainunit.sparse.CSR``,
199
+ ``brainunit.sparse.CSC``, ``brainunit.sparse.COO``, or any other sparse matrix).
222
200
 
223
201
  Args:
224
- data : array of shape ``(nse,)``.
225
- indices : array of shape ``(nse,)``
226
- indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
227
- v : array of shape ``(shape[0] if transpose else shape[1],)``
228
- and dtype ``data.dtype``
229
- shape : length-2 tuple representing the matrix shape
230
- transpose : boolean specifying whether to transpose the sparse matrix
231
- before computing.
232
-
233
- Returns:
234
- y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
235
- the matrix vector product.
236
- """
237
- return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
238
-
239
-
240
- class CSRLinear(Module):
241
- """
242
- Linear layer with Compressed Sparse Row (CSR) matrix.
243
- """
244
- __module__ = 'brainstate.nn'
245
-
246
- def __init__(
247
- self,
248
- in_size: Size,
249
- out_size: Size,
250
- indptr: ArrayLike,
251
- indices: ArrayLike,
252
- weight: Union[Callable, ArrayLike],
253
- b_init: Optional[Union[Callable, ArrayLike]] = None,
254
- name: Optional[str] = None,
255
- ):
256
- super().__init__(name=name)
257
-
258
- # input and output shape
259
- self.in_size = in_size
260
- self.out_size = out_size
261
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
262
- 'and "out_size" must be the same.')
263
-
264
- # CSR data structure
265
- indptr = jnp.asarray(indptr)
266
- indices = jnp.asarray(indices)
267
- assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
268
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
269
- assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
270
- with jax.ensure_compile_time_eval():
271
- self.indptr = u.math.asarray(indptr)
272
- self.indices = u.math.asarray(indices)
273
-
274
- # weights
275
- weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
276
- params = dict(weight=weight)
277
- if b_init is not None:
278
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
279
- self.weight = ParamState(params)
280
-
281
- def update(self, x):
282
- data = self.weight.value['weight']
283
- data, w_unit = u.get_mantissa(data), u.get_unit(data)
284
- x, x_unit = u.get_mantissa(x), u.get_unit(x)
285
- shape = [self.in_size[-1], self.out_size[-1]]
286
- if x.ndim == 1:
287
- y = csr_matvec(data, self.indices, self.indptr, x, shape=shape)
288
- elif x.ndim == 2:
289
- y = csr_matmat(data, self.indices, self.indptr, x, shape=shape)
290
- else:
291
- raise NotImplementedError(f"matmul with object of shape {x.shape}")
292
- y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
293
- if 'bias' in self.weight.value:
294
- y = y + self.weight.value['bias']
295
- return y
296
-
297
-
298
- class CSCLinear(Module):
299
- """
300
- Linear layer with Compressed Sparse Column (CSC) matrix.
202
+ spar_mat: SparseMatrix. The sparse weight matrix.
203
+ in_size: Size. The input size.
204
+ name: str. The object name.
301
205
  """
302
206
  __module__ = 'brainstate.nn'
303
207
 
304
208
  def __init__(
305
209
  self,
306
- in_size: Size,
307
- out_size: Size,
308
- indptr: ArrayLike,
309
- indices: ArrayLike,
310
- weight: Union[Callable, ArrayLike],
311
- b_init: Optional[Union[Callable, ArrayLike]] = None,
312
- name: Optional[str] = None,
313
- ):
314
- super().__init__(name=name)
315
-
316
- # input and output shape
317
- self.in_size = in_size
318
- self.out_size = out_size
319
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
320
- 'and "out_size" must be the same.')
321
-
322
- # CSR data structure
323
- indptr = jnp.asarray(indptr)
324
- indices = jnp.asarray(indices)
325
- assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
326
- assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
327
- assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
328
- with jax.ensure_compile_time_eval():
329
- self.indptr = u.math.asarray(indptr)
330
- self.indices = u.math.asarray(indices)
331
-
332
- # weights
333
- weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
334
- params = dict(weight=weight)
335
- if b_init is not None:
336
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
337
- self.weight = ParamState(params)
338
-
339
- def update(self, x):
340
- data = self.weight.value['weight']
341
- data, w_unit = u.get_mantissa(data), u.get_unit(data)
342
- x, x_unit = u.get_mantissa(x), u.get_unit(x)
343
- shape = [self.out_size[-1], self.in_size[-1]]
344
- if x.ndim == 1:
345
- y = csr_matvec(data, self.indices, self.indptr, x, shape=shape, transpose=True)
346
- elif x.ndim == 2:
347
- y = csr_matmat(data, self.indices, self.indptr, x, shape=shape, transpose=True)
348
- else:
349
- raise NotImplementedError(f"matmul with object of shape {x.shape}")
350
- y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
351
- if 'bias' in self.weight.value:
352
- y = y + self.weight.value['bias']
353
- return y
354
-
355
-
356
- def coo_matvec(
357
- data: jax.Array,
358
- row: jax.Array,
359
- col: jax.Array,
360
- v: jax.Array, *,
361
- spinfo: COOInfo,
362
- transpose: bool = False
363
- ) -> jax.Array:
364
- """Product of COO sparse matrix and a dense vector.
365
-
366
- Args:
367
- data : array of shape ``(nse,)``.
368
- row : array of shape ``(nse,)``
369
- col : array of shape ``(nse,)`` and dtype ``row.dtype``
370
- v : array of shape ``(shape[0] if transpose else shape[1],)`` and
371
- dtype ``data.dtype``
372
- spinfo : COOInfo object containing the shape of the matrix and the dtype
373
- transpose : boolean specifying whether to transpose the sparse matrix
374
- before computing.
375
-
376
- Returns:
377
- y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
378
- the matrix vector product.
379
- """
380
- return coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose)
381
-
382
-
383
- def coo_matmat(
384
- data: jax.Array, row: jax.Array, col: jax.Array, B: jax.Array, *,
385
- spinfo: COOInfo, transpose: bool = False
386
- ) -> jax.Array:
387
- """Product of COO sparse matrix and a dense matrix.
388
-
389
- Args:
390
- data : array of shape ``(nse,)``.
391
- row : array of shape ``(nse,)``
392
- col : array of shape ``(nse,)`` and dtype ``row.dtype``
393
- B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
394
- dtype ``data.dtype``
395
- spinfo : COOInfo object containing the shape of the matrix and the dtype
396
- transpose : boolean specifying whether to transpose the sparse matrix
397
- before computing.
398
-
399
- Returns:
400
- C : array of shape ``(shape[1] if transpose else shape[0], cols)``
401
- representing the matrix vector product.
402
- """
403
- return coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose)
404
-
405
-
406
- class COOLinear(Module):
407
-
408
- def __init__(
409
- self,
410
- in_size: Size,
411
- out_size: Size,
412
- row: ArrayLike,
413
- col: ArrayLike,
414
- weight: Union[Callable, ArrayLike],
210
+ spar_mat: u.sparse.SparseMatrix,
415
211
  b_init: Optional[Union[Callable, ArrayLike]] = None,
416
- rows_sorted: bool = False,
417
- cols_sorted: bool = False,
212
+ in_size: Size = None,
418
213
  name: Optional[str] = None,
419
214
  ):
420
215
  super().__init__(name=name)
421
216
 
422
217
  # input and output shape
423
- self.in_size = in_size
424
- self.out_size = out_size
425
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
426
- 'and "out_size" must be the same.')
427
-
428
- # COO data structure
429
- row = jnp.asarray(row)
430
- col = jnp.asarray(col)
431
- assert row.ndim == 1, f"row must be 1D. Got: {row.ndim}"
432
- assert col.ndim == 1, f"col must be 1D. Got: {col.ndim}"
433
- assert row.size == col.size, f"row and col must have the same size. Got: {row.size} and {col.size}"
434
- with jax.ensure_compile_time_eval():
435
- self.row = u.math.asarray(row)
436
- self.col = u.math.asarray(col)
437
-
438
- # COO structure information
439
- self.rows_sorted = rows_sorted
440
- self.cols_sorted = cols_sorted
218
+ if in_size is not None:
219
+ self.in_size = in_size
220
+ self.out_size = spar_mat.shape[-1]
221
+ if in_size is not None:
222
+ assert self.in_size[:-1] == self.out_size[:-1], (
223
+ 'The first n-1 dimensions of "in_size" '
224
+ 'and "out_size" must be the same.'
225
+ )
441
226
 
442
227
  # weights
443
- weight = init.param(weight, (len(row),), allow_none=False, allow_scalar=False)
444
- params = dict(weight=weight)
228
+ assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
229
+ self.spar_mat = spar_mat
230
+ params = dict(weight=spar_mat.data)
445
231
  if b_init is not None:
446
232
  params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
447
233
  self.weight = ParamState(params)
448
234
 
449
235
  def update(self, x):
450
236
  data = self.weight.value['weight']
451
- data, w_unit = u.get_mantissa(data), u.get_unit(data)
452
- x, x_unit = u.get_mantissa(x), u.get_unit(x)
453
- spinfo = COOInfo(
454
- shape=(self.in_size[-1], self.out_size[-1]),
455
- rows_sorted=self.rows_sorted,
456
- cols_sorted=self.cols_sorted
457
- )
458
- if x.ndim == 1:
459
- y = coo_matvec(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
460
- elif x.ndim == 2:
461
- y = coo_matmat(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
462
- else:
463
- raise NotImplementedError(f"matmul with object of shape {x.shape}")
464
- y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
237
+ y = x @ self.spar_mat.with_data(data)
465
238
  if 'bias' in self.weight.value:
466
239
  y = y + self.weight.value['bias']
467
240
  return y