brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_embedding.py
CHANGED
@@ -1,58 +1,408 @@
|
|
1
|
-
# Copyright 2024
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
"""
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from functools import lru_cache
|
17
|
+
from typing import Callable, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
import jax.tree_util as jtu
|
23
|
+
from jax import core as jax_core
|
24
|
+
|
25
|
+
from brainstate._state import ParamState
|
26
|
+
from brainstate.typing import ArrayLike, Size
|
27
|
+
from . import init as init
|
28
|
+
from ._module import Module
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
'Embedding',
|
32
|
+
]
|
33
|
+
|
34
|
+
|
35
|
+
def _normalize_embedding_size(size: Size) -> Tuple[int, ...]:
|
36
|
+
"""Convert ``Size`` specifications to a validated tuple of integers."""
|
37
|
+
size_array = np.asarray(size)
|
38
|
+
if size_array.size == 0:
|
39
|
+
raise ValueError('embedding_size must contain at least one dimension.')
|
40
|
+
flat = size_array.reshape(-1)
|
41
|
+
normalized = tuple(int(dim) for dim in flat)
|
42
|
+
if any(dim < 0 for dim in normalized):
|
43
|
+
raise ValueError('embedding_size must not contain negative values.')
|
44
|
+
return normalized
|
45
|
+
|
46
|
+
|
47
|
+
@lru_cache(maxsize=None)
|
48
|
+
def _embedding_lookup_fn(
|
49
|
+
padding_idx: Optional[int],
|
50
|
+
scale_grad_by_freq: bool,
|
51
|
+
):
|
52
|
+
"""Return a lookup function with a custom VJP implementing embedding semantics."""
|
53
|
+
|
54
|
+
@jax.custom_vjp
|
55
|
+
def _lookup(weight: jax.Array, indices: jax.Array) -> jax.Array:
|
56
|
+
indices = jnp.asarray(indices)
|
57
|
+
return weight[indices]
|
58
|
+
|
59
|
+
def _lookup_fwd(weight: jax.Array, indices: jax.Array):
|
60
|
+
indices = jnp.asarray(indices)
|
61
|
+
return weight[indices], (indices, weight.shape)
|
62
|
+
|
63
|
+
def _lookup_bwd(residual, grad_output: jax.Array):
|
64
|
+
indices, weight_shape = residual
|
65
|
+
grad_output = jnp.asarray(grad_output)
|
66
|
+
flat_idx = jnp.ravel(indices)
|
67
|
+
if flat_idx.size == 0:
|
68
|
+
return jnp.zeros(weight_shape, dtype=grad_output.dtype), None
|
69
|
+
|
70
|
+
flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
|
71
|
+
grad_flat = grad_output.reshape((flat_idx.shape[0],) + weight_shape[1:])
|
72
|
+
|
73
|
+
if scale_grad_by_freq:
|
74
|
+
counts = jnp.bincount(flat_idx, length=weight_shape[0])
|
75
|
+
counts = counts.astype(grad_flat.dtype)
|
76
|
+
counts = jnp.where(counts == 0, 1.0, counts)
|
77
|
+
scale = counts[flat_idx]
|
78
|
+
grad_flat = grad_flat / scale.reshape((flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1))
|
79
|
+
|
80
|
+
if padding_idx is not None:
|
81
|
+
pad_value = jnp.asarray(padding_idx, dtype=flat_idx.dtype)
|
82
|
+
mask = flat_idx != pad_value
|
83
|
+
broadcast_shape = (flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1)
|
84
|
+
grad_flat = grad_flat * mask.reshape(broadcast_shape).astype(grad_flat.dtype)
|
85
|
+
|
86
|
+
grad_weight = jnp.zeros(weight_shape, dtype=grad_output.dtype)
|
87
|
+
grad_weight = grad_weight.at[flat_idx].add(grad_flat)
|
88
|
+
return grad_weight, None
|
89
|
+
|
90
|
+
_lookup.defvjp(_lookup_fwd, _lookup_bwd)
|
91
|
+
return _lookup
|
92
|
+
|
93
|
+
|
94
|
+
def _contains_tracer(tree) -> bool:
|
95
|
+
"""Return True if the pytree contains any JAX tracer values."""
|
96
|
+
return any(isinstance(leaf, jax_core.Tracer) for leaf in jtu.tree_leaves(tree))
|
97
|
+
|
98
|
+
|
99
|
+
|
100
|
+
class Embedding(Module):
|
101
|
+
r"""
|
102
|
+
A simple lookup table that stores embeddings of a fixed size.
|
103
|
+
|
104
|
+
This module is commonly used to store word embeddings and retrieve them using indices.
|
105
|
+
The input to the module is a list of indices, and the output is the corresponding
|
106
|
+
word embeddings.
|
107
|
+
|
108
|
+
Parameters
|
109
|
+
----------
|
110
|
+
num_embeddings : int
|
111
|
+
Size of embedding dictionary. Must be non-negative.
|
112
|
+
embedding_size : Size
|
113
|
+
Size of each embedding vector. Can be an int or a sequence of ints, and must
|
114
|
+
contain only non-negative values.
|
115
|
+
embedding_init : Callable or ArrayLike, optional
|
116
|
+
The initializer for the embedding lookup table, of shape
|
117
|
+
``(num_embeddings, embedding_size)``. Default is ``LecunUniform()``.
|
118
|
+
padding_idx : int, optional
|
119
|
+
If specified, the entries at ``padding_idx`` do not contribute to the gradient;
|
120
|
+
therefore, the embedding vector at ``padding_idx`` is not updated during training,
|
121
|
+
i.e., it remains as a fixed "pad". For a newly constructed Embedding, the embedding
|
122
|
+
vector at ``padding_idx`` will default to all zeros. Default is ``None``.
|
123
|
+
max_norm : float, optional
|
124
|
+
If given, each embedding vector with norm larger than ``max_norm`` is renormalized
|
125
|
+
to have norm ``max_norm``. Default is ``None``.
|
126
|
+
norm_type : float, optional
|
127
|
+
The p of the p-norm to compute for the ``max_norm`` option. Default is ``2.0``.
|
128
|
+
scale_grad_by_freq : bool, optional
|
129
|
+
If given, this scales gradients by the inverse frequency of the words in
|
130
|
+
the mini-batch. Default is ``False``.
|
131
|
+
name : str, optional
|
132
|
+
The name of the module.
|
133
|
+
param_type : type, optional
|
134
|
+
The parameter state type to use. Default is ``ParamState``.
|
135
|
+
|
136
|
+
Attributes
|
137
|
+
----------
|
138
|
+
num_embeddings : int
|
139
|
+
Size of the embedding dictionary.
|
140
|
+
embedding_size : tuple[int, ...]
|
141
|
+
Size of each embedding vector.
|
142
|
+
out_size : tuple[int, ...]
|
143
|
+
Output size, same as ``embedding_size``.
|
144
|
+
weight : ParamState
|
145
|
+
The learnable weights of the module of shape
|
146
|
+
``(num_embeddings, *embedding_size)``.
|
147
|
+
padding_idx : int or None
|
148
|
+
Index of the padding token.
|
149
|
+
max_norm : float or None
|
150
|
+
Maximum norm for embedding vectors.
|
151
|
+
norm_type : float
|
152
|
+
Type of p-norm to compute for max_norm.
|
153
|
+
scale_grad_by_freq : bool
|
154
|
+
Whether to scale gradients by frequency.
|
155
|
+
freeze : bool
|
156
|
+
Whether the embedding weights are frozen.
|
157
|
+
|
158
|
+
Examples
|
159
|
+
--------
|
160
|
+
Create an embedding layer with 10 words and 3-dimensional embeddings:
|
161
|
+
|
162
|
+
.. code-block:: python
|
163
|
+
|
164
|
+
>>> import brainstate as bst
|
165
|
+
>>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3)
|
166
|
+
>>> embedding.weight.value.shape
|
167
|
+
(10, 3)
|
168
|
+
|
169
|
+
Retrieve embeddings for specific indices:
|
170
|
+
|
171
|
+
.. code-block:: python
|
172
|
+
|
173
|
+
>>> import jax.numpy as jnp
|
174
|
+
>>> indices = jnp.array([1, 3, 5])
|
175
|
+
>>> output = embedding(indices)
|
176
|
+
>>> output.shape
|
177
|
+
(3, 3)
|
178
|
+
|
179
|
+
Use with a batch of sequences:
|
180
|
+
|
181
|
+
.. code-block:: python
|
182
|
+
|
183
|
+
>>> # Batch of 2 sequences, each with 4 tokens
|
184
|
+
>>> batch_indices = jnp.array([[1, 2, 3, 4],
|
185
|
+
... [5, 6, 7, 8]])
|
186
|
+
>>> output = embedding(batch_indices)
|
187
|
+
>>> output.shape
|
188
|
+
(2, 4, 3)
|
189
|
+
|
190
|
+
Use padding_idx to keep padding embeddings fixed:
|
191
|
+
|
192
|
+
.. code-block:: python
|
193
|
+
|
194
|
+
>>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
|
195
|
+
>>> # The embedding at index 0 will remain zeros and not be updated during training
|
196
|
+
>>> indices = jnp.array([0, 2, 0, 5])
|
197
|
+
>>> output = embedding(indices)
|
198
|
+
>>> output[0] # Will be zeros
|
199
|
+
Array([0., 0., 0.], dtype=float32)
|
200
|
+
|
201
|
+
Use max_norm to constrain embedding norms:
|
202
|
+
|
203
|
+
.. code-block:: python
|
204
|
+
|
205
|
+
>>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
|
206
|
+
>>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0
|
207
|
+
|
208
|
+
Load pretrained embeddings:
|
209
|
+
|
210
|
+
.. code-block:: python
|
211
|
+
|
212
|
+
>>> import brainstate
|
213
|
+
>>> import jax.numpy as jnp
|
214
|
+
>>> pretrained = jnp.array([[1.0, 2.0, 3.0],
|
215
|
+
... [4.0, 5.0, 6.0],
|
216
|
+
... [7.0, 8.0, 9.0]])
|
217
|
+
>>> embedding = bst.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
|
218
|
+
>>> embedding.weight.value.shape
|
219
|
+
(3, 3)
|
220
|
+
"""
|
221
|
+
|
222
|
+
def __init__(
|
223
|
+
self,
|
224
|
+
num_embeddings: int,
|
225
|
+
embedding_size: Size,
|
226
|
+
embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
|
227
|
+
padding_idx: Optional[int] = None,
|
228
|
+
max_norm: Optional[float] = None,
|
229
|
+
norm_type: float = 2.0,
|
230
|
+
scale_grad_by_freq: bool = False,
|
231
|
+
freeze: bool = False,
|
232
|
+
name: Optional[str] = None,
|
233
|
+
param_type: type = ParamState,
|
234
|
+
):
|
235
|
+
super().__init__(name=name)
|
236
|
+
|
237
|
+
self.num_embeddings = int(num_embeddings)
|
238
|
+
if self.num_embeddings < 0:
|
239
|
+
raise ValueError('num_embeddings must not be negative.')
|
240
|
+
|
241
|
+
embedding_size_tuple = _normalize_embedding_size(embedding_size)
|
242
|
+
self.embedding_size = embedding_size_tuple
|
243
|
+
self.out_size = embedding_size_tuple
|
244
|
+
|
245
|
+
if padding_idx is not None:
|
246
|
+
padding_idx = int(padding_idx)
|
247
|
+
if padding_idx < 0 or padding_idx >= self.num_embeddings:
|
248
|
+
raise ValueError(f'padding_idx must be within [0, {self.num_embeddings}).')
|
249
|
+
self.padding_idx = padding_idx
|
250
|
+
|
251
|
+
if max_norm is not None and max_norm <= 0:
|
252
|
+
raise ValueError('max_norm must be positive when provided.')
|
253
|
+
self.max_norm = max_norm
|
254
|
+
self.norm_type = norm_type
|
255
|
+
self.scale_grad_by_freq = bool(scale_grad_by_freq)
|
256
|
+
self.freeze = bool(freeze)
|
257
|
+
|
258
|
+
weight_shape = (self.num_embeddings, *self.out_size)
|
259
|
+
weight = init.param(embedding_init, weight_shape)
|
260
|
+
|
261
|
+
if self.padding_idx is not None:
|
262
|
+
weight = weight.at[self.padding_idx].set(0)
|
263
|
+
|
264
|
+
self.weight = param_type(weight)
|
265
|
+
self._lookup = _embedding_lookup_fn(self.padding_idx, self.scale_grad_by_freq)
|
266
|
+
|
267
|
+
@classmethod
|
268
|
+
def from_pretrained(
|
269
|
+
cls,
|
270
|
+
embeddings: ArrayLike,
|
271
|
+
padding_idx: Optional[int] = None,
|
272
|
+
max_norm: Optional[float] = None,
|
273
|
+
norm_type: float = 2.0,
|
274
|
+
scale_grad_by_freq: bool = False,
|
275
|
+
freeze: bool = True,
|
276
|
+
name: Optional[str] = None,
|
277
|
+
param_type: type = ParamState,
|
278
|
+
):
|
279
|
+
r"""
|
280
|
+
Create an Embedding instance from given 2-dimensional array.
|
281
|
+
|
282
|
+
Parameters
|
283
|
+
----------
|
284
|
+
embeddings : ArrayLike
|
285
|
+
Array containing weights for the Embedding. First dimension is passed to
|
286
|
+
Embedding as ``num_embeddings``, remaining dimensions as ``embedding_size``.
|
287
|
+
padding_idx : int, optional
|
288
|
+
If specified, the entries at ``padding_idx`` do not contribute to the gradient.
|
289
|
+
Default is ``None``.
|
290
|
+
max_norm : float, optional
|
291
|
+
See module initialization documentation. Default is ``None``.
|
292
|
+
norm_type : float, optional
|
293
|
+
See module initialization documentation. Default is ``2.0``.
|
294
|
+
scale_grad_by_freq : bool, optional
|
295
|
+
See module initialization documentation. Default is ``False``.
|
296
|
+
freeze : bool, optional
|
297
|
+
If ``True``, embeddings are frozen (no gradients). Default is ``True``.
|
298
|
+
name : str, optional
|
299
|
+
The name of the module.
|
300
|
+
|
301
|
+
Returns
|
302
|
+
-------
|
303
|
+
Embedding
|
304
|
+
An Embedding module with pretrained weights.
|
305
|
+
|
306
|
+
Examples
|
307
|
+
--------
|
308
|
+
Load pretrained word embeddings:
|
309
|
+
|
310
|
+
.. code-block:: python
|
311
|
+
|
312
|
+
>>> import jax.numpy as jnp
|
313
|
+
>>> import brainstate as bst
|
314
|
+
>>> pretrained = jnp.array([[1.0, 2.0, 3.0],
|
315
|
+
... [4.0, 5.0, 6.0],
|
316
|
+
... [7.0, 8.0, 9.0]])
|
317
|
+
>>> embedding = bst.nn.Embedding.from_pretrained(pretrained)
|
318
|
+
>>> embedding.weight.value.shape
|
319
|
+
(3, 3)
|
320
|
+
>>> indices = jnp.array([1])
|
321
|
+
>>> embedding(indices)
|
322
|
+
Array([[4., 5., 6.]], dtype=float32)
|
323
|
+
"""
|
324
|
+
embeddings = jnp.asarray(embeddings)
|
325
|
+
if embeddings.ndim < 2:
|
326
|
+
raise ValueError('embeddings must be at least 2-dimensional')
|
327
|
+
|
328
|
+
num_embeddings = embeddings.shape[0]
|
329
|
+
embedding_size = embeddings.shape[1:]
|
330
|
+
|
331
|
+
instance = cls(
|
332
|
+
num_embeddings=num_embeddings,
|
333
|
+
embedding_size=embedding_size,
|
334
|
+
embedding_init=embeddings,
|
335
|
+
padding_idx=padding_idx,
|
336
|
+
max_norm=max_norm,
|
337
|
+
norm_type=norm_type,
|
338
|
+
scale_grad_by_freq=scale_grad_by_freq,
|
339
|
+
freeze=freeze,
|
340
|
+
name=name,
|
341
|
+
param_type=param_type,
|
342
|
+
)
|
343
|
+
|
344
|
+
instance.weight = param_type(jnp.asarray(embeddings))
|
345
|
+
return instance
|
346
|
+
|
347
|
+
def update(self, indices: ArrayLike):
|
348
|
+
"""
|
349
|
+
Retrieve embeddings for the given indices.
|
350
|
+
|
351
|
+
Parameters
|
352
|
+
----------
|
353
|
+
indices : ArrayLike
|
354
|
+
Indices to retrieve embeddings for. Can be any shape.
|
355
|
+
|
356
|
+
Returns
|
357
|
+
-------
|
358
|
+
ArrayLike
|
359
|
+
Embeddings corresponding to the indices, with shape
|
360
|
+
``(*indices.shape, *embedding_size)``.
|
361
|
+
"""
|
362
|
+
indices = jnp.asarray(indices)
|
363
|
+
if not jnp.issubdtype(indices.dtype, jnp.integer):
|
364
|
+
raise TypeError('Embedding indices must be integers.')
|
365
|
+
|
366
|
+
weight_value = self.weight.value
|
367
|
+
effective_weight = weight_value
|
368
|
+
|
369
|
+
if self.max_norm is not None:
|
370
|
+
renormed_weight = self._apply_max_norm(weight_value, indices)
|
371
|
+
effective_weight = weight_value + jax.lax.stop_gradient(renormed_weight - weight_value)
|
372
|
+
if not _contains_tracer(renormed_weight):
|
373
|
+
self.weight.value = renormed_weight
|
374
|
+
|
375
|
+
if self.freeze:
|
376
|
+
effective_weight = jax.lax.stop_gradient(effective_weight)
|
377
|
+
|
378
|
+
embeddings = self._lookup(effective_weight, indices)
|
379
|
+
return embeddings
|
380
|
+
|
381
|
+
def _apply_max_norm(self, weight: jax.Array, indices: jax.Array) -> jax.Array:
|
382
|
+
"""Apply max_norm constraint to the embedding weights for the given indices."""
|
383
|
+
flat_idx = jnp.ravel(indices)
|
384
|
+
if flat_idx.size == 0:
|
385
|
+
return weight
|
386
|
+
|
387
|
+
flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
|
388
|
+
if self.padding_idx is not None:
|
389
|
+
pad_value = jnp.asarray(self.padding_idx, dtype=flat_idx.dtype)
|
390
|
+
flat_idx = flat_idx[flat_idx != pad_value]
|
391
|
+
|
392
|
+
if flat_idx.size == 0:
|
393
|
+
return weight
|
394
|
+
|
395
|
+
rows = weight[flat_idx]
|
396
|
+
rows_flat = rows.reshape((rows.shape[0], -1))
|
397
|
+
row_dtype = rows_flat.dtype
|
398
|
+
|
399
|
+
norms = jnp.linalg.norm(rows_flat, ord=self.norm_type, axis=1, keepdims=True)
|
400
|
+
max_norm = jnp.asarray(self.max_norm, dtype=row_dtype)
|
401
|
+
eps = jnp.asarray(1e-8, dtype=row_dtype)
|
402
|
+
one = jnp.asarray(1.0, dtype=row_dtype)
|
403
|
+
scale = jnp.minimum(one, max_norm / (norms + eps))
|
404
|
+
rows_scaled = (rows_flat * scale).reshape(rows.shape)
|
405
|
+
|
406
|
+
return weight.at[flat_idx].set(rows_scaled)
|
407
|
+
|
408
|
+
|
@@ -0,0 +1,156 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
import brainstate as bs
|
24
|
+
|
25
|
+
|
26
|
+
class TestEmbedding(unittest.TestCase):
|
27
|
+
"""Comprehensive tests for the Embedding module."""
|
28
|
+
|
29
|
+
def setUp(self):
|
30
|
+
settings = bs.environ.all()
|
31
|
+
self._prev_fit = settings.get('fit', None)
|
32
|
+
bs.environ.set(fit=True)
|
33
|
+
|
34
|
+
def tearDown(self):
|
35
|
+
if self._prev_fit is None:
|
36
|
+
bs.environ.pop('fit', None)
|
37
|
+
else:
|
38
|
+
bs.environ.set(fit=self._prev_fit)
|
39
|
+
|
40
|
+
def test_padding_idx_zero_gradient(self):
|
41
|
+
embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=3, padding_idx=0)
|
42
|
+
lookup = embedding._lookup
|
43
|
+
weight = jnp.arange(12.0, dtype=jnp.float32).reshape(4, 3)
|
44
|
+
indices = jnp.array([0, 1, 0, 2], dtype=jnp.int32)
|
45
|
+
|
46
|
+
def loss_fn(w):
|
47
|
+
return jnp.sum(lookup(w, indices))
|
48
|
+
|
49
|
+
grad = jax.grad(loss_fn)(weight)
|
50
|
+
self.assertTrue(jnp.allclose(grad[0], 0.0))
|
51
|
+
self.assertFalse(jnp.allclose(grad[1], 0.0))
|
52
|
+
|
53
|
+
def test_scale_grad_by_freq(self):
|
54
|
+
base = bs.nn.Embedding(num_embeddings=5, embedding_size=2)
|
55
|
+
scaled = bs.nn.Embedding(num_embeddings=5, embedding_size=2, scale_grad_by_freq=True)
|
56
|
+
base_lookup = base._lookup
|
57
|
+
scaled_lookup = scaled._lookup
|
58
|
+
|
59
|
+
weight = jnp.arange(10.0, dtype=jnp.float32).reshape(5, 2)
|
60
|
+
indices = jnp.array([1, 1, 2], dtype=jnp.int32)
|
61
|
+
|
62
|
+
def loss_base(w):
|
63
|
+
return jnp.sum(base_lookup(w, indices))
|
64
|
+
|
65
|
+
def loss_scaled(w):
|
66
|
+
return jnp.sum(scaled_lookup(w, indices))
|
67
|
+
|
68
|
+
base_grad = jax.grad(loss_base)(weight)
|
69
|
+
scaled_grad = jax.grad(loss_scaled)(weight)
|
70
|
+
|
71
|
+
counts = jnp.bincount(indices, length=weight.shape[0])
|
72
|
+
ones = jnp.ones((indices.shape[0], weight.shape[1]), dtype=weight.dtype)
|
73
|
+
expected_base = jnp.zeros_like(weight).at[indices].add(ones)
|
74
|
+
expected_scaled = jnp.where(counts[:, None] > 0, jnp.ones_like(weight), 0.0)
|
75
|
+
|
76
|
+
self.assertTrue(jnp.allclose(base_grad, expected_base))
|
77
|
+
self.assertTrue(jnp.allclose(scaled_grad, expected_scaled))
|
78
|
+
|
79
|
+
def test_lookup_grad_jit_consistent(self):
|
80
|
+
embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=2)
|
81
|
+
lookup = embedding._lookup
|
82
|
+
weight = jnp.arange(8.0, dtype=jnp.float32).reshape(4, 2)
|
83
|
+
indices = jnp.array([0, 1, 1, 3], dtype=jnp.int32)
|
84
|
+
|
85
|
+
def loss_fn(w):
|
86
|
+
return jnp.sum(lookup(w, indices))
|
87
|
+
|
88
|
+
grad_eager = jax.grad(loss_fn)(weight)
|
89
|
+
grad_jitted = jax.grad(jax.jit(loss_fn))(weight)
|
90
|
+
|
91
|
+
expected = jnp.zeros_like(weight).at[indices].add(jnp.ones((indices.shape[0], weight.shape[1]), dtype=weight.dtype))
|
92
|
+
|
93
|
+
self.assertTrue(jnp.allclose(grad_eager, grad_jitted))
|
94
|
+
self.assertTrue(jnp.allclose(grad_eager, expected))
|
95
|
+
|
96
|
+
def test_jit_forward_with_max_norm(self):
|
97
|
+
embedding = bs.nn.Embedding(num_embeddings=3, embedding_size=3, max_norm=0.5)
|
98
|
+
heavy = jnp.array([[0.0, 0.0, 0.0], [1.0, 2.0, 2.0], [0.3, -0.4, 0.5]], dtype=jnp.float32)
|
99
|
+
embedding.weight.value = heavy
|
100
|
+
indices = jnp.array([1, 2, 1], dtype=jnp.int32)
|
101
|
+
|
102
|
+
compiled = jax.jit(lambda ids: embedding(ids))
|
103
|
+
out = compiled(indices)
|
104
|
+
self.assertEqual(out.shape, (3, 3))
|
105
|
+
output_norms = jnp.linalg.norm(out, axis=-1)
|
106
|
+
self.assertTrue(jnp.all(output_norms <= 0.5 + 1e-6))
|
107
|
+
# Weight remains unclipped during JIT execution but must be usable without tracer leaks
|
108
|
+
self.assertGreater(float(jnp.linalg.norm(embedding.weight.value[1])), 0.5)
|
109
|
+
|
110
|
+
def test_freeze_disables_gradients(self):
|
111
|
+
embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=2, freeze=True)
|
112
|
+
indices = jnp.array([1, 2, 3], dtype=jnp.int32)
|
113
|
+
|
114
|
+
def loss_fn(weight):
|
115
|
+
embedding.weight.value = weight
|
116
|
+
return jnp.sum(embedding(indices))
|
117
|
+
|
118
|
+
grad = jax.grad(loss_fn)(embedding.weight.value)
|
119
|
+
self.assertTrue(jnp.allclose(grad, 0.0))
|
120
|
+
|
121
|
+
def test_from_pretrained_defaults_to_freeze(self):
|
122
|
+
pretrained = jnp.arange(12.0, dtype=jnp.float32).reshape(4, 3)
|
123
|
+
embedding = bs.nn.Embedding.from_pretrained(pretrained)
|
124
|
+
self.assertTrue(embedding.freeze)
|
125
|
+
|
126
|
+
def loss_fn(weight):
|
127
|
+
embedding.weight.value = weight
|
128
|
+
return jnp.sum(embedding(jnp.array([1, 2], dtype=jnp.int32)))
|
129
|
+
|
130
|
+
grad = jax.grad(loss_fn)(embedding.weight.value)
|
131
|
+
self.assertTrue(jnp.allclose(grad, 0.0))
|
132
|
+
|
133
|
+
def test_from_pretrained_unfrozen_gradients(self):
|
134
|
+
pretrained = jnp.arange(6.0, dtype=jnp.float32).reshape(2, 3)
|
135
|
+
embedding = bs.nn.Embedding.from_pretrained(pretrained, freeze=False)
|
136
|
+
|
137
|
+
def loss_fn(weight):
|
138
|
+
embedding.weight.value = weight
|
139
|
+
return jnp.sum(embedding(jnp.array([0, 1], dtype=jnp.int32)))
|
140
|
+
|
141
|
+
grad = jax.grad(loss_fn)(embedding.weight.value)
|
142
|
+
self.assertFalse(jnp.allclose(grad, 0.0))
|
143
|
+
|
144
|
+
def test_max_norm_renormalizes_weights(self):
|
145
|
+
embedding = bs.nn.Embedding(num_embeddings=3, embedding_size=3, max_norm=1.0, norm_type=2.0)
|
146
|
+
custom_weight = jnp.array([[0.0, 0.0, 0.0], [3.0, 4.0, 0.0], [0.5, 0.5, 0.5]], dtype=jnp.float32)
|
147
|
+
embedding.weight.value = custom_weight
|
148
|
+
_ = embedding(jnp.array([1, 2], dtype=jnp.int32))
|
149
|
+
|
150
|
+
row_norm = jnp.linalg.norm(embedding.weight.value[1])
|
151
|
+
self.assertLessEqual(float(row_norm), 1.0 + 1e-6)
|
152
|
+
self.assertTrue(jnp.allclose(embedding.weight.value[0], custom_weight[0]))
|
153
|
+
|
154
|
+
|
155
|
+
if __name__ == '__main__':
|
156
|
+
unittest.main()
|