brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
+ from __future__ import annotations
17
18
 
18
19
  from typing import Sequence, Union
19
20
 
@@ -21,22 +22,22 @@ import jax
21
22
  import jax.numpy as jnp
22
23
  import numpy as np
23
24
 
24
- from .. import environ
25
- from .._module import Module
26
- from .._state import State, LongTermState
25
+ from brainstate import environ
26
+ from brainstate._state import State, LongTermState
27
+ from brainstate.graph import Node
27
28
 
28
29
  __all__ = [
29
- 'LearningRateScheduler',
30
- 'ConstantLR',
31
- 'StepLR',
32
- 'MultiStepLR',
33
- 'CosineAnnealingLR',
34
- 'CosineAnnealingWarmRestarts',
35
- 'ExponentialLR',
36
- 'ExponentialDecayLR',
37
- 'InverseTimeDecayLR',
38
- 'PolynomialDecayLR',
39
- 'PiecewiseConstantLR',
30
+ 'LearningRateScheduler',
31
+ 'ConstantLR',
32
+ 'StepLR',
33
+ 'MultiStepLR',
34
+ 'CosineAnnealingLR',
35
+ 'CosineAnnealingWarmRestarts',
36
+ 'ExponentialLR',
37
+ 'ExponentialDecayLR',
38
+ 'InverseTimeDecayLR',
39
+ 'PolynomialDecayLR',
40
+ 'PiecewiseConstantLR',
40
41
  ]
41
42
 
42
43
 
@@ -45,442 +46,404 @@ __all__ = [
45
46
 
46
47
 
47
48
  def make_schedule(scalar_or_schedule):
48
- if isinstance(scalar_or_schedule, LearningRateScheduler):
49
- return scalar_or_schedule
50
- elif isinstance(scalar_or_schedule, (int, float, State)):
51
- return ConstantLR(scalar_or_schedule)
52
- else:
53
- raise TypeError(type(scalar_or_schedule))
54
-
55
-
56
- class LearningRateScheduler(Module):
57
- """
58
- The learning rate scheduler.
59
-
60
- Attributes
61
- ----------
62
- lr: float, State
63
- The learning rate.
64
- last_epoch: int
65
- The index of last epoch.
66
-
67
- """
68
-
69
- def __init__(self, lr: Union[float, State], last_epoch: int = -1):
70
- super().__init__()
71
- if isinstance(lr, State):
72
- lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
49
+ if isinstance(scalar_or_schedule, LearningRateScheduler):
50
+ return scalar_or_schedule
51
+ elif isinstance(scalar_or_schedule, (int, float, State)):
52
+ return ConstantLR(scalar_or_schedule)
73
53
  else:
74
- lr = jnp.asarray(lr, dtype=environ.dftype())
75
- self._lr = lr
76
- assert last_epoch >= -1, 'last_epoch should be greater than -1.'
77
- self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
78
-
79
- @property
80
- def lr(self):
81
- return self._lr.value if isinstance(self._lr, State) else self._lr
82
-
83
- @lr.setter
84
- def lr(self, value):
85
- if isinstance(value, State):
86
- value = value.value
87
- assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
88
- if isinstance(self._lr, State):
89
- self._lr.value = value
90
- else:
91
- self._lr = value
54
+ raise TypeError(type(scalar_or_schedule))
92
55
 
93
- def step_epoch(self):
94
- """
95
- Update the epoch count.
96
- """
97
- self.last_epoch.value += 1
98
56
 
99
- def step_call(self):
100
- """
101
- Update the call count.
57
+ class LearningRateScheduler(Node):
102
58
  """
103
- pass
59
+ The learning rate scheduler.
104
60
 
105
- def __repr__(self):
106
- return f'{self.__class__.__name__}(lr={self.lr.value}, last_epoch={self.last_epoch.value}{self.extra_repr()})'
61
+ Parameters
62
+ ----------
63
+ lr: float, State
64
+ The learning rate.
65
+ last_epoch: int
66
+ The index of last epoch.
107
67
 
108
- def extra_repr(self):
109
- return ''
68
+ """
110
69
 
111
- def __call__(self, i=None):
112
- raise NotImplementedError
70
+ def __init__(self, lr: Union[float, State], last_epoch: int = -1):
71
+ super().__init__()
72
+ if isinstance(lr, State):
73
+ lr.value = jnp.asarray(lr.value, dtype=environ.dftype())
74
+ else:
75
+ lr = jnp.asarray(lr, dtype=environ.dftype())
76
+ self._lr = lr
77
+ assert last_epoch >= -1, 'last_epoch should be greater than -1.'
78
+ self.last_epoch = LongTermState(jnp.asarray(last_epoch, dtype=environ.ditype()))
79
+
80
+ @property
81
+ def lr(self):
82
+ return self._lr.value if isinstance(self._lr, State) else self._lr
83
+
84
+ @lr.setter
85
+ def lr(self, value):
86
+ if isinstance(value, State):
87
+ value = value.value
88
+ assert jnp.ndim(value) == 0, 'The learning rate should be a scalar.'
89
+ if isinstance(self._lr, State):
90
+ self._lr.value = value
91
+ else:
92
+ self._lr = value
93
+
94
+ def step_epoch(self):
95
+ """
96
+ Update the epoch count.
97
+ """
98
+ self.last_epoch.value += 1
99
+
100
+ def step_call(self):
101
+ """
102
+ Update the call count.
103
+ """
104
+ pass
105
+
106
+ def __call__(self, i=None):
107
+ raise NotImplementedError
113
108
 
114
109
 
115
110
  class ConstantLR(LearningRateScheduler):
116
- """
117
- Constant learning rate scheduler.
118
- """
111
+ """
112
+ Constant learning rate scheduler.
113
+ """
119
114
 
120
- def __call__(self, i=None):
121
- return self.lr
115
+ def __call__(self, i=None):
116
+ return self.lr
122
117
 
123
118
 
124
119
  class CallBasedLRScheduler(LearningRateScheduler):
125
- """
126
- The learning rate scheduler based on the call count.
120
+ """
121
+ The learning rate scheduler based on the call count.
122
+
123
+ Parameters
124
+ ----------
125
+ lr: float
126
+ The learning rate.
127
+ last_epoch: int
128
+ The index of last epoch.
129
+ last_call: int
130
+ The index of last call.
127
131
 
128
- Parameters
129
- ----------
130
- lr: float
131
- The learning rate.
132
- last_epoch: int
133
- The index of last epoch.
134
- last_call: int
135
- The index of last call.
132
+ """
136
133
 
137
- """
134
+ def __init__(self, lr: Union[float, State], last_epoch: int = -1, last_call: int = -1):
135
+ super().__init__(lr=lr, last_epoch=last_epoch)
138
136
 
139
- def __init__(self, lr: Union[float, State], last_epoch: int = -1, last_call: int = -1):
140
- super().__init__(lr=lr, last_epoch=last_epoch)
137
+ assert last_call >= -1, 'last_call should be greater than -1.'
138
+ self.last_call = LongTermState(jnp.asarray(last_call, dtype=environ.ditype()))
141
139
 
142
- assert last_call >= -1, 'last_call should be greater than -1.'
143
- self.last_call = LongTermState(jnp.asarray(last_call, dtype=environ.ditype()))
140
+ def step_call(self):
141
+ """
142
+ Update the call count.
143
+ """
144
+ self.last_call.value += 1
144
145
 
145
- def step_call(self):
146
- """
147
- Update the call count.
146
+
147
+ class StepLR(LearningRateScheduler):
148
+ """Decays the learning rate of each parameter group by gamma every
149
+ `step_size` epochs.
150
+
151
+ Parameters
152
+ ----------
153
+ lr: float
154
+ Initial learning rate.
155
+ step_size: int
156
+ Period of learning rate decay.
157
+ gamma: float
158
+ Multiplicative factor of learning rate decay.
159
+ Default: 0.1.
160
+ last_epoch: int
161
+ The index of last epoch. Default: -1.
148
162
  """
149
- self.last_call.value += 1
150
163
 
151
- def __repr__(self):
152
- return (f'{self.__class__.__name__}(lr={self.lr.value}, '
153
- f'last_epoch={self.last_epoch.value}, '
154
- f'last_call={self.last_call.value}{self.extra_repr()})')
164
+ def __init__(
165
+ self,
166
+ lr: float,
167
+ step_size: int,
168
+ gamma: float = 0.1,
169
+ last_epoch: int = -1
170
+ ):
171
+ super().__init__(lr=lr, last_epoch=last_epoch)
155
172
 
173
+ assert step_size >= 1, 'step_size should be greater than or equal to 1.'
174
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
175
+ self.step_size = step_size
176
+ self.gamma = gamma
156
177
 
157
- class StepLR(LearningRateScheduler):
158
- """Decays the learning rate of each parameter group by gamma every
159
- `step_size` epochs.
160
-
161
- Parameters
162
- ----------
163
- lr: float
164
- Initial learning rate.
165
- step_size: int
166
- Period of learning rate decay.
167
- gamma: float
168
- Multiplicative factor of learning rate decay.
169
- Default: 0.1.
170
- last_epoch: int
171
- The index of last epoch. Default: -1.
172
- """
173
-
174
- def __init__(
175
- self,
176
- lr: float,
177
- step_size: int,
178
- gamma: float = 0.1,
179
- last_epoch: int = -1
180
- ):
181
- super().__init__(lr=lr, last_epoch=last_epoch)
182
-
183
- assert step_size >= 1, 'step_size should be greater than or equal to 1.'
184
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
185
- self.step_size = step_size
186
- self.gamma = gamma
187
-
188
- def __call__(self, i=None):
189
- i = (self.last_epoch.value + 1) if i is None else i
190
- return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
191
-
192
- def extra_repr(self):
193
- return f', gamma={self.gamma}, step_size={self.step_size}'
178
+ def __call__(self, i=None):
179
+ i = (self.last_epoch.value + 1) if i is None else i
180
+ return self.lr * self.gamma ** (jnp.floor_divide(i, self.step_size))
194
181
 
195
182
 
196
183
  class MultiStepLR(LearningRateScheduler):
197
- """Decays the learning rate of each parameter group by gamma once the
198
- number of epoch reaches one of the milestones. Notice that such decay can
199
- happen simultaneously with other changes to the learning rate from outside
200
- this scheduler. When last_epoch=-1, sets initial lr as lr.
201
-
202
- Parameters
203
- ----------
204
- lr: float
205
- Initial learning rate.
206
- milestones: sequence of int
207
- List of epoch indices. Must be increasing.
208
- gamma: float
209
- Multiplicative factor of learning rate decay.
210
- Default: 0.1.
211
- last_epoch: int
212
- The index of last epoch. Default: -1.
213
- """
214
-
215
- def __init__(
216
- self,
217
- lr: float,
218
- milestones: Sequence[int],
219
- gamma: float = 0.1,
220
- last_epoch: int = -1
221
- ):
222
- super().__init__(lr=lr, last_epoch=last_epoch)
223
-
224
- assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
225
- assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
226
- 'milestones should be a sequence of increasing integers.'
227
- )
228
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
229
- self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=environ.ditype())
230
- self.gamma = gamma
231
-
232
- def __call__(self, i=None):
233
- i = (self.last_epoch.value + 1) if i is None else i
234
- conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
235
- p = jnp.argmax(conditions)
236
- return self.lr * self.gamma ** p
237
-
238
- def extra_repr(self):
239
- return f', milestones={self.milestones}, gamma={self.gamma}'
184
+ """Decays the learning rate of each parameter group by gamma once the
185
+ number of epoch reaches one of the milestones. Notice that such decay can
186
+ happen simultaneously with other changes to the learning rate from outside
187
+ this scheduler. When last_epoch=-1, sets initial lr as lr.
188
+
189
+ Parameters
190
+ ----------
191
+ lr: float
192
+ Initial learning rate.
193
+ milestones: sequence of int
194
+ List of epoch indices. Must be increasing.
195
+ gamma: float
196
+ Multiplicative factor of learning rate decay.
197
+ Default: 0.1.
198
+ last_epoch: int
199
+ The index of last epoch. Default: -1.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ lr: float,
205
+ milestones: Sequence[int],
206
+ gamma: float = 0.1,
207
+ last_epoch: int = -1
208
+ ):
209
+ super().__init__(lr=lr, last_epoch=last_epoch)
210
+
211
+ assert len(milestones) > 0, 'milestones should be a non-empty sequence.'
212
+ assert all([milestones[i] < milestones[i + 1] for i in range(len(milestones) - 1)]), (
213
+ 'milestones should be a sequence of increasing integers.'
214
+ )
215
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
216
+ self.milestones = jnp.asarray((-1,) + tuple(milestones) + (np.iinfo(np.int32).max,), dtype=environ.ditype())
217
+ self.gamma = gamma
218
+
219
+ def __call__(self, i=None):
220
+ i = (self.last_epoch.value + 1) if i is None else i
221
+ conditions = jnp.logical_and((i >= self.milestones[:-1]), (i < self.milestones[1:]))
222
+ p = jnp.argmax(conditions)
223
+ return self.lr * self.gamma ** p
240
224
 
241
225
 
242
226
  class CosineAnnealingLR(LearningRateScheduler):
243
- r"""Set the learning rate of each parameter group using a cosine annealing
244
- schedule, where :math:`\eta_{max}` is set to the initial lr and
245
- :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
246
-
247
- .. math::
248
- \begin{aligned}
249
- \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
250
- + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
251
- & T_{cur} \neq (2k+1)T_{max}; \\
252
- \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
253
- \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
254
- & T_{cur} = (2k+1)T_{max}.
255
- \end{aligned}
256
-
257
- When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
258
- is defined recursively, the learning rate can be simultaneously modified
259
- outside this scheduler by other operators. If the learning rate is set
260
- solely by this scheduler, the learning rate at each step becomes:
261
-
262
- .. math::
263
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
264
- \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
265
-
266
- It has been proposed in
267
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
268
- implements the cosine annealing part of SGDR, and not the restarts.
269
-
270
- Parameters
271
- ----------
272
- lr: float
273
- Initial learning rate.
274
- T_max: int
275
- Maximum number of iterations.
276
- eta_min: float
277
- Minimum learning rate. Default: 0.
278
- last_epoch: int
279
- The index of last epoch. Default: -1.
280
-
281
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
282
- https://arxiv.org/abs/1608.03983
283
- """
284
-
285
- def __init__(
286
- self,
287
- lr: float,
288
- T_max: int,
289
- eta_min: float = 0.,
290
- last_epoch: int = -1,
291
- ):
292
- super().__init__(lr=lr, last_epoch=last_epoch)
293
-
294
- assert T_max >= 1, 'T_max should be greater than or equal to 1.'
295
- self._init_epoch = last_epoch
296
- self.T_max = T_max
297
- self.eta_min = eta_min
298
-
299
- def __call__(self, i=None):
300
- i = (self.last_epoch.value + 1) if i is None else i
301
- return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
302
-
303
- def extra_repr(self):
304
- return f', T_max={self.T_max}, eta_min={self.eta_min}'
227
+ r"""Set the learning rate of each parameter group using a cosine annealing
228
+ schedule, where :math:`\eta_{max}` is set to the initial lr and
229
+ :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
230
+
231
+ .. math::
232
+ \begin{aligned}
233
+ \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
234
+ + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
235
+ & T_{cur} \neq (2k+1)T_{max}; \\
236
+ \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
237
+ \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
238
+ & T_{cur} = (2k+1)T_{max}.
239
+ \end{aligned}
240
+
241
+ When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
242
+ is defined recursively, the learning rate can be simultaneously modified
243
+ outside this scheduler by other operators. If the learning rate is set
244
+ solely by this scheduler, the learning rate at each step becomes:
245
+
246
+ .. math::
247
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
248
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
249
+
250
+ It has been proposed in
251
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
252
+ implements the cosine annealing part of SGDR, and not the restarts.
253
+
254
+ Parameters
255
+ ----------
256
+ lr: float
257
+ Initial learning rate.
258
+ T_max: int
259
+ Maximum number of iterations.
260
+ eta_min: float
261
+ Minimum learning rate. Default: 0.
262
+ last_epoch: int
263
+ The index of last epoch. Default: -1.
264
+
265
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
266
+ https://arxiv.org/abs/1608.03983
267
+ """
305
268
 
269
+ def __init__(
270
+ self,
271
+ lr: float,
272
+ T_max: int,
273
+ eta_min: float = 0.,
274
+ last_epoch: int = -1,
275
+ ):
276
+ super().__init__(lr=lr, last_epoch=last_epoch)
306
277
 
307
- class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
308
- """Set the learning rate of each parameter group using a cosine annealing
309
- schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
310
- is the number of epochs since the last restart and :math:`T_{i}` is the number
311
- of epochs between two warm restarts in SGDR:
312
-
313
- .. math::
314
- \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
315
- \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
316
-
317
- When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
318
- When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
319
-
320
- It has been proposed in
321
- `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
322
-
323
- Parameters
324
- ----------
325
- lr: float
326
- Initial learning rate.
327
- num_call_per_epoch: int
328
- The number the scheduler to call in each epoch.
329
- This usually means the number of batch in each epoch training.
330
- T_0: int
331
- Number of iterations for the first restart.
332
- T_mult: int
333
- A factor increases :math:`T_{i}` after a restart. Default: 1.
334
- eta_min: float
335
- Minimum learning rate. Default: 0.
336
- last_call: int
337
- The index of last call. Default: -1.
338
-
339
- .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
340
- https://arxiv.org/abs/1608.03983
341
- """
342
-
343
- def __init__(
344
- self,
345
- lr: float,
346
- num_call_per_epoch: int,
347
- T_0: int,
348
- T_mult: int = 1,
349
- eta_min: float = 0.,
350
- last_epoch: int = -1,
351
- last_call: int = -1
352
- ):
353
- super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
354
- if T_0 <= 0 or not isinstance(T_0, int):
355
- raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
356
- if T_mult < 1 or not isinstance(T_mult, int):
357
- raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
358
-
359
- self.T_mult = T_mult
360
- self.eta_min = eta_min
361
- self.T_0 = T_0
362
- self.num_call_per_epoch = num_call_per_epoch
363
-
364
- def _cond1(self, epoch):
365
- if self.T_mult == 1:
366
- T_cur = epoch % self.T_0
367
- T_i = self.T_0
368
- else:
369
- n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
370
- T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
371
- T_i = self.T_0 * self.T_mult ** n
372
- return T_cur, T_i
278
+ assert T_max >= 1, 'T_max should be greater than or equal to 1.'
279
+ self._init_epoch = last_epoch
280
+ self.T_max = T_max
281
+ self.eta_min = eta_min
373
282
 
374
- def _cond2(self, epoch):
375
- return epoch, self.T_0
283
+ def __call__(self, i=None):
284
+ i = (self.last_epoch.value + 1) if i is None else i
285
+ return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * i / self.T_max)) / 2
376
286
 
377
- def __call__(self, i=None):
378
- epoch = self.current_epoch(i)
379
- T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
380
- return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
381
287
 
382
- def current_epoch(self, i=None):
383
- i = (self.last_call.value + 1) if i is None else i
384
- return jnp.floor(i / self.num_call_per_epoch)
288
+ class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
289
+ """Set the learning rate of each parameter group using a cosine annealing
290
+ schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
291
+ is the number of epochs since the last restart and :math:`T_{i}` is the number
292
+ of epochs between two warm restarts in SGDR:
293
+
294
+ .. math::
295
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
296
+ \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
297
+
298
+ When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
299
+ When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
300
+
301
+ It has been proposed in
302
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
303
+
304
+ Parameters
305
+ ----------
306
+ lr: float
307
+ Initial learning rate.
308
+ num_call_per_epoch: int
309
+ The number the scheduler to call in each epoch.
310
+ This usually means the number of batch in each epoch training.
311
+ T_0: int
312
+ Number of iterations for the first restart.
313
+ T_mult: int
314
+ A factor increases :math:`T_{i}` after a restart. Default: 1.
315
+ eta_min: float
316
+ Minimum learning rate. Default: 0.
317
+ last_call: int
318
+ The index of last call. Default: -1.
319
+
320
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
321
+ https://arxiv.org/abs/1608.03983
322
+ """
385
323
 
386
- def extra_repr(self):
387
- return f', T_0={self.T_0}, T_mult={self.T_mult}, eta_min={self.eta_min}'
324
+ def __init__(
325
+ self,
326
+ lr: float,
327
+ num_call_per_epoch: int,
328
+ T_0: int,
329
+ T_mult: int = 1,
330
+ eta_min: float = 0.,
331
+ last_epoch: int = -1,
332
+ last_call: int = -1
333
+ ):
334
+ super().__init__(lr=lr, last_call=last_call, last_epoch=last_epoch)
335
+ if T_0 <= 0 or not isinstance(T_0, int):
336
+ raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
337
+ if T_mult < 1 or not isinstance(T_mult, int):
338
+ raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
339
+
340
+ self.T_mult = T_mult
341
+ self.eta_min = eta_min
342
+ self.T_0 = T_0
343
+ self.num_call_per_epoch = num_call_per_epoch
344
+
345
+ def _cond1(self, epoch):
346
+ if self.T_mult == 1:
347
+ T_cur = epoch % self.T_0
348
+ T_i = self.T_0
349
+ else:
350
+ n = jnp.floor(jnp.log(epoch / self.T_0 * (self.T_mult - 1) + 1) / jnp.log(self.T_mult))
351
+ T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
352
+ T_i = self.T_0 * self.T_mult ** n
353
+ return T_cur, T_i
354
+
355
+ def _cond2(self, epoch):
356
+ return epoch, self.T_0
357
+
358
+ def __call__(self, i=None):
359
+ epoch = self.current_epoch(i)
360
+ T_cur, T_i = jax.lax.cond(epoch >= self.T_0, self._cond1, self._cond2, epoch)
361
+ return self.eta_min + (self.lr - self.eta_min) * (1 + jnp.cos(jnp.pi * T_cur / T_i)) / 2
362
+
363
+ def current_epoch(self, i=None):
364
+ i = (self.last_call.value + 1) if i is None else i
365
+ return jnp.floor(i / self.num_call_per_epoch)
388
366
 
389
367
 
390
368
  class ExponentialLR(LearningRateScheduler):
391
- """Decays the learning rate of each parameter group by gamma every epoch.
392
- When last_epoch=-1, sets initial lr as lr.
393
-
394
- Parameters
395
- ----------
396
- lr: float
397
- Initial learning rate.
398
- gamma: float
399
- Multiplicative factor of learning rate decay.
400
- last_epoch: int
401
- The index of last epoch. Default: -1.
402
- """
403
-
404
- def __init__(self,
405
- lr: float,
406
- gamma: float,
407
- last_epoch: int = -1):
408
- super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
409
- assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
410
- self.gamma = gamma
411
-
412
- def __call__(self, i: int = None):
413
- i = (self.last_epoch.value + 1) if i is None else i
414
- return self.lr * self.gamma ** i
415
-
416
- def extra_repr(self):
417
- return f', gamma={self.gamma}'
369
+ """Decays the learning rate of each parameter group by gamma every epoch.
370
+ When last_epoch=-1, sets initial lr as lr.
371
+
372
+ Parameters
373
+ ----------
374
+ lr: float
375
+ Initial learning rate.
376
+ gamma: float
377
+ Multiplicative factor of learning rate decay.
378
+ last_epoch: int
379
+ The index of last epoch. Default: -1.
380
+ """
418
381
 
382
+ def __init__(self,
383
+ lr: float,
384
+ gamma: float,
385
+ last_epoch: int = -1):
386
+ super(ExponentialLR, self).__init__(lr=lr, last_epoch=last_epoch)
387
+ assert 1. >= gamma >= 0, 'gamma should be in the range [0, 1].'
388
+ self.gamma = gamma
389
+
390
+ def __call__(self, i: int = None):
391
+ i = (self.last_epoch.value + 1) if i is None else i
392
+ return self.lr * self.gamma ** i
419
393
 
420
- class ExponentialDecayLR(CallBasedLRScheduler):
421
- def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
422
- super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
423
- self.decay_steps = decay_steps
424
- self.decay_rate = decay_rate
425
394
 
426
- def __call__(self, i=None):
427
- i = (self.last_call.value + 1) if i is None else i
428
- return self.lr * self.decay_rate ** (i / self.decay_steps)
395
+ class ExponentialDecayLR(CallBasedLRScheduler):
396
+ def __init__(self, lr, decay_steps, decay_rate, last_epoch: int = -1, last_call: int = -1):
397
+ super().__init__(lr=lr, last_epoch=last_epoch, last_call=last_call)
398
+ self.decay_steps = decay_steps
399
+ self.decay_rate = decay_rate
429
400
 
430
- def extra_repr(self):
431
- return f', decay_steps={self.decay_steps}, decay_rate={self.decay_rate}'
401
+ def __call__(self, i=None):
402
+ i = (self.last_call.value + 1) if i is None else i
403
+ return self.lr * self.decay_rate ** (i / self.decay_steps)
432
404
 
433
405
 
434
406
  class InverseTimeDecayLR(ExponentialDecayLR):
435
- def __init__(self, lr, decay_steps, decay_rate, staircase=False,
436
- last_epoch: int = -1, last_call: int = -1):
437
- super().__init__(lr, decay_steps, decay_rate, last_epoch=last_epoch, last_call=last_call)
438
- self.staircase = staircase
439
-
440
- def __call__(self, i=None):
441
- i = (self.last_call.value + 1) if i is None else i
442
- if self.staircase:
443
- return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
444
- else:
445
- return self.lr / (1 + self.decay_rate * i / self.decay_steps)
407
+ def __init__(self, lr, decay_steps, decay_rate, staircase=False,
408
+ last_epoch: int = -1, last_call: int = -1):
409
+ super().__init__(lr, decay_steps, decay_rate, last_epoch=last_epoch, last_call=last_call)
410
+ self.staircase = staircase
446
411
 
447
- def extra_repr(self):
448
- return f', decay_steps={self.decay_steps}, decay_rate={self.decay_rate}, staircase={self.staircase}'
412
+ def __call__(self, i=None):
413
+ i = (self.last_call.value + 1) if i is None else i
414
+ if self.staircase:
415
+ return self.lr / (1 + self.decay_rate * jnp.floor(i / self.decay_steps))
416
+ else:
417
+ return self.lr / (1 + self.decay_rate * i / self.decay_steps)
449
418
 
450
419
 
451
420
  class PolynomialDecayLR(CallBasedLRScheduler):
452
- def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
453
- super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
454
- self.decay_steps = decay_steps
455
- self.final_lr = final_lr
456
- self.power = power
457
-
458
- def __call__(self, i=None):
459
- i = (self.last_call.value + 1) if i is None else i
460
- i = jnp.minimum(i, self.decay_steps)
461
- step_mult = (1 - i / self.decay_steps) ** self.power
462
- return step_mult * (self.lr - self.final_lr) + self.final_lr
421
+ def __init__(self, lr, decay_steps, final_lr, power=1.0, last_epoch: int = -1, last_call: int = -1):
422
+ super(PolynomialDecayLR, self).__init__(lr, last_epoch=last_epoch, last_call=last_call)
423
+ self.decay_steps = decay_steps
424
+ self.final_lr = final_lr
425
+ self.power = power
463
426
 
464
- def extra_repr(self):
465
- return f', decay_steps={self.decay_steps}, final_lr={self.final_lr}, power={self.power}'
427
+ def __call__(self, i=None):
428
+ i = (self.last_call.value + 1) if i is None else i
429
+ i = jnp.minimum(i, self.decay_steps)
430
+ step_mult = (1 - i / self.decay_steps) ** self.power
431
+ return step_mult * (self.lr - self.final_lr) + self.final_lr
466
432
 
467
433
 
468
434
  class PiecewiseConstantLR(CallBasedLRScheduler):
469
- def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
470
- super().__init__(0., last_epoch=last_epoch, last_call=last_call)
471
-
472
- boundaries = jnp.array(boundaries)
473
- values = jnp.array(values)
474
- if not boundaries.ndim == values.ndim == 1:
475
- raise ValueError("boundaries and values must be sequences")
476
- if not boundaries.shape[0] == values.shape[0] - 1:
477
- raise ValueError("boundaries length must be one shorter than values length")
478
- self.boundaries = boundaries
479
- self.values = values
480
-
481
- def __call__(self, i=None):
482
- i = (self.last_call.value + 1) if i is None else i
483
- return self.values[jnp.sum(i > self.boundaries)]
484
-
485
- def extra_repr(self):
486
- return f', boundaries={self.boundaries}, values={self.values}'
435
+ def __init__(self, boundaries, values, last_epoch: int = -1, last_call: int = -1):
436
+ super().__init__(0., last_epoch=last_epoch, last_call=last_call)
437
+
438
+ boundaries = jnp.array(boundaries)
439
+ values = jnp.array(values)
440
+ if not boundaries.ndim == values.ndim == 1:
441
+ raise ValueError("boundaries and values must be sequences")
442
+ if not boundaries.shape[0] == values.shape[0] - 1:
443
+ raise ValueError("boundaries length must be one shorter than values length")
444
+ self.boundaries = boundaries
445
+ self.values = values
446
+
447
+ def __call__(self, i=None):
448
+ i = (self.last_call.value + 1) if i is None else i
449
+ return self.values[jnp.sum(i > self.boundaries)]