brainstate 0.1.0.post20241210__py2.py3-none-any.whl → 0.1.0.post20241220__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/compile/_jit.py +20 -14
- brainstate/compile/_loop_collect_return.py +14 -6
- brainstate/compile/_progress_bar.py +5 -3
- brainstate/event/__init__.py +8 -6
- brainstate/event/_csr.py +906 -0
- brainstate/event/_csr_mv.py +12 -25
- brainstate/event/_csr_mv_test.py +76 -76
- brainstate/event/_csr_test.py +90 -0
- brainstate/event/_fixedprob_mv.py +52 -32
- brainstate/event/_linear_mv.py +2 -2
- brainstate/event/_xla_custom_op.py +8 -11
- brainstate/graph/_graph_node.py +10 -1
- brainstate/graph/_graph_operation.py +8 -6
- brainstate/nn/_dyn_impl/_inputs.py +127 -2
- brainstate/nn/_dynamics/_dynamics_base.py +12 -0
- brainstate/nn/_dynamics/_projection_base.py +25 -7
- brainstate/nn/_elementwise/_dropout_test.py +11 -11
- brainstate/nn/_interaction/_linear.py +21 -248
- brainstate/nn/_interaction/_linear_test.py +73 -6
- brainstate/random/_rand_funs.py +7 -3
- brainstate/typing.py +3 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/METADATA +3 -2
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/RECORD +26 -25
- brainstate/event/_csr_benchmark.py +0 -14
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20241210.dist-info → brainstate-0.1.0.post20241220.dist-info}/top_level.txt +0 -0
@@ -17,17 +17,23 @@ from __future__ import annotations
|
|
17
17
|
from typing import Union, Optional, Sequence, Callable
|
18
18
|
|
19
19
|
import brainunit as u
|
20
|
+
import jax
|
21
|
+
import numpy as np
|
20
22
|
|
21
23
|
from brainstate import environ, init, random
|
22
24
|
from brainstate._state import ShortTermState
|
23
|
-
from brainstate.
|
24
|
-
from brainstate.
|
25
|
+
from brainstate._state import State
|
26
|
+
from brainstate.compile import while_loop, cond
|
27
|
+
from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
|
28
|
+
from brainstate.nn._module import Module
|
25
29
|
from brainstate.typing import ArrayLike, Size, DTypeLike
|
26
30
|
|
27
31
|
__all__ = [
|
28
32
|
'SpikeTime',
|
29
33
|
'PoissonSpike',
|
30
34
|
'PoissonEncoder',
|
35
|
+
'PoissonInput',
|
36
|
+
'poisson_input',
|
31
37
|
]
|
32
38
|
|
33
39
|
|
@@ -152,3 +158,122 @@ class PoissonEncoder(Dynamics):
|
|
152
158
|
spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
|
153
159
|
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
154
160
|
return spikes
|
161
|
+
|
162
|
+
|
163
|
+
class PoissonInput(Module):
|
164
|
+
"""
|
165
|
+
Poisson Input to the given :py:class:`brainstate.State`.
|
166
|
+
|
167
|
+
Adds independent Poisson input to a target variable. For large
|
168
|
+
numbers of inputs, this is much more efficient than creating a
|
169
|
+
`PoissonGroup`. The synaptic events are generated randomly during the
|
170
|
+
simulation and are not preloaded and stored in memory. All the inputs must
|
171
|
+
target the same variable, have the same frequency and same synaptic weight.
|
172
|
+
All neurons in the target variable receive independent realizations of
|
173
|
+
Poisson spike trains.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
target: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`.
|
177
|
+
num_input: The number of inputs.
|
178
|
+
freq: The frequency of each of the inputs. Must be a scalar.
|
179
|
+
weight: The synaptic weight. Must be a scalar.
|
180
|
+
name: The target name.
|
181
|
+
"""
|
182
|
+
|
183
|
+
def __init__(
|
184
|
+
self,
|
185
|
+
target: Prefetch,
|
186
|
+
indices: Union[np.ndarray, jax.Array],
|
187
|
+
num_input: int,
|
188
|
+
freq: Union[int, float],
|
189
|
+
weight: Union[int, float],
|
190
|
+
name: Optional[str] = None,
|
191
|
+
):
|
192
|
+
super().__init__(name=name)
|
193
|
+
|
194
|
+
self.target = target
|
195
|
+
self.indices = indices
|
196
|
+
self.num_input = num_input
|
197
|
+
self.freq = freq
|
198
|
+
self.weight = weight
|
199
|
+
|
200
|
+
def update(self):
|
201
|
+
p = self.freq * environ.get_dt()
|
202
|
+
a = self.num_input * p
|
203
|
+
b = self.num_input * (1 - p)
|
204
|
+
|
205
|
+
target = self.target()
|
206
|
+
target_state = getattr(self.target.module, self.target.item)
|
207
|
+
|
208
|
+
# generate Poisson input
|
209
|
+
inp = cond(
|
210
|
+
u.math.logical_and(a > 5, b > 5),
|
211
|
+
lambda: random.normal(a, b * p, self.indices.shape),
|
212
|
+
lambda: random.binomial(self.num_input, p, self.indices.shape).astype(float)
|
213
|
+
)
|
214
|
+
|
215
|
+
# update target variable
|
216
|
+
target_state.value = target.at[self.indices].add(inp * self.weight)
|
217
|
+
|
218
|
+
|
219
|
+
def poisson_input(
|
220
|
+
freq: ArrayLike,
|
221
|
+
num_input: int,
|
222
|
+
weight: ArrayLike,
|
223
|
+
target: State,
|
224
|
+
indices: Optional[Union[np.ndarray, jax.Array]] = None,
|
225
|
+
):
|
226
|
+
"""
|
227
|
+
Poisson Input to the given :py:class:`brainstate.State`.
|
228
|
+
"""
|
229
|
+
assert isinstance(target, State), 'The target must be a State.'
|
230
|
+
p = freq * environ.get_dt()
|
231
|
+
a = num_input * p
|
232
|
+
b = num_input * (1 - p)
|
233
|
+
tar_val = target.value
|
234
|
+
if indices is None:
|
235
|
+
# generate Poisson input
|
236
|
+
inp = cond(
|
237
|
+
u.math.logical_and(a > 5, b > 5),
|
238
|
+
lambda: jax.tree.map(
|
239
|
+
lambda tar: random.normal(a, b * p, tar.shape),
|
240
|
+
tar_val,
|
241
|
+
is_leaf=u.math.is_quantity
|
242
|
+
),
|
243
|
+
lambda: jax.tree.map(
|
244
|
+
lambda tar: random.binomial(num_input, p, tar.shape).astype(float),
|
245
|
+
tar_val,
|
246
|
+
is_leaf=u.math.is_quantity
|
247
|
+
)
|
248
|
+
)
|
249
|
+
|
250
|
+
# update target variable
|
251
|
+
target.value = jax.tree.map(
|
252
|
+
lambda x: x * weight,
|
253
|
+
inp,
|
254
|
+
is_leaf=u.math.is_quantity
|
255
|
+
)
|
256
|
+
|
257
|
+
else:
|
258
|
+
# generate Poisson input
|
259
|
+
inp = cond(
|
260
|
+
u.math.logical_and(a > 5, b > 5),
|
261
|
+
lambda: jax.tree.map(
|
262
|
+
lambda tar: random.normal(a, b * p, tar[indices].shape),
|
263
|
+
tar_val,
|
264
|
+
is_leaf=u.math.is_quantity
|
265
|
+
),
|
266
|
+
lambda: jax.tree.map(
|
267
|
+
lambda tar: random.binomial(num_input, p, tar[indices].shape).astype(float),
|
268
|
+
tar_val,
|
269
|
+
is_leaf=u.math.is_quantity
|
270
|
+
)
|
271
|
+
)
|
272
|
+
|
273
|
+
# update target variable
|
274
|
+
target.value = jax.tree.map(
|
275
|
+
lambda x, tar: tar.at[indices].add(x * weight),
|
276
|
+
inp,
|
277
|
+
tar_val,
|
278
|
+
is_leaf=u.math.is_quantity
|
279
|
+
)
|
@@ -107,6 +107,8 @@ class Dynamics(Module):
|
|
107
107
|
|
108
108
|
__module__ = 'brainstate.nn'
|
109
109
|
|
110
|
+
graph_invisible_attrs = ('_before_updates', '_after_updates', '_current_inputs', '_delta_inputs')
|
111
|
+
|
110
112
|
# before updates
|
111
113
|
_before_updates: Optional[Dict[Hashable, Callable]]
|
112
114
|
|
@@ -443,6 +445,16 @@ class Prefetch(Node):
|
|
443
445
|
item = _get_prefetch_item(self)
|
444
446
|
return item.value if isinstance(item, State) else item
|
445
447
|
|
448
|
+
def get_item_value(self):
|
449
|
+
item = _get_prefetch_item(self)
|
450
|
+
return item.value if isinstance(item, State) else item
|
451
|
+
|
452
|
+
def get_item(self):
|
453
|
+
"""
|
454
|
+
Get
|
455
|
+
"""
|
456
|
+
return _get_prefetch_item(self)
|
457
|
+
|
446
458
|
|
447
459
|
class PrefetchDelay(Node):
|
448
460
|
def __init__(self, module: Dynamics, item: str):
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
from typing import Union, Callable
|
17
|
+
from typing import Union, Callable, Optional
|
18
18
|
|
19
19
|
from brainstate._state import State
|
20
20
|
from brainstate.mixin import AlignPost, ParamDescriber, BindCondData, JointTypes
|
@@ -60,24 +60,28 @@ def is_instance(x, cls) -> bool:
|
|
60
60
|
return isinstance(x, cls)
|
61
61
|
|
62
62
|
|
63
|
-
def get_post_repr(syn, out):
|
64
|
-
|
63
|
+
def get_post_repr(label, syn, out):
|
64
|
+
if label is None:
|
65
|
+
return f'{syn.identifier} // {out.identifier}'
|
66
|
+
else:
|
67
|
+
return f'{label}{syn.identifier} // {out.identifier}'
|
65
68
|
|
66
69
|
|
67
70
|
def align_post_add_bef_update(
|
68
71
|
syn_desc: ParamDescriber[AlignPost],
|
69
72
|
out_desc: ParamDescriber[BindCondData],
|
70
73
|
post: Dynamics,
|
71
|
-
proj_name: str
|
74
|
+
proj_name: str,
|
75
|
+
label: str,
|
72
76
|
):
|
73
77
|
# synapse and output initialization
|
74
|
-
_post_repr = get_post_repr(syn_desc, out_desc)
|
78
|
+
_post_repr = get_post_repr(label, syn_desc, out_desc)
|
75
79
|
if not post._has_before_update(_post_repr):
|
76
80
|
syn_cls = syn_desc()
|
77
81
|
out_cls = out_desc()
|
78
82
|
|
79
83
|
# synapse and output initialization
|
80
|
-
post.add_current_input(proj_name, out_cls)
|
84
|
+
post.add_current_input(proj_name, out_cls, label=label)
|
81
85
|
post._add_before_update(_post_repr, _AlignPost(syn_cls, out_cls))
|
82
86
|
syn = post._get_before_update(_post_repr).syn
|
83
87
|
out = post._get_before_update(_post_repr).out
|
@@ -139,6 +143,7 @@ class AlignPostProj(Interaction):
|
|
139
143
|
syn: Union[ParamDescriber[AlignPost], AlignPost],
|
140
144
|
out: Union[ParamDescriber[SynOut], SynOut],
|
141
145
|
post: Dynamics,
|
146
|
+
label: Optional[str] = None,
|
142
147
|
):
|
143
148
|
super().__init__(name=get_unique_name(self.__class__.__name__))
|
144
149
|
|
@@ -154,12 +159,21 @@ class AlignPostProj(Interaction):
|
|
154
159
|
# checking synapse and output models
|
155
160
|
if is_instance(syn, ParamDescriber[AlignPost]):
|
156
161
|
if not is_instance(out, ParamDescriber[SynOut]):
|
162
|
+
if is_instance(out, ParamDescriber):
|
163
|
+
raise TypeError(
|
164
|
+
f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
|
165
|
+
f'the synapse is an instance of {AlignPost}, but got {out}.'
|
166
|
+
)
|
157
167
|
raise TypeError(
|
158
168
|
f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
|
159
169
|
f'the synapse is a describer, but we got {out}.'
|
160
170
|
)
|
161
171
|
merging = True
|
162
172
|
else:
|
173
|
+
if is_instance(syn, ParamDescriber):
|
174
|
+
raise TypeError(
|
175
|
+
f'The synapse should be an instance of describer {ParamDescriber[AlignPost]}, but got {syn}.'
|
176
|
+
)
|
163
177
|
if not is_instance(out, SynOut):
|
164
178
|
raise TypeError(
|
165
179
|
f'The output should be an instance of {SynOut} when the synapse is '
|
@@ -176,7 +190,11 @@ class AlignPostProj(Interaction):
|
|
176
190
|
|
177
191
|
if merging:
|
178
192
|
# synapse and output initialization
|
179
|
-
syn, out = align_post_add_bef_update(syn_desc=syn,
|
193
|
+
syn, out = align_post_add_bef_update(syn_desc=syn,
|
194
|
+
out_desc=out,
|
195
|
+
post=post,
|
196
|
+
proj_name=self.name,
|
197
|
+
label=label)
|
180
198
|
else:
|
181
199
|
post.add_current_input(self.name, out)
|
182
200
|
|
@@ -59,17 +59,17 @@ class TestDropout(unittest.TestCase):
|
|
59
59
|
expected_non_zero_elements = input_data[output_data != 0] * scale_factor
|
60
60
|
np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements)
|
61
61
|
|
62
|
-
def test_Dropout1d(self):
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
62
|
+
# def test_Dropout1d(self):
|
63
|
+
# dropout_layer = bst.nn.Dropout1d(prob=0.5)
|
64
|
+
# input_data = np.random.randn(2, 3, 4)
|
65
|
+
# with bst.environ.context(fit=True):
|
66
|
+
# output_data = dropout_layer(input_data)
|
67
|
+
# self.assertEqual(input_data.shape, output_data.shape)
|
68
|
+
# self.assertTrue(np.any(output_data == 0))
|
69
|
+
# scale_factor = 1 / (1 - 0.5)
|
70
|
+
# non_zero_elements = output_data[output_data != 0]
|
71
|
+
# expected_non_zero_elements = input_data[output_data != 0] * scale_factor
|
72
|
+
# np.testing.assert_almost_equal(non_zero_elements, expected_non_zero_elements, decimal=4)
|
73
73
|
|
74
74
|
def test_Dropout2d(self):
|
75
75
|
dropout_layer = bst.nn.Dropout2d(prob=0.5)
|
@@ -20,10 +20,7 @@ from __future__ import annotations
|
|
20
20
|
from typing import Callable, Union, Optional
|
21
21
|
|
22
22
|
import brainunit as u
|
23
|
-
import jax
|
24
23
|
import jax.numpy as jnp
|
25
|
-
from jax.experimental.sparse.coo import coo_matvec_p, coo_matmat_p, COOInfo
|
26
|
-
from jax.experimental.sparse.csr import csr_matvec_p, csr_matmat_p
|
27
24
|
|
28
25
|
from brainstate import init, functional
|
29
26
|
from brainstate._state import ParamState
|
@@ -34,9 +31,7 @@ __all__ = [
|
|
34
31
|
'Linear',
|
35
32
|
'ScaledWSLinear',
|
36
33
|
'SignedWLinear',
|
37
|
-
'
|
38
|
-
'CSCLinear',
|
39
|
-
'COOLinear',
|
34
|
+
'SparseLinear',
|
40
35
|
'AllToAll',
|
41
36
|
'OneToOne',
|
42
37
|
]
|
@@ -198,270 +193,48 @@ class ScaledWSLinear(Module):
|
|
198
193
|
return y
|
199
194
|
|
200
195
|
|
201
|
-
|
202
|
-
"""Product of CSR sparse matrix and a dense matrix.
|
203
|
-
|
204
|
-
Args:
|
205
|
-
data : array of shape ``(nse,)``.
|
206
|
-
indices : array of shape ``(nse,)``
|
207
|
-
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
208
|
-
B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and
|
209
|
-
dtype ``mat.dtype``
|
210
|
-
transpose : boolean specifying whether to transpose the sparse matrix
|
211
|
-
before computing.
|
212
|
-
|
213
|
-
Returns:
|
214
|
-
C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)``
|
215
|
-
representing the matrix vector product.
|
196
|
+
class SparseLinear(Module):
|
216
197
|
"""
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
def csr_matvec(data, indices, indptr, v, *, shape, transpose=False) -> jax.Array:
|
221
|
-
"""Product of CSR sparse matrix and a dense vector.
|
198
|
+
Linear layer with Sparse Matrix (can be ``brainunit.sparse.CSR``,
|
199
|
+
``brainunit.sparse.CSC``, ``brainunit.sparse.COO``, or any other sparse matrix).
|
222
200
|
|
223
201
|
Args:
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
v : array of shape ``(shape[0] if transpose else shape[1],)``
|
228
|
-
and dtype ``data.dtype``
|
229
|
-
shape : length-2 tuple representing the matrix shape
|
230
|
-
transpose : boolean specifying whether to transpose the sparse matrix
|
231
|
-
before computing.
|
232
|
-
|
233
|
-
Returns:
|
234
|
-
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
235
|
-
the matrix vector product.
|
236
|
-
"""
|
237
|
-
return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
|
238
|
-
|
239
|
-
|
240
|
-
class CSRLinear(Module):
|
241
|
-
"""
|
242
|
-
Linear layer with Compressed Sparse Row (CSR) matrix.
|
243
|
-
"""
|
244
|
-
__module__ = 'brainstate.nn'
|
245
|
-
|
246
|
-
def __init__(
|
247
|
-
self,
|
248
|
-
in_size: Size,
|
249
|
-
out_size: Size,
|
250
|
-
indptr: ArrayLike,
|
251
|
-
indices: ArrayLike,
|
252
|
-
weight: Union[Callable, ArrayLike],
|
253
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
254
|
-
name: Optional[str] = None,
|
255
|
-
):
|
256
|
-
super().__init__(name=name)
|
257
|
-
|
258
|
-
# input and output shape
|
259
|
-
self.in_size = in_size
|
260
|
-
self.out_size = out_size
|
261
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
262
|
-
'and "out_size" must be the same.')
|
263
|
-
|
264
|
-
# CSR data structure
|
265
|
-
indptr = jnp.asarray(indptr)
|
266
|
-
indices = jnp.asarray(indices)
|
267
|
-
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
|
268
|
-
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
|
269
|
-
assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
|
270
|
-
with jax.ensure_compile_time_eval():
|
271
|
-
self.indptr = u.math.asarray(indptr)
|
272
|
-
self.indices = u.math.asarray(indices)
|
273
|
-
|
274
|
-
# weights
|
275
|
-
weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
|
276
|
-
params = dict(weight=weight)
|
277
|
-
if b_init is not None:
|
278
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
279
|
-
self.weight = ParamState(params)
|
280
|
-
|
281
|
-
def update(self, x):
|
282
|
-
data = self.weight.value['weight']
|
283
|
-
data, w_unit = u.get_mantissa(data), u.get_unit(data)
|
284
|
-
x, x_unit = u.get_mantissa(x), u.get_unit(x)
|
285
|
-
shape = [self.in_size[-1], self.out_size[-1]]
|
286
|
-
if x.ndim == 1:
|
287
|
-
y = csr_matvec(data, self.indices, self.indptr, x, shape=shape)
|
288
|
-
elif x.ndim == 2:
|
289
|
-
y = csr_matmat(data, self.indices, self.indptr, x, shape=shape)
|
290
|
-
else:
|
291
|
-
raise NotImplementedError(f"matmul with object of shape {x.shape}")
|
292
|
-
y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
|
293
|
-
if 'bias' in self.weight.value:
|
294
|
-
y = y + self.weight.value['bias']
|
295
|
-
return y
|
296
|
-
|
297
|
-
|
298
|
-
class CSCLinear(Module):
|
299
|
-
"""
|
300
|
-
Linear layer with Compressed Sparse Column (CSC) matrix.
|
202
|
+
spar_mat: SparseMatrix. The sparse weight matrix.
|
203
|
+
in_size: Size. The input size.
|
204
|
+
name: str. The object name.
|
301
205
|
"""
|
302
206
|
__module__ = 'brainstate.nn'
|
303
207
|
|
304
208
|
def __init__(
|
305
209
|
self,
|
306
|
-
|
307
|
-
out_size: Size,
|
308
|
-
indptr: ArrayLike,
|
309
|
-
indices: ArrayLike,
|
310
|
-
weight: Union[Callable, ArrayLike],
|
311
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
312
|
-
name: Optional[str] = None,
|
313
|
-
):
|
314
|
-
super().__init__(name=name)
|
315
|
-
|
316
|
-
# input and output shape
|
317
|
-
self.in_size = in_size
|
318
|
-
self.out_size = out_size
|
319
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
320
|
-
'and "out_size" must be the same.')
|
321
|
-
|
322
|
-
# CSR data structure
|
323
|
-
indptr = jnp.asarray(indptr)
|
324
|
-
indices = jnp.asarray(indices)
|
325
|
-
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
|
326
|
-
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
|
327
|
-
assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
|
328
|
-
with jax.ensure_compile_time_eval():
|
329
|
-
self.indptr = u.math.asarray(indptr)
|
330
|
-
self.indices = u.math.asarray(indices)
|
331
|
-
|
332
|
-
# weights
|
333
|
-
weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
|
334
|
-
params = dict(weight=weight)
|
335
|
-
if b_init is not None:
|
336
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
337
|
-
self.weight = ParamState(params)
|
338
|
-
|
339
|
-
def update(self, x):
|
340
|
-
data = self.weight.value['weight']
|
341
|
-
data, w_unit = u.get_mantissa(data), u.get_unit(data)
|
342
|
-
x, x_unit = u.get_mantissa(x), u.get_unit(x)
|
343
|
-
shape = [self.out_size[-1], self.in_size[-1]]
|
344
|
-
if x.ndim == 1:
|
345
|
-
y = csr_matvec(data, self.indices, self.indptr, x, shape=shape, transpose=True)
|
346
|
-
elif x.ndim == 2:
|
347
|
-
y = csr_matmat(data, self.indices, self.indptr, x, shape=shape, transpose=True)
|
348
|
-
else:
|
349
|
-
raise NotImplementedError(f"matmul with object of shape {x.shape}")
|
350
|
-
y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
|
351
|
-
if 'bias' in self.weight.value:
|
352
|
-
y = y + self.weight.value['bias']
|
353
|
-
return y
|
354
|
-
|
355
|
-
|
356
|
-
def coo_matvec(
|
357
|
-
data: jax.Array,
|
358
|
-
row: jax.Array,
|
359
|
-
col: jax.Array,
|
360
|
-
v: jax.Array, *,
|
361
|
-
spinfo: COOInfo,
|
362
|
-
transpose: bool = False
|
363
|
-
) -> jax.Array:
|
364
|
-
"""Product of COO sparse matrix and a dense vector.
|
365
|
-
|
366
|
-
Args:
|
367
|
-
data : array of shape ``(nse,)``.
|
368
|
-
row : array of shape ``(nse,)``
|
369
|
-
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
370
|
-
v : array of shape ``(shape[0] if transpose else shape[1],)`` and
|
371
|
-
dtype ``data.dtype``
|
372
|
-
spinfo : COOInfo object containing the shape of the matrix and the dtype
|
373
|
-
transpose : boolean specifying whether to transpose the sparse matrix
|
374
|
-
before computing.
|
375
|
-
|
376
|
-
Returns:
|
377
|
-
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
378
|
-
the matrix vector product.
|
379
|
-
"""
|
380
|
-
return coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose)
|
381
|
-
|
382
|
-
|
383
|
-
def coo_matmat(
|
384
|
-
data: jax.Array, row: jax.Array, col: jax.Array, B: jax.Array, *,
|
385
|
-
spinfo: COOInfo, transpose: bool = False
|
386
|
-
) -> jax.Array:
|
387
|
-
"""Product of COO sparse matrix and a dense matrix.
|
388
|
-
|
389
|
-
Args:
|
390
|
-
data : array of shape ``(nse,)``.
|
391
|
-
row : array of shape ``(nse,)``
|
392
|
-
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
393
|
-
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
|
394
|
-
dtype ``data.dtype``
|
395
|
-
spinfo : COOInfo object containing the shape of the matrix and the dtype
|
396
|
-
transpose : boolean specifying whether to transpose the sparse matrix
|
397
|
-
before computing.
|
398
|
-
|
399
|
-
Returns:
|
400
|
-
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
401
|
-
representing the matrix vector product.
|
402
|
-
"""
|
403
|
-
return coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose)
|
404
|
-
|
405
|
-
|
406
|
-
class COOLinear(Module):
|
407
|
-
|
408
|
-
def __init__(
|
409
|
-
self,
|
410
|
-
in_size: Size,
|
411
|
-
out_size: Size,
|
412
|
-
row: ArrayLike,
|
413
|
-
col: ArrayLike,
|
414
|
-
weight: Union[Callable, ArrayLike],
|
210
|
+
spar_mat: u.sparse.SparseMatrix,
|
415
211
|
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
416
|
-
|
417
|
-
cols_sorted: bool = False,
|
212
|
+
in_size: Size = None,
|
418
213
|
name: Optional[str] = None,
|
419
214
|
):
|
420
215
|
super().__init__(name=name)
|
421
216
|
|
422
217
|
# input and output shape
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
assert row.ndim == 1, f"row must be 1D. Got: {row.ndim}"
|
432
|
-
assert col.ndim == 1, f"col must be 1D. Got: {col.ndim}"
|
433
|
-
assert row.size == col.size, f"row and col must have the same size. Got: {row.size} and {col.size}"
|
434
|
-
with jax.ensure_compile_time_eval():
|
435
|
-
self.row = u.math.asarray(row)
|
436
|
-
self.col = u.math.asarray(col)
|
437
|
-
|
438
|
-
# COO structure information
|
439
|
-
self.rows_sorted = rows_sorted
|
440
|
-
self.cols_sorted = cols_sorted
|
218
|
+
if in_size is not None:
|
219
|
+
self.in_size = in_size
|
220
|
+
self.out_size = spar_mat.shape[-1]
|
221
|
+
if in_size is not None:
|
222
|
+
assert self.in_size[:-1] == self.out_size[:-1], (
|
223
|
+
'The first n-1 dimensions of "in_size" '
|
224
|
+
'and "out_size" must be the same.'
|
225
|
+
)
|
441
226
|
|
442
227
|
# weights
|
443
|
-
|
444
|
-
|
228
|
+
assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
|
229
|
+
self.spar_mat = spar_mat
|
230
|
+
params = dict(weight=spar_mat.data)
|
445
231
|
if b_init is not None:
|
446
232
|
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
447
233
|
self.weight = ParamState(params)
|
448
234
|
|
449
235
|
def update(self, x):
|
450
236
|
data = self.weight.value['weight']
|
451
|
-
|
452
|
-
x, x_unit = u.get_mantissa(x), u.get_unit(x)
|
453
|
-
spinfo = COOInfo(
|
454
|
-
shape=(self.in_size[-1], self.out_size[-1]),
|
455
|
-
rows_sorted=self.rows_sorted,
|
456
|
-
cols_sorted=self.cols_sorted
|
457
|
-
)
|
458
|
-
if x.ndim == 1:
|
459
|
-
y = coo_matvec(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
|
460
|
-
elif x.ndim == 2:
|
461
|
-
y = coo_matmat(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
|
462
|
-
else:
|
463
|
-
raise NotImplementedError(f"matmul with object of shape {x.shape}")
|
464
|
-
y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
|
237
|
+
y = x @ self.spar_mat.with_data(data)
|
465
238
|
if 'bias' in self.weight.value:
|
466
239
|
y = y + self.weight.value['bias']
|
467
240
|
return y
|