brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,58 +1,408 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from typing import Optional, Callable, Union
17
-
18
- from brainstate import init
19
- from brainstate._state import ParamState
20
- from brainstate.typing import ArrayLike
21
- from ._module import Module
22
-
23
- __all__ = [
24
- 'Embedding',
25
- ]
26
-
27
-
28
- class Embedding(Module):
29
- r"""
30
- A simple lookup table that stores embeddings of a fixed size.
31
-
32
- Args:
33
- num_embeddings: Size of embedding dictionary. Must be non-negative.
34
- embedding_size: Size of each embedding vector. Must be non-negative.
35
- embedding_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
36
- """
37
-
38
- def __init__(
39
- self,
40
- num_embeddings: int,
41
- embedding_size: int,
42
- embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
43
- name: Optional[str] = None,
44
- ):
45
- super().__init__(name=name)
46
- if num_embeddings < 0:
47
- raise ValueError("num_embeddings must not be negative.")
48
- if embedding_size < 0:
49
- raise ValueError("embedding_size must not be negative.")
50
- self.num_embeddings = num_embeddings
51
- self.embedding_size = embedding_size
52
- self.out_size = (embedding_size,)
53
-
54
- weight = init.param(embedding_init, (self.num_embeddings, self.embedding_size))
55
- self.weight = ParamState(weight)
56
-
57
- def update(self, indices: ArrayLike):
58
- return self.weight.value[indices]
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from functools import lru_cache
17
+ from typing import Callable, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import jax.tree_util as jtu
23
+ from jax import core as jax_core
24
+
25
+ from brainstate._state import ParamState
26
+ from brainstate.typing import ArrayLike, Size
27
+ from . import init as init
28
+ from ._module import Module
29
+
30
+ __all__ = [
31
+ 'Embedding',
32
+ ]
33
+
34
+
35
+ def _normalize_embedding_size(size: Size) -> Tuple[int, ...]:
36
+ """Convert ``Size`` specifications to a validated tuple of integers."""
37
+ size_array = np.asarray(size)
38
+ if size_array.size == 0:
39
+ raise ValueError('embedding_size must contain at least one dimension.')
40
+ flat = size_array.reshape(-1)
41
+ normalized = tuple(int(dim) for dim in flat)
42
+ if any(dim < 0 for dim in normalized):
43
+ raise ValueError('embedding_size must not contain negative values.')
44
+ return normalized
45
+
46
+
47
+ @lru_cache(maxsize=None)
48
+ def _embedding_lookup_fn(
49
+ padding_idx: Optional[int],
50
+ scale_grad_by_freq: bool,
51
+ ):
52
+ """Return a lookup function with a custom VJP implementing embedding semantics."""
53
+
54
+ @jax.custom_vjp
55
+ def _lookup(weight: jax.Array, indices: jax.Array) -> jax.Array:
56
+ indices = jnp.asarray(indices)
57
+ return weight[indices]
58
+
59
+ def _lookup_fwd(weight: jax.Array, indices: jax.Array):
60
+ indices = jnp.asarray(indices)
61
+ return weight[indices], (indices, weight.shape)
62
+
63
+ def _lookup_bwd(residual, grad_output: jax.Array):
64
+ indices, weight_shape = residual
65
+ grad_output = jnp.asarray(grad_output)
66
+ flat_idx = jnp.ravel(indices)
67
+ if flat_idx.size == 0:
68
+ return jnp.zeros(weight_shape, dtype=grad_output.dtype), None
69
+
70
+ flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
71
+ grad_flat = grad_output.reshape((flat_idx.shape[0],) + weight_shape[1:])
72
+
73
+ if scale_grad_by_freq:
74
+ counts = jnp.bincount(flat_idx, length=weight_shape[0])
75
+ counts = counts.astype(grad_flat.dtype)
76
+ counts = jnp.where(counts == 0, 1.0, counts)
77
+ scale = counts[flat_idx]
78
+ grad_flat = grad_flat / scale.reshape((flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1))
79
+
80
+ if padding_idx is not None:
81
+ pad_value = jnp.asarray(padding_idx, dtype=flat_idx.dtype)
82
+ mask = flat_idx != pad_value
83
+ broadcast_shape = (flat_idx.shape[0],) + (1,) * (grad_flat.ndim - 1)
84
+ grad_flat = grad_flat * mask.reshape(broadcast_shape).astype(grad_flat.dtype)
85
+
86
+ grad_weight = jnp.zeros(weight_shape, dtype=grad_output.dtype)
87
+ grad_weight = grad_weight.at[flat_idx].add(grad_flat)
88
+ return grad_weight, None
89
+
90
+ _lookup.defvjp(_lookup_fwd, _lookup_bwd)
91
+ return _lookup
92
+
93
+
94
+ def _contains_tracer(tree) -> bool:
95
+ """Return True if the pytree contains any JAX tracer values."""
96
+ return any(isinstance(leaf, jax_core.Tracer) for leaf in jtu.tree_leaves(tree))
97
+
98
+
99
+
100
+ class Embedding(Module):
101
+ r"""
102
+ A simple lookup table that stores embeddings of a fixed size.
103
+
104
+ This module is commonly used to store word embeddings and retrieve them using indices.
105
+ The input to the module is a list of indices, and the output is the corresponding
106
+ word embeddings.
107
+
108
+ Parameters
109
+ ----------
110
+ num_embeddings : int
111
+ Size of embedding dictionary. Must be non-negative.
112
+ embedding_size : Size
113
+ Size of each embedding vector. Can be an int or a sequence of ints, and must
114
+ contain only non-negative values.
115
+ embedding_init : Callable or ArrayLike, optional
116
+ The initializer for the embedding lookup table, of shape
117
+ ``(num_embeddings, embedding_size)``. Default is ``LecunUniform()``.
118
+ padding_idx : int, optional
119
+ If specified, the entries at ``padding_idx`` do not contribute to the gradient;
120
+ therefore, the embedding vector at ``padding_idx`` is not updated during training,
121
+ i.e., it remains as a fixed "pad". For a newly constructed Embedding, the embedding
122
+ vector at ``padding_idx`` will default to all zeros. Default is ``None``.
123
+ max_norm : float, optional
124
+ If given, each embedding vector with norm larger than ``max_norm`` is renormalized
125
+ to have norm ``max_norm``. Default is ``None``.
126
+ norm_type : float, optional
127
+ The p of the p-norm to compute for the ``max_norm`` option. Default is ``2.0``.
128
+ scale_grad_by_freq : bool, optional
129
+ If given, this scales gradients by the inverse frequency of the words in
130
+ the mini-batch. Default is ``False``.
131
+ name : str, optional
132
+ The name of the module.
133
+ param_type : type, optional
134
+ The parameter state type to use. Default is ``ParamState``.
135
+
136
+ Attributes
137
+ ----------
138
+ num_embeddings : int
139
+ Size of the embedding dictionary.
140
+ embedding_size : tuple[int, ...]
141
+ Size of each embedding vector.
142
+ out_size : tuple[int, ...]
143
+ Output size, same as ``embedding_size``.
144
+ weight : ParamState
145
+ The learnable weights of the module of shape
146
+ ``(num_embeddings, *embedding_size)``.
147
+ padding_idx : int or None
148
+ Index of the padding token.
149
+ max_norm : float or None
150
+ Maximum norm for embedding vectors.
151
+ norm_type : float
152
+ Type of p-norm to compute for max_norm.
153
+ scale_grad_by_freq : bool
154
+ Whether to scale gradients by frequency.
155
+ freeze : bool
156
+ Whether the embedding weights are frozen.
157
+
158
+ Examples
159
+ --------
160
+ Create an embedding layer with 10 words and 3-dimensional embeddings:
161
+
162
+ .. code-block:: python
163
+
164
+ >>> import brainstate as bst
165
+ >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3)
166
+ >>> embedding.weight.value.shape
167
+ (10, 3)
168
+
169
+ Retrieve embeddings for specific indices:
170
+
171
+ .. code-block:: python
172
+
173
+ >>> import jax.numpy as jnp
174
+ >>> indices = jnp.array([1, 3, 5])
175
+ >>> output = embedding(indices)
176
+ >>> output.shape
177
+ (3, 3)
178
+
179
+ Use with a batch of sequences:
180
+
181
+ .. code-block:: python
182
+
183
+ >>> # Batch of 2 sequences, each with 4 tokens
184
+ >>> batch_indices = jnp.array([[1, 2, 3, 4],
185
+ ... [5, 6, 7, 8]])
186
+ >>> output = embedding(batch_indices)
187
+ >>> output.shape
188
+ (2, 4, 3)
189
+
190
+ Use padding_idx to keep padding embeddings fixed:
191
+
192
+ .. code-block:: python
193
+
194
+ >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, padding_idx=0)
195
+ >>> # The embedding at index 0 will remain zeros and not be updated during training
196
+ >>> indices = jnp.array([0, 2, 0, 5])
197
+ >>> output = embedding(indices)
198
+ >>> output[0] # Will be zeros
199
+ Array([0., 0., 0.], dtype=float32)
200
+
201
+ Use max_norm to constrain embedding norms:
202
+
203
+ .. code-block:: python
204
+
205
+ >>> embedding = bst.nn.Embedding(num_embeddings=10, embedding_size=3, max_norm=1.0)
206
+ >>> # All embeddings accessed in a forward pass are renormalized to have norm <= 1.0
207
+
208
+ Load pretrained embeddings:
209
+
210
+ .. code-block:: python
211
+
212
+ >>> import brainstate
213
+ >>> import jax.numpy as jnp
214
+ >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
215
+ ... [4.0, 5.0, 6.0],
216
+ ... [7.0, 8.0, 9.0]])
217
+ >>> embedding = bst.nn.Embedding.from_pretrained(pretrained, param_type=brainstate.FakeState)
218
+ >>> embedding.weight.value.shape
219
+ (3, 3)
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ num_embeddings: int,
225
+ embedding_size: Size,
226
+ embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
227
+ padding_idx: Optional[int] = None,
228
+ max_norm: Optional[float] = None,
229
+ norm_type: float = 2.0,
230
+ scale_grad_by_freq: bool = False,
231
+ freeze: bool = False,
232
+ name: Optional[str] = None,
233
+ param_type: type = ParamState,
234
+ ):
235
+ super().__init__(name=name)
236
+
237
+ self.num_embeddings = int(num_embeddings)
238
+ if self.num_embeddings < 0:
239
+ raise ValueError('num_embeddings must not be negative.')
240
+
241
+ embedding_size_tuple = _normalize_embedding_size(embedding_size)
242
+ self.embedding_size = embedding_size_tuple
243
+ self.out_size = embedding_size_tuple
244
+
245
+ if padding_idx is not None:
246
+ padding_idx = int(padding_idx)
247
+ if padding_idx < 0 or padding_idx >= self.num_embeddings:
248
+ raise ValueError(f'padding_idx must be within [0, {self.num_embeddings}).')
249
+ self.padding_idx = padding_idx
250
+
251
+ if max_norm is not None and max_norm <= 0:
252
+ raise ValueError('max_norm must be positive when provided.')
253
+ self.max_norm = max_norm
254
+ self.norm_type = norm_type
255
+ self.scale_grad_by_freq = bool(scale_grad_by_freq)
256
+ self.freeze = bool(freeze)
257
+
258
+ weight_shape = (self.num_embeddings, *self.out_size)
259
+ weight = init.param(embedding_init, weight_shape)
260
+
261
+ if self.padding_idx is not None:
262
+ weight = weight.at[self.padding_idx].set(0)
263
+
264
+ self.weight = param_type(weight)
265
+ self._lookup = _embedding_lookup_fn(self.padding_idx, self.scale_grad_by_freq)
266
+
267
+ @classmethod
268
+ def from_pretrained(
269
+ cls,
270
+ embeddings: ArrayLike,
271
+ padding_idx: Optional[int] = None,
272
+ max_norm: Optional[float] = None,
273
+ norm_type: float = 2.0,
274
+ scale_grad_by_freq: bool = False,
275
+ freeze: bool = True,
276
+ name: Optional[str] = None,
277
+ param_type: type = ParamState,
278
+ ):
279
+ r"""
280
+ Create an Embedding instance from given 2-dimensional array.
281
+
282
+ Parameters
283
+ ----------
284
+ embeddings : ArrayLike
285
+ Array containing weights for the Embedding. First dimension is passed to
286
+ Embedding as ``num_embeddings``, remaining dimensions as ``embedding_size``.
287
+ padding_idx : int, optional
288
+ If specified, the entries at ``padding_idx`` do not contribute to the gradient.
289
+ Default is ``None``.
290
+ max_norm : float, optional
291
+ See module initialization documentation. Default is ``None``.
292
+ norm_type : float, optional
293
+ See module initialization documentation. Default is ``2.0``.
294
+ scale_grad_by_freq : bool, optional
295
+ See module initialization documentation. Default is ``False``.
296
+ freeze : bool, optional
297
+ If ``True``, embeddings are frozen (no gradients). Default is ``True``.
298
+ name : str, optional
299
+ The name of the module.
300
+
301
+ Returns
302
+ -------
303
+ Embedding
304
+ An Embedding module with pretrained weights.
305
+
306
+ Examples
307
+ --------
308
+ Load pretrained word embeddings:
309
+
310
+ .. code-block:: python
311
+
312
+ >>> import jax.numpy as jnp
313
+ >>> import brainstate as bst
314
+ >>> pretrained = jnp.array([[1.0, 2.0, 3.0],
315
+ ... [4.0, 5.0, 6.0],
316
+ ... [7.0, 8.0, 9.0]])
317
+ >>> embedding = bst.nn.Embedding.from_pretrained(pretrained)
318
+ >>> embedding.weight.value.shape
319
+ (3, 3)
320
+ >>> indices = jnp.array([1])
321
+ >>> embedding(indices)
322
+ Array([[4., 5., 6.]], dtype=float32)
323
+ """
324
+ embeddings = jnp.asarray(embeddings)
325
+ if embeddings.ndim < 2:
326
+ raise ValueError('embeddings must be at least 2-dimensional')
327
+
328
+ num_embeddings = embeddings.shape[0]
329
+ embedding_size = embeddings.shape[1:]
330
+
331
+ instance = cls(
332
+ num_embeddings=num_embeddings,
333
+ embedding_size=embedding_size,
334
+ embedding_init=embeddings,
335
+ padding_idx=padding_idx,
336
+ max_norm=max_norm,
337
+ norm_type=norm_type,
338
+ scale_grad_by_freq=scale_grad_by_freq,
339
+ freeze=freeze,
340
+ name=name,
341
+ param_type=param_type,
342
+ )
343
+
344
+ instance.weight = param_type(jnp.asarray(embeddings))
345
+ return instance
346
+
347
+ def update(self, indices: ArrayLike):
348
+ """
349
+ Retrieve embeddings for the given indices.
350
+
351
+ Parameters
352
+ ----------
353
+ indices : ArrayLike
354
+ Indices to retrieve embeddings for. Can be any shape.
355
+
356
+ Returns
357
+ -------
358
+ ArrayLike
359
+ Embeddings corresponding to the indices, with shape
360
+ ``(*indices.shape, *embedding_size)``.
361
+ """
362
+ indices = jnp.asarray(indices)
363
+ if not jnp.issubdtype(indices.dtype, jnp.integer):
364
+ raise TypeError('Embedding indices must be integers.')
365
+
366
+ weight_value = self.weight.value
367
+ effective_weight = weight_value
368
+
369
+ if self.max_norm is not None:
370
+ renormed_weight = self._apply_max_norm(weight_value, indices)
371
+ effective_weight = weight_value + jax.lax.stop_gradient(renormed_weight - weight_value)
372
+ if not _contains_tracer(renormed_weight):
373
+ self.weight.value = renormed_weight
374
+
375
+ if self.freeze:
376
+ effective_weight = jax.lax.stop_gradient(effective_weight)
377
+
378
+ embeddings = self._lookup(effective_weight, indices)
379
+ return embeddings
380
+
381
+ def _apply_max_norm(self, weight: jax.Array, indices: jax.Array) -> jax.Array:
382
+ """Apply max_norm constraint to the embedding weights for the given indices."""
383
+ flat_idx = jnp.ravel(indices)
384
+ if flat_idx.size == 0:
385
+ return weight
386
+
387
+ flat_idx = jnp.asarray(flat_idx, dtype=jnp.int32)
388
+ if self.padding_idx is not None:
389
+ pad_value = jnp.asarray(self.padding_idx, dtype=flat_idx.dtype)
390
+ flat_idx = flat_idx[flat_idx != pad_value]
391
+
392
+ if flat_idx.size == 0:
393
+ return weight
394
+
395
+ rows = weight[flat_idx]
396
+ rows_flat = rows.reshape((rows.shape[0], -1))
397
+ row_dtype = rows_flat.dtype
398
+
399
+ norms = jnp.linalg.norm(rows_flat, ord=self.norm_type, axis=1, keepdims=True)
400
+ max_norm = jnp.asarray(self.max_norm, dtype=row_dtype)
401
+ eps = jnp.asarray(1e-8, dtype=row_dtype)
402
+ one = jnp.asarray(1.0, dtype=row_dtype)
403
+ scale = jnp.minimum(one, max_norm / (norms + eps))
404
+ rows_scaled = (rows_flat * scale).reshape(rows.shape)
405
+
406
+ return weight.at[flat_idx].set(rows_scaled)
407
+
408
+
@@ -0,0 +1,156 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ import unittest
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+
23
+ import brainstate as bs
24
+
25
+
26
+ class TestEmbedding(unittest.TestCase):
27
+ """Comprehensive tests for the Embedding module."""
28
+
29
+ def setUp(self):
30
+ settings = bs.environ.all()
31
+ self._prev_fit = settings.get('fit', None)
32
+ bs.environ.set(fit=True)
33
+
34
+ def tearDown(self):
35
+ if self._prev_fit is None:
36
+ bs.environ.pop('fit', None)
37
+ else:
38
+ bs.environ.set(fit=self._prev_fit)
39
+
40
+ def test_padding_idx_zero_gradient(self):
41
+ embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=3, padding_idx=0)
42
+ lookup = embedding._lookup
43
+ weight = jnp.arange(12.0, dtype=jnp.float32).reshape(4, 3)
44
+ indices = jnp.array([0, 1, 0, 2], dtype=jnp.int32)
45
+
46
+ def loss_fn(w):
47
+ return jnp.sum(lookup(w, indices))
48
+
49
+ grad = jax.grad(loss_fn)(weight)
50
+ self.assertTrue(jnp.allclose(grad[0], 0.0))
51
+ self.assertFalse(jnp.allclose(grad[1], 0.0))
52
+
53
+ def test_scale_grad_by_freq(self):
54
+ base = bs.nn.Embedding(num_embeddings=5, embedding_size=2)
55
+ scaled = bs.nn.Embedding(num_embeddings=5, embedding_size=2, scale_grad_by_freq=True)
56
+ base_lookup = base._lookup
57
+ scaled_lookup = scaled._lookup
58
+
59
+ weight = jnp.arange(10.0, dtype=jnp.float32).reshape(5, 2)
60
+ indices = jnp.array([1, 1, 2], dtype=jnp.int32)
61
+
62
+ def loss_base(w):
63
+ return jnp.sum(base_lookup(w, indices))
64
+
65
+ def loss_scaled(w):
66
+ return jnp.sum(scaled_lookup(w, indices))
67
+
68
+ base_grad = jax.grad(loss_base)(weight)
69
+ scaled_grad = jax.grad(loss_scaled)(weight)
70
+
71
+ counts = jnp.bincount(indices, length=weight.shape[0])
72
+ ones = jnp.ones((indices.shape[0], weight.shape[1]), dtype=weight.dtype)
73
+ expected_base = jnp.zeros_like(weight).at[indices].add(ones)
74
+ expected_scaled = jnp.where(counts[:, None] > 0, jnp.ones_like(weight), 0.0)
75
+
76
+ self.assertTrue(jnp.allclose(base_grad, expected_base))
77
+ self.assertTrue(jnp.allclose(scaled_grad, expected_scaled))
78
+
79
+ def test_lookup_grad_jit_consistent(self):
80
+ embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=2)
81
+ lookup = embedding._lookup
82
+ weight = jnp.arange(8.0, dtype=jnp.float32).reshape(4, 2)
83
+ indices = jnp.array([0, 1, 1, 3], dtype=jnp.int32)
84
+
85
+ def loss_fn(w):
86
+ return jnp.sum(lookup(w, indices))
87
+
88
+ grad_eager = jax.grad(loss_fn)(weight)
89
+ grad_jitted = jax.grad(jax.jit(loss_fn))(weight)
90
+
91
+ expected = jnp.zeros_like(weight).at[indices].add(jnp.ones((indices.shape[0], weight.shape[1]), dtype=weight.dtype))
92
+
93
+ self.assertTrue(jnp.allclose(grad_eager, grad_jitted))
94
+ self.assertTrue(jnp.allclose(grad_eager, expected))
95
+
96
+ def test_jit_forward_with_max_norm(self):
97
+ embedding = bs.nn.Embedding(num_embeddings=3, embedding_size=3, max_norm=0.5)
98
+ heavy = jnp.array([[0.0, 0.0, 0.0], [1.0, 2.0, 2.0], [0.3, -0.4, 0.5]], dtype=jnp.float32)
99
+ embedding.weight.value = heavy
100
+ indices = jnp.array([1, 2, 1], dtype=jnp.int32)
101
+
102
+ compiled = jax.jit(lambda ids: embedding(ids))
103
+ out = compiled(indices)
104
+ self.assertEqual(out.shape, (3, 3))
105
+ output_norms = jnp.linalg.norm(out, axis=-1)
106
+ self.assertTrue(jnp.all(output_norms <= 0.5 + 1e-6))
107
+ # Weight remains unclipped during JIT execution but must be usable without tracer leaks
108
+ self.assertGreater(float(jnp.linalg.norm(embedding.weight.value[1])), 0.5)
109
+
110
+ def test_freeze_disables_gradients(self):
111
+ embedding = bs.nn.Embedding(num_embeddings=4, embedding_size=2, freeze=True)
112
+ indices = jnp.array([1, 2, 3], dtype=jnp.int32)
113
+
114
+ def loss_fn(weight):
115
+ embedding.weight.value = weight
116
+ return jnp.sum(embedding(indices))
117
+
118
+ grad = jax.grad(loss_fn)(embedding.weight.value)
119
+ self.assertTrue(jnp.allclose(grad, 0.0))
120
+
121
+ def test_from_pretrained_defaults_to_freeze(self):
122
+ pretrained = jnp.arange(12.0, dtype=jnp.float32).reshape(4, 3)
123
+ embedding = bs.nn.Embedding.from_pretrained(pretrained)
124
+ self.assertTrue(embedding.freeze)
125
+
126
+ def loss_fn(weight):
127
+ embedding.weight.value = weight
128
+ return jnp.sum(embedding(jnp.array([1, 2], dtype=jnp.int32)))
129
+
130
+ grad = jax.grad(loss_fn)(embedding.weight.value)
131
+ self.assertTrue(jnp.allclose(grad, 0.0))
132
+
133
+ def test_from_pretrained_unfrozen_gradients(self):
134
+ pretrained = jnp.arange(6.0, dtype=jnp.float32).reshape(2, 3)
135
+ embedding = bs.nn.Embedding.from_pretrained(pretrained, freeze=False)
136
+
137
+ def loss_fn(weight):
138
+ embedding.weight.value = weight
139
+ return jnp.sum(embedding(jnp.array([0, 1], dtype=jnp.int32)))
140
+
141
+ grad = jax.grad(loss_fn)(embedding.weight.value)
142
+ self.assertFalse(jnp.allclose(grad, 0.0))
143
+
144
+ def test_max_norm_renormalizes_weights(self):
145
+ embedding = bs.nn.Embedding(num_embeddings=3, embedding_size=3, max_norm=1.0, norm_type=2.0)
146
+ custom_weight = jnp.array([[0.0, 0.0, 0.0], [3.0, 4.0, 0.0], [0.5, 0.5, 0.5]], dtype=jnp.float32)
147
+ embedding.weight.value = custom_weight
148
+ _ = embedding(jnp.array([1, 2], dtype=jnp.int32))
149
+
150
+ row_norm = jnp.linalg.norm(embedding.weight.value[1])
151
+ self.assertLessEqual(float(row_norm), 1.0 + 1e-6)
152
+ self.assertTrue(jnp.allclose(embedding.weight.value[0], custom_weight[0]))
153
+
154
+
155
+ if __name__ == '__main__':
156
+ unittest.main()