brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1617 +1,1320 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from
|
19
|
-
from
|
20
|
-
|
21
|
-
|
22
|
-
import
|
23
|
-
import jax
|
24
|
-
import jax.
|
25
|
-
import
|
26
|
-
import
|
27
|
-
|
28
|
-
from
|
29
|
-
|
30
|
-
from brainstate import
|
31
|
-
from
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
self
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
isinstance(self._value, jax.
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
if
|
174
|
-
self.
|
175
|
-
if n is None
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
self.value =
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
def
|
201
|
-
self,
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
low
|
243
|
-
high
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
self
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
self,
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
self,
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
self,
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
key =
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype))
|
412
|
-
return r
|
413
|
-
|
414
|
-
def laplace(
|
415
|
-
self,
|
416
|
-
loc=None,
|
417
|
-
scale=None,
|
418
|
-
size: Optional[Size] = None,
|
419
|
-
key: Optional[SeedOrKey] = None,
|
420
|
-
dtype: DTypeLike = None
|
421
|
-
):
|
422
|
-
loc = _check_py_seq(loc)
|
423
|
-
scale = _check_py_seq(scale)
|
424
|
-
if size is None:
|
425
|
-
size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
|
426
|
-
key = self.
|
427
|
-
|
428
|
-
r
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
u.math.shape(
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
dtype
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
scale =
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
u.math.
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
#
|
634
|
-
#
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
#
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
mean=
|
687
|
-
sigma=
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
self
|
760
|
-
alpha
|
761
|
-
|
762
|
-
key
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
key
|
776
|
-
dtype
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
if
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
if
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
key
|
867
|
-
dtype
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
self
|
910
|
-
a
|
911
|
-
size
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
n
|
968
|
-
p
|
969
|
-
size
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
size =
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
#
|
1005
|
-
#
|
1006
|
-
#
|
1007
|
-
#
|
1008
|
-
#
|
1009
|
-
#
|
1010
|
-
# x
|
1011
|
-
#
|
1012
|
-
#
|
1013
|
-
#
|
1014
|
-
#
|
1015
|
-
#
|
1016
|
-
#
|
1017
|
-
#
|
1018
|
-
#
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
df,
|
1038
|
-
size
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
size
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
self
|
1121
|
-
logits
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1152
|
-
|
1153
|
-
|
1154
|
-
|
1155
|
-
|
1156
|
-
|
1157
|
-
|
1158
|
-
|
1159
|
-
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1170
|
-
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1177
|
-
|
1178
|
-
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
|
1192
|
-
|
1193
|
-
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
size
|
1227
|
-
|
1228
|
-
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
dfnum
|
1245
|
-
dfden
|
1246
|
-
nonc
|
1247
|
-
size
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
|
1268
|
-
|
1269
|
-
|
1270
|
-
|
1271
|
-
|
1272
|
-
|
1273
|
-
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
input,
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
key: Optional[SeedOrKey] = None
|
1322
|
-
):
|
1323
|
-
if high is None:
|
1324
|
-
high = max(input)
|
1325
|
-
return self.randint(low, high=high, size=u.math.shape(input), dtype=dtype, key=key)
|
1326
|
-
|
1327
|
-
|
1328
|
-
# default random generator
|
1329
|
-
DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
|
1330
|
-
|
1331
|
-
|
1332
|
-
# ---------------------------------------------------------------------------------------------------------------
|
1333
|
-
|
1334
|
-
|
1335
|
-
def _formalize_key(key):
|
1336
|
-
if isinstance(key, int):
|
1337
|
-
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
1338
|
-
elif isinstance(key, (jax.Array, np.ndarray)):
|
1339
|
-
if jnp.issubdtype(key.dtype, jax.dtypes.prng_key):
|
1340
|
-
return key
|
1341
|
-
if key.size == 1 and jnp.issubdtype(key.dtype, jnp.integer):
|
1342
|
-
return jr.PRNGKey(key) if use_prng_key else jr.key(key)
|
1343
|
-
|
1344
|
-
if key.dtype != jnp.uint32:
|
1345
|
-
raise TypeError('key must be a int or an array with two uint32.')
|
1346
|
-
if key.size != 2:
|
1347
|
-
raise TypeError('key must be a int or an array with two uint32.')
|
1348
|
-
return u.math.asarray(key, dtype=jnp.uint32)
|
1349
|
-
else:
|
1350
|
-
raise TypeError('key must be a int or an array with two uint32.')
|
1351
|
-
|
1352
|
-
|
1353
|
-
def _size2shape(size):
|
1354
|
-
if size is None:
|
1355
|
-
return ()
|
1356
|
-
elif isinstance(size, (tuple, list)):
|
1357
|
-
return tuple(size)
|
1358
|
-
else:
|
1359
|
-
return (size,)
|
1360
|
-
|
1361
|
-
|
1362
|
-
def _check_shape(
|
1363
|
-
name,
|
1364
|
-
shape,
|
1365
|
-
*param_shapes
|
1366
|
-
):
|
1367
|
-
if param_shapes:
|
1368
|
-
shape_ = lax.broadcast_shapes(shape, *param_shapes)
|
1369
|
-
if shape != shape_:
|
1370
|
-
msg = ("{} parameter shapes must be broadcast-compatible with shape "
|
1371
|
-
"argument, and the result of broadcasting the shapes must equal "
|
1372
|
-
"the shape argument, but got result {} for shape argument {}.")
|
1373
|
-
raise ValueError(msg.format(name, shape_, shape))
|
1374
|
-
|
1375
|
-
|
1376
|
-
def _is_python_scalar(x):
|
1377
|
-
if hasattr(x, 'aval'):
|
1378
|
-
return x.aval.weak_type
|
1379
|
-
elif np.ndim(x) == 0:
|
1380
|
-
return True
|
1381
|
-
elif isinstance(x, (bool, int, float, complex)):
|
1382
|
-
return True
|
1383
|
-
else:
|
1384
|
-
return False
|
1385
|
-
|
1386
|
-
|
1387
|
-
python_scalar_dtypes = {
|
1388
|
-
bool: np.dtype('bool'),
|
1389
|
-
int: np.dtype('int64'),
|
1390
|
-
float: np.dtype('float64'),
|
1391
|
-
complex: np.dtype('complex128'),
|
1392
|
-
}
|
1393
|
-
|
1394
|
-
|
1395
|
-
def _dtype(
|
1396
|
-
x,
|
1397
|
-
*,
|
1398
|
-
canonicalize: bool = False
|
1399
|
-
):
|
1400
|
-
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
|
1401
|
-
if x is None:
|
1402
|
-
raise ValueError(f"Invalid argument to dtype: {x}.")
|
1403
|
-
elif isinstance(x, type) and x in python_scalar_dtypes:
|
1404
|
-
dt = python_scalar_dtypes[x]
|
1405
|
-
elif type(x) in python_scalar_dtypes:
|
1406
|
-
dt = python_scalar_dtypes[type(x)]
|
1407
|
-
elif hasattr(x, 'dtype'):
|
1408
|
-
dt = x.dtype
|
1409
|
-
else:
|
1410
|
-
dt = np.result_type(x)
|
1411
|
-
return dtypes.canonicalize_dtype(dt) if canonicalize else dt
|
1412
|
-
|
1413
|
-
|
1414
|
-
def _const(
|
1415
|
-
example,
|
1416
|
-
val
|
1417
|
-
):
|
1418
|
-
if _is_python_scalar(example):
|
1419
|
-
dtype = dtypes.canonicalize_dtype(type(example))
|
1420
|
-
val = dtypes.scalar_type_of(example)(val)
|
1421
|
-
return val if dtype == _dtype(val, canonicalize=True) else np.array(val, dtype)
|
1422
|
-
else:
|
1423
|
-
dtype = dtypes.canonicalize_dtype(example.dtype)
|
1424
|
-
return np.array(val, dtype)
|
1425
|
-
|
1426
|
-
|
1427
|
-
@partial(jit, static_argnums=(2,))
|
1428
|
-
def _categorical(
|
1429
|
-
key,
|
1430
|
-
p,
|
1431
|
-
shape
|
1432
|
-
):
|
1433
|
-
# this implementation is fast when event shape is small, and slow otherwise
|
1434
|
-
# Ref: https://stackoverflow.com/a/34190035
|
1435
|
-
shape = shape or p.shape[:-1]
|
1436
|
-
s = jnp.cumsum(p, axis=-1)
|
1437
|
-
r = jr.uniform(key, shape=shape + (1,))
|
1438
|
-
return jnp.sum(s < r, axis=-1)
|
1439
|
-
|
1440
|
-
|
1441
|
-
def _scatter_add_one(
|
1442
|
-
operand,
|
1443
|
-
indices,
|
1444
|
-
updates
|
1445
|
-
):
|
1446
|
-
return lax.scatter_add(
|
1447
|
-
operand,
|
1448
|
-
indices,
|
1449
|
-
updates,
|
1450
|
-
lax.ScatterDimensionNumbers(
|
1451
|
-
update_window_dims=(),
|
1452
|
-
inserted_window_dims=(0,),
|
1453
|
-
scatter_dims_to_operand_dims=(0,),
|
1454
|
-
),
|
1455
|
-
)
|
1456
|
-
|
1457
|
-
|
1458
|
-
def _reshape(x, shape):
|
1459
|
-
if isinstance(x, (int, float, np.ndarray, np.generic)):
|
1460
|
-
return np.reshape(x, shape)
|
1461
|
-
else:
|
1462
|
-
return jnp.reshape(x, shape)
|
1463
|
-
|
1464
|
-
|
1465
|
-
def _promote_shapes(
|
1466
|
-
*args,
|
1467
|
-
shape=()
|
1468
|
-
):
|
1469
|
-
# adapted from lax.lax_numpy
|
1470
|
-
if len(args) < 2 and not shape:
|
1471
|
-
return args
|
1472
|
-
else:
|
1473
|
-
shapes = [u.math.shape(arg) for arg in args]
|
1474
|
-
num_dims = len(lax.broadcast_shapes(shape, *shapes))
|
1475
|
-
return [
|
1476
|
-
_reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg
|
1477
|
-
for arg, s in zip(args, shapes)
|
1478
|
-
]
|
1479
|
-
|
1480
|
-
|
1481
|
-
@partial(jit, static_argnums=(3, 4))
|
1482
|
-
def _multinomial(
|
1483
|
-
key,
|
1484
|
-
p,
|
1485
|
-
n,
|
1486
|
-
n_max,
|
1487
|
-
shape=()
|
1488
|
-
):
|
1489
|
-
if u.math.shape(n) != u.math.shape(p)[:-1]:
|
1490
|
-
broadcast_shape = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p)[:-1])
|
1491
|
-
n = jnp.broadcast_to(n, broadcast_shape)
|
1492
|
-
p = jnp.broadcast_to(p, broadcast_shape + u.math.shape(p)[-1:])
|
1493
|
-
shape = shape or p.shape[:-1]
|
1494
|
-
if n_max == 0:
|
1495
|
-
return jnp.zeros(shape + p.shape[-1:], dtype=jnp.result_type(int))
|
1496
|
-
# get indices from categorical distribution then gather the result
|
1497
|
-
indices = _categorical(key, p, (n_max,) + shape)
|
1498
|
-
# mask out values when counts is heterogeneous
|
1499
|
-
if jnp.ndim(n) > 0:
|
1500
|
-
mask = _promote_shapes(jnp.arange(n_max) < jnp.expand_dims(n, -1), shape=shape + (n_max,))[0]
|
1501
|
-
mask = jnp.moveaxis(mask, -1, 0).astype(indices.dtype)
|
1502
|
-
excess = jnp.concatenate([jnp.expand_dims(n_max - n, -1),
|
1503
|
-
jnp.zeros(u.math.shape(n) + (p.shape[-1] - 1,))],
|
1504
|
-
-1)
|
1505
|
-
else:
|
1506
|
-
mask = 1
|
1507
|
-
excess = 0
|
1508
|
-
# NB: we transpose to move batch shape to the front
|
1509
|
-
indices_2D = (jnp.reshape(indices * mask, (n_max, -1))).T
|
1510
|
-
samples_2D = vmap(_scatter_add_one)(
|
1511
|
-
jnp.zeros((indices_2D.shape[0], p.shape[-1]), dtype=indices.dtype),
|
1512
|
-
jnp.expand_dims(indices_2D, axis=-1),
|
1513
|
-
jnp.ones(indices_2D.shape, dtype=indices.dtype)
|
1514
|
-
)
|
1515
|
-
return jnp.reshape(samples_2D, shape + p.shape[-1:]) - excess
|
1516
|
-
|
1517
|
-
|
1518
|
-
@partial(jit, static_argnums=(2, 3), static_argnames=['shape', 'dtype'])
|
1519
|
-
def _von_mises_centered(
|
1520
|
-
key,
|
1521
|
-
concentration,
|
1522
|
-
shape,
|
1523
|
-
dtype=None
|
1524
|
-
):
|
1525
|
-
"""Compute centered von Mises samples using rejection sampling from [1]_ with wrapped Cauchy proposal.
|
1526
|
-
|
1527
|
-
Returns
|
1528
|
-
-------
|
1529
|
-
out: array_like
|
1530
|
-
centered samples from von Mises
|
1531
|
-
|
1532
|
-
References
|
1533
|
-
----------
|
1534
|
-
.. [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag, 1986;
|
1535
|
-
Chapter 9, p. 473-476. http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
|
1536
|
-
|
1537
|
-
"""
|
1538
|
-
shape = shape or u.math.shape(concentration)
|
1539
|
-
dtype = dtype or environ.dftype()
|
1540
|
-
concentration = lax.convert_element_type(concentration, dtype)
|
1541
|
-
concentration = jnp.broadcast_to(concentration, shape)
|
1542
|
-
|
1543
|
-
if dtype == jnp.float16:
|
1544
|
-
s_cutoff = 1.8e-1
|
1545
|
-
elif dtype == jnp.float32:
|
1546
|
-
s_cutoff = 2e-2
|
1547
|
-
elif dtype == jnp.float64:
|
1548
|
-
s_cutoff = 1.2e-4
|
1549
|
-
else:
|
1550
|
-
raise ValueError(f"Unsupported dtype: {dtype}")
|
1551
|
-
|
1552
|
-
r = 1.0 + jnp.sqrt(1.0 + 4.0 * concentration ** 2)
|
1553
|
-
rho = (r - jnp.sqrt(2.0 * r)) / (2.0 * concentration)
|
1554
|
-
s_exact = (1.0 + rho ** 2) / (2.0 * rho)
|
1555
|
-
|
1556
|
-
s_approximate = 1.0 / concentration
|
1557
|
-
|
1558
|
-
s = jnp.where(concentration > s_cutoff, s_exact, s_approximate)
|
1559
|
-
|
1560
|
-
def cond_fn(
|
1561
|
-
*args
|
1562
|
-
):
|
1563
|
-
"""check if all are done or reached max number of iterations"""
|
1564
|
-
i, _, done, _, _ = args[0]
|
1565
|
-
return jnp.bitwise_and(i < 100, jnp.logical_not(jnp.all(done)))
|
1566
|
-
|
1567
|
-
def body_fn(
|
1568
|
-
*args
|
1569
|
-
):
|
1570
|
-
i, key, done, _, w = args[0]
|
1571
|
-
uni_ukey, uni_vkey, key = jr.split(key, 3)
|
1572
|
-
u_ = jr.uniform(
|
1573
|
-
key=uni_ukey,
|
1574
|
-
shape=shape,
|
1575
|
-
dtype=concentration.dtype,
|
1576
|
-
minval=-1.0,
|
1577
|
-
maxval=1.0,
|
1578
|
-
)
|
1579
|
-
z = jnp.cos(jnp.pi * u_)
|
1580
|
-
w = jnp.where(done, w, (1.0 + s * z) / (s + z)) # Update where not done
|
1581
|
-
y = concentration * (s - w)
|
1582
|
-
v = jr.uniform(key=uni_vkey, shape=shape, dtype=concentration.dtype)
|
1583
|
-
accept = (y * (2.0 - y) >= v) | (jnp.log(y / v) + 1.0 >= y)
|
1584
|
-
return i + 1, key, accept | done, u_, w
|
1585
|
-
|
1586
|
-
init_done = jnp.zeros(shape, dtype=bool)
|
1587
|
-
init_u = jnp.zeros(shape)
|
1588
|
-
init_w = jnp.zeros(shape)
|
1589
|
-
|
1590
|
-
_, _, done, uu, w = lax.while_loop(
|
1591
|
-
cond_fun=cond_fn,
|
1592
|
-
body_fun=body_fn,
|
1593
|
-
init_val=(jnp.array(0), key, init_done, init_u, init_w),
|
1594
|
-
)
|
1595
|
-
|
1596
|
-
return jnp.sign(uu) * jnp.arccos(w)
|
1597
|
-
|
1598
|
-
|
1599
|
-
def _loc_scale(
|
1600
|
-
loc,
|
1601
|
-
scale,
|
1602
|
-
value
|
1603
|
-
):
|
1604
|
-
if loc is None:
|
1605
|
-
if scale is None:
|
1606
|
-
return value
|
1607
|
-
else:
|
1608
|
-
return value * scale
|
1609
|
-
else:
|
1610
|
-
if scale is None:
|
1611
|
-
return value + loc
|
1612
|
-
else:
|
1613
|
-
return value * scale + loc
|
1614
|
-
|
1615
|
-
|
1616
|
-
def _check_py_seq(seq):
|
1617
|
-
return u.math.asarray(seq) if isinstance(seq, (tuple, list)) else seq
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from operator import index
|
19
|
+
from typing import Optional
|
20
|
+
|
21
|
+
import brainunit as u
|
22
|
+
import jax
|
23
|
+
import jax.numpy as jnp
|
24
|
+
import jax.random as jr
|
25
|
+
import numpy as np
|
26
|
+
from jax import lax, core
|
27
|
+
|
28
|
+
from brainstate import environ
|
29
|
+
from brainstate._state import State
|
30
|
+
from brainstate.typing import DTypeLike, Size, SeedOrKey
|
31
|
+
from ._impl import (
|
32
|
+
multinomial, von_mises_centered, const,
|
33
|
+
formalize_key, _loc_scale, _size2shape, _check_py_seq, _check_shape,
|
34
|
+
noncentral_f, logseries, hypergeometric, f, power, zipf
|
35
|
+
)
|
36
|
+
|
37
|
+
__all__ = [
|
38
|
+
'RandomState',
|
39
|
+
'DEFAULT',
|
40
|
+
]
|
41
|
+
|
42
|
+
use_prng_key = True
|
43
|
+
|
44
|
+
|
45
|
+
class RandomState(State):
|
46
|
+
"""RandomState that track the random generator state. """
|
47
|
+
|
48
|
+
# __slots__ = ('_backup', '_value')
|
49
|
+
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
seed_or_key: Optional[SeedOrKey] = None
|
53
|
+
):
|
54
|
+
"""RandomState constructor.
|
55
|
+
|
56
|
+
Parameters
|
57
|
+
----------
|
58
|
+
seed_or_key: int, Array, optional
|
59
|
+
It can be an integer for initial seed of the random number generator,
|
60
|
+
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
|
61
|
+
"""
|
62
|
+
with jax.ensure_compile_time_eval():
|
63
|
+
if seed_or_key is None:
|
64
|
+
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
|
65
|
+
if isinstance(seed_or_key, int):
|
66
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
67
|
+
else:
|
68
|
+
if jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
69
|
+
key = seed_or_key
|
70
|
+
else:
|
71
|
+
if len(seed_or_key) != 2 and seed_or_key.dtype != np.uint32:
|
72
|
+
raise ValueError('key must be an array with dtype uint32. '
|
73
|
+
f'But we got {seed_or_key}')
|
74
|
+
key = seed_or_key
|
75
|
+
super().__init__(key)
|
76
|
+
|
77
|
+
self._backup = None
|
78
|
+
|
79
|
+
def __repr__(self):
|
80
|
+
return f'{self.__class__.__name__}({self.value})'
|
81
|
+
|
82
|
+
def check_if_deleted(self):
|
83
|
+
if not use_prng_key and isinstance(self._value, np.ndarray):
|
84
|
+
self._value = jr.key(np.random.randint(0, 10000))
|
85
|
+
|
86
|
+
if (
|
87
|
+
isinstance(self._value, jax.Array) and
|
88
|
+
not isinstance(self._value, jax.core.Tracer) and
|
89
|
+
self._value.is_deleted()
|
90
|
+
):
|
91
|
+
self.seed()
|
92
|
+
|
93
|
+
@staticmethod
|
94
|
+
def _batch_keys(batch_size: int):
|
95
|
+
key = jr.PRNGKey(0) if use_prng_key else jr.key(0)
|
96
|
+
return jr.split(key, batch_size)
|
97
|
+
|
98
|
+
# ------------------- #
|
99
|
+
# seed and random key #
|
100
|
+
# ------------------- #
|
101
|
+
|
102
|
+
def backup_key(self):
|
103
|
+
if self._backup is not None:
|
104
|
+
raise ValueError('The random key has been backed up, and has not been restored.')
|
105
|
+
self._backup = self.value
|
106
|
+
|
107
|
+
def restore_key(self):
|
108
|
+
if self._backup is None:
|
109
|
+
raise ValueError('The random key has not been backed up.')
|
110
|
+
self.value = self._backup
|
111
|
+
self._backup = None
|
112
|
+
|
113
|
+
def clone(self):
|
114
|
+
return type(self)(self.split_key())
|
115
|
+
|
116
|
+
def set_key(self, key: SeedOrKey):
|
117
|
+
self.value = key
|
118
|
+
|
119
|
+
def seed(
|
120
|
+
self,
|
121
|
+
seed_or_key: Optional[SeedOrKey] = None
|
122
|
+
):
|
123
|
+
"""Sets a new random seed.
|
124
|
+
|
125
|
+
Parameters
|
126
|
+
----------
|
127
|
+
seed_or_key: int, ArrayLike, optional
|
128
|
+
It can be an integer for initial seed of the random number generator,
|
129
|
+
or it can be a JAX's PRNKey, which is an array with two elements and `uint32` dtype.
|
130
|
+
"""
|
131
|
+
with jax.ensure_compile_time_eval():
|
132
|
+
if seed_or_key is None:
|
133
|
+
seed_or_key = np.random.randint(0, 10000000, 2, dtype=np.uint32)
|
134
|
+
if np.size(seed_or_key) == 1:
|
135
|
+
if isinstance(seed_or_key, int):
|
136
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
137
|
+
elif jnp.issubdtype(seed_or_key.dtype, jax.dtypes.prng_key):
|
138
|
+
key = seed_or_key
|
139
|
+
elif isinstance(seed_or_key, (jnp.ndarray, np.ndarray)) and jnp.issubdtype(seed_or_key.dtype, jnp.integer):
|
140
|
+
key = jr.PRNGKey(seed_or_key) if use_prng_key else jr.key(seed_or_key)
|
141
|
+
else:
|
142
|
+
raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
|
143
|
+
else:
|
144
|
+
if len(seed_or_key) == 2 and seed_or_key.dtype == np.uint32:
|
145
|
+
key = seed_or_key
|
146
|
+
else:
|
147
|
+
raise ValueError(f'Invalid seed_or_key: {seed_or_key}')
|
148
|
+
self.value = key
|
149
|
+
|
150
|
+
def split_key(
|
151
|
+
self,
|
152
|
+
n: Optional[int] = None,
|
153
|
+
backup: bool = False
|
154
|
+
) -> SeedOrKey:
|
155
|
+
"""
|
156
|
+
Create a new seed from the current seed.
|
157
|
+
|
158
|
+
Parameters
|
159
|
+
----------
|
160
|
+
n: int, optional
|
161
|
+
The number of seeds to generate.
|
162
|
+
backup : bool, optional
|
163
|
+
Whether to back up the current key.
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
key : SeedOrKey
|
168
|
+
The new seed or a tuple of JAX random keys.
|
169
|
+
"""
|
170
|
+
if n is not None:
|
171
|
+
assert isinstance(n, int) and n >= 1, f'n should be an integer greater than 1, but we got {n}'
|
172
|
+
|
173
|
+
if not isinstance(self.value, jax.Array):
|
174
|
+
self.value = u.math.asarray(self.value, dtype=jnp.uint32)
|
175
|
+
keys = jr.split(self.value, num=2 if n is None else n + 1)
|
176
|
+
self.value = keys[0]
|
177
|
+
if backup:
|
178
|
+
self.backup_key()
|
179
|
+
if n is None:
|
180
|
+
return keys[1]
|
181
|
+
else:
|
182
|
+
return keys[1:]
|
183
|
+
|
184
|
+
def self_assign_multi_keys(
|
185
|
+
self,
|
186
|
+
n: int,
|
187
|
+
backup: bool = True
|
188
|
+
):
|
189
|
+
"""
|
190
|
+
Self-assign multiple keys to the current random state.
|
191
|
+
"""
|
192
|
+
if backup:
|
193
|
+
keys = jr.split(self.value, n + 1)
|
194
|
+
self.value = keys[0]
|
195
|
+
self.backup_key()
|
196
|
+
self.value = keys[1:]
|
197
|
+
else:
|
198
|
+
self.value = jr.split(self.value, n)
|
199
|
+
|
200
|
+
def __get_key(self, key):
|
201
|
+
return self.split_key() if key is None else formalize_key(key, use_prng_key)
|
202
|
+
|
203
|
+
# ---------------- #
|
204
|
+
# random functions #
|
205
|
+
# ---------------- #
|
206
|
+
|
207
|
+
def rand(
|
208
|
+
self,
|
209
|
+
*dn,
|
210
|
+
key: Optional[SeedOrKey] = None,
|
211
|
+
dtype: DTypeLike = None
|
212
|
+
):
|
213
|
+
key = self.__get_key(key)
|
214
|
+
dtype = dtype or environ.dftype()
|
215
|
+
r = jr.uniform(key, dn, dtype)
|
216
|
+
return r
|
217
|
+
|
218
|
+
def randint(
|
219
|
+
self,
|
220
|
+
low,
|
221
|
+
high=None,
|
222
|
+
size: Optional[Size] = None,
|
223
|
+
dtype: DTypeLike = None,
|
224
|
+
key: Optional[SeedOrKey] = None
|
225
|
+
):
|
226
|
+
if high is None:
|
227
|
+
high = low
|
228
|
+
low = 0
|
229
|
+
high = _check_py_seq(high)
|
230
|
+
low = _check_py_seq(low)
|
231
|
+
if size is None:
|
232
|
+
size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
|
233
|
+
key = self.__get_key(key)
|
234
|
+
dtype = dtype or environ.ditype()
|
235
|
+
r = jr.randint(key,
|
236
|
+
shape=_size2shape(size),
|
237
|
+
minval=low, maxval=high, dtype=dtype)
|
238
|
+
return r
|
239
|
+
|
240
|
+
def random_integers(
|
241
|
+
self,
|
242
|
+
low,
|
243
|
+
high=None,
|
244
|
+
size: Optional[Size] = None,
|
245
|
+
key: Optional[SeedOrKey] = None,
|
246
|
+
dtype: DTypeLike = None
|
247
|
+
):
|
248
|
+
low = _check_py_seq(low)
|
249
|
+
high = _check_py_seq(high)
|
250
|
+
if high is None:
|
251
|
+
high = low
|
252
|
+
low = 1
|
253
|
+
high += 1
|
254
|
+
if size is None:
|
255
|
+
size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
|
256
|
+
key = self.__get_key(key)
|
257
|
+
dtype = dtype or environ.ditype()
|
258
|
+
r = jr.randint(key,
|
259
|
+
shape=_size2shape(size),
|
260
|
+
minval=low,
|
261
|
+
maxval=high,
|
262
|
+
dtype=dtype)
|
263
|
+
return r
|
264
|
+
|
265
|
+
def randn(
|
266
|
+
self,
|
267
|
+
*dn,
|
268
|
+
key: Optional[SeedOrKey] = None,
|
269
|
+
dtype: DTypeLike = None
|
270
|
+
):
|
271
|
+
key = self.__get_key(key)
|
272
|
+
r = jr.normal(key, shape=dn, dtype=dtype or environ.dftype())
|
273
|
+
return r
|
274
|
+
|
275
|
+
def random(
|
276
|
+
self,
|
277
|
+
size: Optional[Size] = None,
|
278
|
+
key: Optional[SeedOrKey] = None,
|
279
|
+
dtype: DTypeLike = None
|
280
|
+
):
|
281
|
+
key = self.__get_key(key)
|
282
|
+
r = jr.uniform(key, _size2shape(size), dtype=dtype or environ.dftype())
|
283
|
+
return r
|
284
|
+
|
285
|
+
def random_sample(
|
286
|
+
self,
|
287
|
+
size: Optional[Size] = None,
|
288
|
+
key: Optional[SeedOrKey] = None,
|
289
|
+
dtype: DTypeLike = None
|
290
|
+
):
|
291
|
+
r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
|
292
|
+
return r
|
293
|
+
|
294
|
+
def ranf(
|
295
|
+
self,
|
296
|
+
size: Optional[Size] = None,
|
297
|
+
key: Optional[SeedOrKey] = None,
|
298
|
+
dtype: DTypeLike = None
|
299
|
+
):
|
300
|
+
r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
|
301
|
+
return r
|
302
|
+
|
303
|
+
def sample(
|
304
|
+
self,
|
305
|
+
size: Optional[Size] = None,
|
306
|
+
key: Optional[SeedOrKey] = None,
|
307
|
+
dtype: DTypeLike = None
|
308
|
+
):
|
309
|
+
r = self.random(size=size, key=key, dtype=dtype or environ.dftype())
|
310
|
+
return r
|
311
|
+
|
312
|
+
def choice(
|
313
|
+
self,
|
314
|
+
a,
|
315
|
+
size: Optional[Size] = None,
|
316
|
+
replace=True,
|
317
|
+
p=None,
|
318
|
+
key: Optional[SeedOrKey] = None
|
319
|
+
):
|
320
|
+
a = _check_py_seq(a)
|
321
|
+
a, unit = u.split_mantissa_unit(a)
|
322
|
+
p = _check_py_seq(p)
|
323
|
+
key = self.__get_key(key)
|
324
|
+
r = jr.choice(key, a=a, shape=_size2shape(size), replace=replace, p=p)
|
325
|
+
return u.maybe_decimal(r * unit)
|
326
|
+
|
327
|
+
def permutation(
|
328
|
+
self,
|
329
|
+
x,
|
330
|
+
axis: int = 0,
|
331
|
+
independent: bool = False,
|
332
|
+
key: Optional[SeedOrKey] = None
|
333
|
+
):
|
334
|
+
x = _check_py_seq(x)
|
335
|
+
x, unit = u.split_mantissa_unit(x)
|
336
|
+
key = self.__get_key(key)
|
337
|
+
r = jr.permutation(key, x, axis, independent=independent)
|
338
|
+
return u.maybe_decimal(r * unit)
|
339
|
+
|
340
|
+
def shuffle(
|
341
|
+
self,
|
342
|
+
x,
|
343
|
+
axis=0,
|
344
|
+
key: Optional[SeedOrKey] = None
|
345
|
+
):
|
346
|
+
return self.permutation(x, axis=axis, key=key, independent=False)
|
347
|
+
|
348
|
+
def beta(
|
349
|
+
self,
|
350
|
+
a,
|
351
|
+
b,
|
352
|
+
size: Optional[Size] = None,
|
353
|
+
key: Optional[SeedOrKey] = None,
|
354
|
+
dtype: DTypeLike = None
|
355
|
+
):
|
356
|
+
a = _check_py_seq(a)
|
357
|
+
b = _check_py_seq(b)
|
358
|
+
if size is None:
|
359
|
+
size = lax.broadcast_shapes(u.math.shape(a), u.math.shape(b))
|
360
|
+
key = self.__get_key(key)
|
361
|
+
r = jr.beta(key, a=a, b=b, shape=_size2shape(size), dtype=dtype or environ.dftype())
|
362
|
+
return r
|
363
|
+
|
364
|
+
def exponential(
|
365
|
+
self,
|
366
|
+
scale=None,
|
367
|
+
size: Optional[Size] = None,
|
368
|
+
key: Optional[SeedOrKey] = None,
|
369
|
+
dtype: DTypeLike = None
|
370
|
+
):
|
371
|
+
if size is None:
|
372
|
+
size = u.math.shape(scale)
|
373
|
+
key = self.__get_key(key)
|
374
|
+
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype or environ.dftype())
|
375
|
+
if scale is not None:
|
376
|
+
scale = u.math.asarray(scale, dtype=dtype)
|
377
|
+
r = r / scale
|
378
|
+
return r
|
379
|
+
|
380
|
+
def gamma(
|
381
|
+
self,
|
382
|
+
shape,
|
383
|
+
scale=None,
|
384
|
+
size: Optional[Size] = None,
|
385
|
+
key: Optional[SeedOrKey] = None,
|
386
|
+
dtype: DTypeLike = None
|
387
|
+
):
|
388
|
+
shape = _check_py_seq(shape)
|
389
|
+
scale = _check_py_seq(scale)
|
390
|
+
if size is None:
|
391
|
+
size = lax.broadcast_shapes(u.math.shape(shape), u.math.shape(scale))
|
392
|
+
key = self.__get_key(key)
|
393
|
+
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype or environ.dftype())
|
394
|
+
if scale is not None:
|
395
|
+
r = r * scale
|
396
|
+
return r
|
397
|
+
|
398
|
+
def gumbel(
|
399
|
+
self,
|
400
|
+
loc=None,
|
401
|
+
scale=None,
|
402
|
+
size: Optional[Size] = None,
|
403
|
+
key: Optional[SeedOrKey] = None,
|
404
|
+
dtype: DTypeLike = None
|
405
|
+
):
|
406
|
+
loc = _check_py_seq(loc)
|
407
|
+
scale = _check_py_seq(scale)
|
408
|
+
if size is None:
|
409
|
+
size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
|
410
|
+
key = self.__get_key(key)
|
411
|
+
r = _loc_scale(loc, scale, jr.gumbel(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
|
412
|
+
return r
|
413
|
+
|
414
|
+
def laplace(
|
415
|
+
self,
|
416
|
+
loc=None,
|
417
|
+
scale=None,
|
418
|
+
size: Optional[Size] = None,
|
419
|
+
key: Optional[SeedOrKey] = None,
|
420
|
+
dtype: DTypeLike = None
|
421
|
+
):
|
422
|
+
loc = _check_py_seq(loc)
|
423
|
+
scale = _check_py_seq(scale)
|
424
|
+
if size is None:
|
425
|
+
size = lax.broadcast_shapes(u.math.shape(loc), u.math.shape(scale))
|
426
|
+
key = self.__get_key(key)
|
427
|
+
r = _loc_scale(loc, scale, jr.laplace(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
|
428
|
+
return r
|
429
|
+
|
430
|
+
def logistic(
|
431
|
+
self,
|
432
|
+
loc=None,
|
433
|
+
scale=None,
|
434
|
+
size: Optional[Size] = None,
|
435
|
+
key: Optional[SeedOrKey] = None,
|
436
|
+
dtype: DTypeLike = None
|
437
|
+
):
|
438
|
+
loc = _check_py_seq(loc)
|
439
|
+
scale = _check_py_seq(scale)
|
440
|
+
if size is None:
|
441
|
+
size = lax.broadcast_shapes(
|
442
|
+
u.math.shape(loc) if loc is not None else (),
|
443
|
+
u.math.shape(scale) if scale is not None else ()
|
444
|
+
)
|
445
|
+
key = self.__get_key(key)
|
446
|
+
r = _loc_scale(loc, scale, jr.logistic(key, shape=_size2shape(size), dtype=dtype or environ.dftype()))
|
447
|
+
return r
|
448
|
+
|
449
|
+
def normal(
|
450
|
+
self,
|
451
|
+
loc=None,
|
452
|
+
scale=None,
|
453
|
+
size: Optional[Size] = None,
|
454
|
+
key: Optional[SeedOrKey] = None,
|
455
|
+
dtype: DTypeLike = None
|
456
|
+
):
|
457
|
+
loc = _check_py_seq(loc)
|
458
|
+
scale = _check_py_seq(scale)
|
459
|
+
if size is None:
|
460
|
+
size = lax.broadcast_shapes(
|
461
|
+
u.math.shape(scale) if scale is not None else (),
|
462
|
+
u.math.shape(loc) if loc is not None else ()
|
463
|
+
)
|
464
|
+
key = self.__get_key(key)
|
465
|
+
dtype = dtype or environ.dftype()
|
466
|
+
r = _loc_scale(loc, scale, jr.normal(key, shape=_size2shape(size), dtype=dtype))
|
467
|
+
return r
|
468
|
+
|
469
|
+
def pareto(
|
470
|
+
self,
|
471
|
+
a,
|
472
|
+
size: Optional[Size] = None,
|
473
|
+
key: Optional[SeedOrKey] = None,
|
474
|
+
dtype: DTypeLike = None
|
475
|
+
):
|
476
|
+
if size is None:
|
477
|
+
size = u.math.shape(a)
|
478
|
+
key = self.__get_key(key)
|
479
|
+
dtype = dtype or environ.dftype()
|
480
|
+
a = u.math.asarray(a, dtype=dtype)
|
481
|
+
r = jr.pareto(key, b=a, shape=_size2shape(size), dtype=dtype)
|
482
|
+
return r
|
483
|
+
|
484
|
+
def poisson(
|
485
|
+
self,
|
486
|
+
lam=1.0,
|
487
|
+
size: Optional[Size] = None,
|
488
|
+
key: Optional[SeedOrKey] = None,
|
489
|
+
dtype: DTypeLike = None
|
490
|
+
):
|
491
|
+
lam = _check_py_seq(lam)
|
492
|
+
if size is None:
|
493
|
+
size = u.math.shape(lam)
|
494
|
+
key = self.__get_key(key)
|
495
|
+
dtype = dtype or environ.ditype()
|
496
|
+
r = jr.poisson(key, lam=lam, shape=_size2shape(size), dtype=dtype)
|
497
|
+
return r
|
498
|
+
|
499
|
+
def standard_cauchy(
|
500
|
+
self,
|
501
|
+
size: Optional[Size] = None,
|
502
|
+
key: Optional[SeedOrKey] = None,
|
503
|
+
dtype: DTypeLike = None
|
504
|
+
):
|
505
|
+
key = self.__get_key(key)
|
506
|
+
dtype = dtype or environ.dftype()
|
507
|
+
r = jr.cauchy(key, shape=_size2shape(size), dtype=dtype)
|
508
|
+
return r
|
509
|
+
|
510
|
+
def standard_exponential(
|
511
|
+
self,
|
512
|
+
size: Optional[Size] = None,
|
513
|
+
key: Optional[SeedOrKey] = None,
|
514
|
+
dtype: DTypeLike = None
|
515
|
+
):
|
516
|
+
key = self.__get_key(key)
|
517
|
+
dtype = dtype or environ.dftype()
|
518
|
+
r = jr.exponential(key, shape=_size2shape(size), dtype=dtype)
|
519
|
+
return r
|
520
|
+
|
521
|
+
def standard_gamma(
|
522
|
+
self,
|
523
|
+
shape,
|
524
|
+
size: Optional[Size] = None,
|
525
|
+
key: Optional[SeedOrKey] = None,
|
526
|
+
dtype: DTypeLike = None
|
527
|
+
):
|
528
|
+
shape = _check_py_seq(shape)
|
529
|
+
if size is None:
|
530
|
+
size = u.math.shape(shape) if shape is not None else ()
|
531
|
+
key = self.__get_key(key)
|
532
|
+
dtype = dtype or environ.dftype()
|
533
|
+
r = jr.gamma(key, a=shape, shape=_size2shape(size), dtype=dtype)
|
534
|
+
return r
|
535
|
+
|
536
|
+
def standard_normal(
|
537
|
+
self,
|
538
|
+
size: Optional[Size] = None,
|
539
|
+
key: Optional[SeedOrKey] = None,
|
540
|
+
dtype: DTypeLike = None
|
541
|
+
):
|
542
|
+
key = self.__get_key(key)
|
543
|
+
dtype = dtype or environ.dftype()
|
544
|
+
r = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
545
|
+
return r
|
546
|
+
|
547
|
+
def standard_t(
|
548
|
+
self,
|
549
|
+
df,
|
550
|
+
size: Optional[Size] = None,
|
551
|
+
key: Optional[SeedOrKey] = None,
|
552
|
+
dtype: DTypeLike = None
|
553
|
+
):
|
554
|
+
df = _check_py_seq(df)
|
555
|
+
if size is None:
|
556
|
+
size = u.math.shape(size) if size is not None else ()
|
557
|
+
key = self.__get_key(key)
|
558
|
+
dtype = dtype or environ.dftype()
|
559
|
+
r = jr.t(key, df=df, shape=_size2shape(size), dtype=dtype)
|
560
|
+
return r
|
561
|
+
|
562
|
+
def uniform(
|
563
|
+
self,
|
564
|
+
low=0.0,
|
565
|
+
high=1.0,
|
566
|
+
size: Optional[Size] = None,
|
567
|
+
key: Optional[SeedOrKey] = None,
|
568
|
+
dtype: DTypeLike = None
|
569
|
+
):
|
570
|
+
low, unit = u.split_mantissa_unit(_check_py_seq(low))
|
571
|
+
high = u.Quantity(_check_py_seq(high)).to(unit).mantissa
|
572
|
+
if size is None:
|
573
|
+
size = lax.broadcast_shapes(u.math.shape(low), u.math.shape(high))
|
574
|
+
key = self.__get_key(key)
|
575
|
+
dtype = dtype or environ.dftype()
|
576
|
+
r = jr.uniform(key, _size2shape(size), dtype=dtype, minval=low, maxval=high)
|
577
|
+
return u.maybe_decimal(r * unit)
|
578
|
+
|
579
|
+
def __norm_cdf(self, x, sqrt2, dtype):
|
580
|
+
# Computes standard normal cumulative distribution function
|
581
|
+
return (np.asarray(1., dtype) + lax.erf(x / sqrt2)) / np.asarray(2., dtype)
|
582
|
+
|
583
|
+
def truncated_normal(
|
584
|
+
self,
|
585
|
+
lower,
|
586
|
+
upper,
|
587
|
+
size: Optional[Size] = None,
|
588
|
+
loc=0.0,
|
589
|
+
scale=1.0,
|
590
|
+
key: Optional[SeedOrKey] = None,
|
591
|
+
dtype: DTypeLike = None,
|
592
|
+
check_valid: bool = True
|
593
|
+
):
|
594
|
+
lower = _check_py_seq(lower)
|
595
|
+
upper = _check_py_seq(upper)
|
596
|
+
loc = _check_py_seq(loc)
|
597
|
+
scale = _check_py_seq(scale)
|
598
|
+
dtype = dtype or environ.dftype()
|
599
|
+
|
600
|
+
lower, unit = u.split_mantissa_unit(u.math.asarray(lower, dtype=dtype))
|
601
|
+
upper = u.math.asarray(upper, dtype=dtype)
|
602
|
+
loc = u.math.asarray(loc, dtype=dtype)
|
603
|
+
scale = u.math.asarray(scale, dtype=dtype)
|
604
|
+
upper, loc, scale = (
|
605
|
+
u.Quantity(upper).in_unit(unit).mantissa,
|
606
|
+
u.Quantity(loc).in_unit(unit).mantissa,
|
607
|
+
u.Quantity(scale).in_unit(unit).mantissa
|
608
|
+
)
|
609
|
+
|
610
|
+
if check_valid:
|
611
|
+
from brainstate.transform._error_if import jit_error_if
|
612
|
+
jit_error_if(
|
613
|
+
u.math.any(u.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
|
614
|
+
"mean is more than 2 std from [lower, upper] in truncated_normal. "
|
615
|
+
"The distribution of values may be incorrect."
|
616
|
+
)
|
617
|
+
|
618
|
+
if size is None:
|
619
|
+
size = u.math.broadcast_shapes(
|
620
|
+
u.math.shape(lower),
|
621
|
+
u.math.shape(upper),
|
622
|
+
u.math.shape(loc),
|
623
|
+
u.math.shape(scale)
|
624
|
+
)
|
625
|
+
|
626
|
+
# Values are generated by using a truncated uniform distribution and
|
627
|
+
# then using the inverse CDF for the normal distribution.
|
628
|
+
# Get upper and lower cdf values
|
629
|
+
sqrt2 = np.array(np.sqrt(2), dtype=dtype)
|
630
|
+
l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
|
631
|
+
u_ = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
|
632
|
+
|
633
|
+
# Uniformly fill tensor with values from [l, u], then translate to
|
634
|
+
# [2l-1, 2u-1].
|
635
|
+
key = self.__get_key(key)
|
636
|
+
out = jr.uniform(
|
637
|
+
key, size, dtype,
|
638
|
+
minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
|
639
|
+
maxval=lax.nextafter(2 * u_ - 1, np.array(-np.inf, dtype=dtype))
|
640
|
+
)
|
641
|
+
|
642
|
+
# Use inverse cdf transform for normal distribution to get truncated
|
643
|
+
# standard normal
|
644
|
+
out = lax.erf_inv(out)
|
645
|
+
|
646
|
+
# Transform to proper mean, std
|
647
|
+
out = out * scale * sqrt2 + loc
|
648
|
+
|
649
|
+
# Clamp to ensure it's in the proper range
|
650
|
+
out = jnp.clip(
|
651
|
+
out,
|
652
|
+
lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
|
653
|
+
lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype))
|
654
|
+
)
|
655
|
+
return u.maybe_decimal(out * unit)
|
656
|
+
|
657
|
+
def _check_p(self, *args, **kwargs):
|
658
|
+
raise ValueError('Parameter p should be within [0, 1], but we got {p}')
|
659
|
+
|
660
|
+
def bernoulli(
|
661
|
+
self,
|
662
|
+
p,
|
663
|
+
size: Optional[Size] = None,
|
664
|
+
key: Optional[SeedOrKey] = None,
|
665
|
+
check_valid: bool = True
|
666
|
+
):
|
667
|
+
p = _check_py_seq(p)
|
668
|
+
if check_valid:
|
669
|
+
from brainstate.transform._error_if import jit_error_if
|
670
|
+
jit_error_if(jnp.any(jnp.logical_or(p < 0, p > 1)), self._check_p, p=p)
|
671
|
+
if size is None:
|
672
|
+
size = u.math.shape(p)
|
673
|
+
key = self.__get_key(key)
|
674
|
+
r = jr.bernoulli(key, p=p, shape=_size2shape(size))
|
675
|
+
return r
|
676
|
+
|
677
|
+
def lognormal(
|
678
|
+
self,
|
679
|
+
mean=None,
|
680
|
+
sigma=None,
|
681
|
+
size: Optional[Size] = None,
|
682
|
+
key: Optional[SeedOrKey] = None,
|
683
|
+
dtype: DTypeLike = None
|
684
|
+
):
|
685
|
+
dtype = dtype or environ.dftype()
|
686
|
+
mean = _check_py_seq(mean)
|
687
|
+
sigma = _check_py_seq(sigma)
|
688
|
+
mean = u.math.asarray(mean, dtype=dtype)
|
689
|
+
sigma = u.math.asarray(sigma, dtype=dtype)
|
690
|
+
unit = mean.unit if isinstance(mean, u.Quantity) else u.UNITLESS
|
691
|
+
mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
|
692
|
+
sigma = sigma.in_unit(unit).mantissa if isinstance(sigma, u.Quantity) else sigma
|
693
|
+
|
694
|
+
if size is None:
|
695
|
+
size = jnp.broadcast_shapes(
|
696
|
+
u.math.shape(mean) if mean is not None else (),
|
697
|
+
u.math.shape(sigma) if sigma is not None else ()
|
698
|
+
)
|
699
|
+
key = self.__get_key(key)
|
700
|
+
dtype = dtype or environ.dftype()
|
701
|
+
samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
|
702
|
+
samples = _loc_scale(mean, sigma, samples)
|
703
|
+
samples = jnp.exp(samples)
|
704
|
+
return u.maybe_decimal(samples * unit)
|
705
|
+
|
706
|
+
def binomial(
|
707
|
+
self,
|
708
|
+
n,
|
709
|
+
p,
|
710
|
+
size: Optional[Size] = None,
|
711
|
+
key: Optional[SeedOrKey] = None,
|
712
|
+
dtype: DTypeLike = None,
|
713
|
+
check_valid: bool = True
|
714
|
+
):
|
715
|
+
n = _check_py_seq(n)
|
716
|
+
p = _check_py_seq(p)
|
717
|
+
if check_valid:
|
718
|
+
from brainstate.transform._error_if import jit_error_if
|
719
|
+
jit_error_if(
|
720
|
+
jnp.any(jnp.logical_or(p < 0, p > 1)),
|
721
|
+
'Parameter p should be within [0, 1], but we got {p}',
|
722
|
+
p=p
|
723
|
+
)
|
724
|
+
if size is None:
|
725
|
+
size = jnp.broadcast_shapes(u.math.shape(n), u.math.shape(p))
|
726
|
+
key = self.__get_key(key)
|
727
|
+
r = jr.binomial(key, n, p, shape=_size2shape(size))
|
728
|
+
dtype = dtype or environ.ditype()
|
729
|
+
return u.math.asarray(r, dtype=dtype)
|
730
|
+
|
731
|
+
def chisquare(
|
732
|
+
self,
|
733
|
+
df,
|
734
|
+
size: Optional[Size] = None,
|
735
|
+
key: Optional[SeedOrKey] = None,
|
736
|
+
dtype: DTypeLike = None
|
737
|
+
):
|
738
|
+
df = _check_py_seq(df)
|
739
|
+
key = self.__get_key(key)
|
740
|
+
dtype = dtype or environ.dftype()
|
741
|
+
if size is None:
|
742
|
+
if jnp.ndim(df) == 0:
|
743
|
+
dist = jr.normal(key, (df,), dtype=dtype) ** 2
|
744
|
+
dist = dist.sum()
|
745
|
+
else:
|
746
|
+
raise NotImplementedError('Do not support non-scale "df" when "size" is None')
|
747
|
+
else:
|
748
|
+
dist = jr.normal(key, (df,) + _size2shape(size), dtype=dtype) ** 2
|
749
|
+
dist = dist.sum(axis=0)
|
750
|
+
return dist
|
751
|
+
|
752
|
+
def dirichlet(
|
753
|
+
self,
|
754
|
+
alpha,
|
755
|
+
size: Optional[Size] = None,
|
756
|
+
key: Optional[SeedOrKey] = None,
|
757
|
+
dtype: DTypeLike = None
|
758
|
+
):
|
759
|
+
key = self.__get_key(key)
|
760
|
+
alpha = _check_py_seq(alpha)
|
761
|
+
dtype = dtype or environ.dftype()
|
762
|
+
r = jr.dirichlet(key, alpha=alpha, shape=_size2shape(size), dtype=dtype)
|
763
|
+
return r
|
764
|
+
|
765
|
+
def geometric(
|
766
|
+
self,
|
767
|
+
p,
|
768
|
+
size: Optional[Size] = None,
|
769
|
+
key: Optional[SeedOrKey] = None,
|
770
|
+
dtype: DTypeLike = None
|
771
|
+
):
|
772
|
+
p = _check_py_seq(p)
|
773
|
+
if size is None:
|
774
|
+
size = u.math.shape(p)
|
775
|
+
key = self.__get_key(key)
|
776
|
+
dtype = dtype or environ.dftype()
|
777
|
+
u_ = jr.uniform(key, size, dtype)
|
778
|
+
r = jnp.floor(jnp.log1p(-u_) / jnp.log1p(-p))
|
779
|
+
return r
|
780
|
+
|
781
|
+
def _check_p2(self, p):
|
782
|
+
raise ValueError(f'We require `sum(pvals[:-1]) <= 1`. But we got {p}')
|
783
|
+
|
784
|
+
def multinomial(
|
785
|
+
self,
|
786
|
+
n,
|
787
|
+
pvals,
|
788
|
+
size: Optional[Size] = None,
|
789
|
+
key: Optional[SeedOrKey] = None,
|
790
|
+
dtype: DTypeLike = None,
|
791
|
+
check_valid: bool = True
|
792
|
+
):
|
793
|
+
key = self.__get_key(key)
|
794
|
+
n = _check_py_seq(n)
|
795
|
+
pvals = _check_py_seq(pvals)
|
796
|
+
if check_valid:
|
797
|
+
from brainstate.transform._error_if import jit_error_if
|
798
|
+
jit_error_if(jnp.sum(pvals[:-1]) > 1., self._check_p2, pvals)
|
799
|
+
if isinstance(n, jax.core.Tracer):
|
800
|
+
raise ValueError("The total count parameter `n` should not be a jax abstract array.")
|
801
|
+
size = _size2shape(size)
|
802
|
+
n_max = int(np.max(jax.device_get(n)))
|
803
|
+
batch_shape = lax.broadcast_shapes(u.math.shape(pvals)[:-1], u.math.shape(n))
|
804
|
+
r = multinomial(key, pvals, n, n_max=n_max, shape=batch_shape + size)
|
805
|
+
dtype = dtype or environ.ditype()
|
806
|
+
return u.math.asarray(r, dtype=dtype)
|
807
|
+
|
808
|
+
def multivariate_normal(
|
809
|
+
self,
|
810
|
+
mean,
|
811
|
+
cov,
|
812
|
+
size: Optional[Size] = None,
|
813
|
+
method: str = 'cholesky',
|
814
|
+
key: Optional[SeedOrKey] = None,
|
815
|
+
dtype: DTypeLike = None
|
816
|
+
):
|
817
|
+
if method not in {'svd', 'eigh', 'cholesky'}:
|
818
|
+
raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
|
819
|
+
dtype = dtype or environ.dftype()
|
820
|
+
mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
|
821
|
+
cov = u.math.asarray(_check_py_seq(cov), dtype=dtype)
|
822
|
+
if isinstance(mean, u.Quantity):
|
823
|
+
assert isinstance(cov, u.Quantity)
|
824
|
+
assert mean.unit ** 2 == cov.unit
|
825
|
+
mean = mean.mantissa if isinstance(mean, u.Quantity) else mean
|
826
|
+
cov = cov.mantissa if isinstance(cov, u.Quantity) else cov
|
827
|
+
unit = mean.unit if isinstance(mean, u.Quantity) else u.Unit()
|
828
|
+
|
829
|
+
key = self.__get_key(key)
|
830
|
+
if not jnp.ndim(mean) >= 1:
|
831
|
+
raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
|
832
|
+
if not jnp.ndim(cov) >= 2:
|
833
|
+
raise ValueError(f"multivariate_normal requires cov.ndim >= 2, got cov.ndim == {jnp.ndim(cov)}")
|
834
|
+
n = mean.shape[-1]
|
835
|
+
if u.math.shape(cov)[-2:] != (n, n):
|
836
|
+
raise ValueError(f"multivariate_normal requires cov.shape == (..., n, n) for n={n}, "
|
837
|
+
f"but got cov.shape == {u.math.shape(cov)}.")
|
838
|
+
if size is None:
|
839
|
+
size = lax.broadcast_shapes(mean.shape[:-1], cov.shape[:-2])
|
840
|
+
else:
|
841
|
+
size = _size2shape(size)
|
842
|
+
_check_shape("normal", size, mean.shape[:-1], cov.shape[:-2])
|
843
|
+
|
844
|
+
if method == 'svd':
|
845
|
+
(u_, s, _) = jnp.linalg.svd(cov)
|
846
|
+
factor = u_ * jnp.sqrt(s[..., None, :])
|
847
|
+
elif method == 'eigh':
|
848
|
+
(w, v) = jnp.linalg.eigh(cov)
|
849
|
+
factor = v * jnp.sqrt(w[..., None, :])
|
850
|
+
else: # 'cholesky'
|
851
|
+
factor = jnp.linalg.cholesky(cov)
|
852
|
+
normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
|
853
|
+
r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
|
854
|
+
return u.maybe_decimal(r * unit)
|
855
|
+
|
856
|
+
def rayleigh(
|
857
|
+
self,
|
858
|
+
scale=1.0,
|
859
|
+
size: Optional[Size] = None,
|
860
|
+
key: Optional[SeedOrKey] = None,
|
861
|
+
dtype: DTypeLike = None
|
862
|
+
):
|
863
|
+
scale = _check_py_seq(scale)
|
864
|
+
if size is None:
|
865
|
+
size = u.math.shape(scale)
|
866
|
+
key = self.__get_key(key)
|
867
|
+
dtype = dtype or environ.dftype()
|
868
|
+
x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), dtype=dtype)))
|
869
|
+
r = x * scale
|
870
|
+
return r
|
871
|
+
|
872
|
+
def triangular(
|
873
|
+
self,
|
874
|
+
size: Optional[Size] = None,
|
875
|
+
key: Optional[SeedOrKey] = None
|
876
|
+
):
|
877
|
+
key = self.__get_key(key)
|
878
|
+
bernoulli_samples = jr.bernoulli(key, p=0.5, shape=_size2shape(size))
|
879
|
+
r = 2 * bernoulli_samples - 1
|
880
|
+
return r
|
881
|
+
|
882
|
+
def vonmises(
|
883
|
+
self,
|
884
|
+
mu,
|
885
|
+
kappa,
|
886
|
+
size: Optional[Size] = None,
|
887
|
+
key: Optional[SeedOrKey] = None,
|
888
|
+
dtype: DTypeLike = None
|
889
|
+
):
|
890
|
+
key = self.__get_key(key)
|
891
|
+
dtype = dtype or environ.dftype()
|
892
|
+
mu = u.math.asarray(_check_py_seq(mu), dtype=dtype)
|
893
|
+
kappa = u.math.asarray(_check_py_seq(kappa), dtype=dtype)
|
894
|
+
if size is None:
|
895
|
+
size = lax.broadcast_shapes(u.math.shape(mu), u.math.shape(kappa))
|
896
|
+
size = _size2shape(size)
|
897
|
+
samples = von_mises_centered(key, kappa, size, dtype=dtype)
|
898
|
+
samples = samples + mu
|
899
|
+
samples = (samples + jnp.pi) % (2.0 * jnp.pi) - jnp.pi
|
900
|
+
return samples
|
901
|
+
|
902
|
+
def weibull(
|
903
|
+
self,
|
904
|
+
a,
|
905
|
+
size: Optional[Size] = None,
|
906
|
+
key: Optional[SeedOrKey] = None,
|
907
|
+
dtype: DTypeLike = None
|
908
|
+
):
|
909
|
+
key = self.__get_key(key)
|
910
|
+
a = _check_py_seq(a)
|
911
|
+
if size is None:
|
912
|
+
size = u.math.shape(a)
|
913
|
+
else:
|
914
|
+
if jnp.size(a) > 1:
|
915
|
+
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
916
|
+
size = _size2shape(size)
|
917
|
+
dtype = dtype or environ.dftype()
|
918
|
+
random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
|
919
|
+
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
920
|
+
return r
|
921
|
+
|
922
|
+
def weibull_min(
|
923
|
+
self,
|
924
|
+
a,
|
925
|
+
scale=None,
|
926
|
+
size: Optional[Size] = None,
|
927
|
+
key: Optional[SeedOrKey] = None,
|
928
|
+
dtype: DTypeLike = None
|
929
|
+
):
|
930
|
+
key = self.__get_key(key)
|
931
|
+
a = _check_py_seq(a)
|
932
|
+
scale = _check_py_seq(scale)
|
933
|
+
if size is None:
|
934
|
+
size = jnp.broadcast_shapes(u.math.shape(a), u.math.shape(scale) if scale is not None else ())
|
935
|
+
else:
|
936
|
+
if jnp.size(a) > 1:
|
937
|
+
raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
|
938
|
+
size = _size2shape(size)
|
939
|
+
dtype = dtype or environ.dftype()
|
940
|
+
random_uniform = jr.uniform(key=key, shape=size, dtype=dtype)
|
941
|
+
r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
|
942
|
+
if scale is not None:
|
943
|
+
r /= scale
|
944
|
+
return r
|
945
|
+
|
946
|
+
def maxwell(
|
947
|
+
self,
|
948
|
+
size: Optional[Size] = None,
|
949
|
+
key: Optional[SeedOrKey] = None,
|
950
|
+
dtype: DTypeLike = None
|
951
|
+
):
|
952
|
+
key = self.__get_key(key)
|
953
|
+
shape = _size2shape(size) + (3,)
|
954
|
+
dtype = dtype or environ.dftype()
|
955
|
+
norm_rvs = jr.normal(key=key, shape=shape, dtype=dtype)
|
956
|
+
r = jnp.linalg.norm(norm_rvs, axis=-1)
|
957
|
+
return r
|
958
|
+
|
959
|
+
def negative_binomial(
|
960
|
+
self,
|
961
|
+
n,
|
962
|
+
p,
|
963
|
+
size: Optional[Size] = None,
|
964
|
+
key: Optional[SeedOrKey] = None,
|
965
|
+
dtype: DTypeLike = None
|
966
|
+
):
|
967
|
+
n = _check_py_seq(n)
|
968
|
+
p = _check_py_seq(p)
|
969
|
+
if size is None:
|
970
|
+
size = lax.broadcast_shapes(u.math.shape(n), u.math.shape(p))
|
971
|
+
size = _size2shape(size)
|
972
|
+
logits = jnp.log(p) - jnp.log1p(-p)
|
973
|
+
if key is None:
|
974
|
+
keys = self.split_key(2)
|
975
|
+
else:
|
976
|
+
keys = jr.split(formalize_key(key, use_prng_key), 2)
|
977
|
+
rate = self.gamma(shape=n, scale=jnp.exp(-logits), size=size, key=keys[0], dtype=environ.dftype())
|
978
|
+
r = self.poisson(lam=rate, key=keys[1], dtype=dtype or environ.ditype())
|
979
|
+
return r
|
980
|
+
|
981
|
+
def wald(
|
982
|
+
self,
|
983
|
+
mean,
|
984
|
+
scale,
|
985
|
+
size: Optional[Size] = None,
|
986
|
+
key: Optional[SeedOrKey] = None,
|
987
|
+
dtype: DTypeLike = None
|
988
|
+
):
|
989
|
+
dtype = dtype or environ.dftype()
|
990
|
+
key = self.__get_key(key)
|
991
|
+
mean = u.math.asarray(_check_py_seq(mean), dtype=dtype)
|
992
|
+
scale = u.math.asarray(_check_py_seq(scale), dtype=dtype)
|
993
|
+
if size is None:
|
994
|
+
size = lax.broadcast_shapes(u.math.shape(mean), u.math.shape(scale))
|
995
|
+
size = _size2shape(size)
|
996
|
+
sampled_chi2 = jnp.square(self.randn(*size))
|
997
|
+
sampled_uniform = self.uniform(size=size, key=key, dtype=dtype)
|
998
|
+
# Wikipedia defines an intermediate x with the formula
|
999
|
+
# x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2)
|
1000
|
+
# where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration.
|
1001
|
+
# Let us write
|
1002
|
+
# w = loc * y / (2 * conc)
|
1003
|
+
# Then we can extract the common factor in the last two terms to obtain
|
1004
|
+
# x = loc + loc * w * (1 - sqrt(2 / w + 1))
|
1005
|
+
# Now we see that the Wikipedia formula suffers from catastrphic
|
1006
|
+
# cancellation for large w (e.g., if conc << loc).
|
1007
|
+
#
|
1008
|
+
# Fortunately, we can fix this by multiplying both sides
|
1009
|
+
# by 1 + sqrt(2 / w + 1). We get
|
1010
|
+
# x * (1 + sqrt(2 / w + 1)) =
|
1011
|
+
# = loc * (1 + sqrt(2 / w + 1)) + loc * w * (1 - (2 / w + 1))
|
1012
|
+
# = loc * (sqrt(2 / w + 1) - 1)
|
1013
|
+
# The term sqrt(2 / w + 1) + 1 no longer presents numerical
|
1014
|
+
# difficulties for large w, and sqrt(2 / w + 1) - 1 is just
|
1015
|
+
# sqrt1pm1(2 / w), which we know how to compute accurately.
|
1016
|
+
# This just leaves the matter of small w, where 2 / w may
|
1017
|
+
# overflow. In the limit a w -> 0, x -> loc, so we just mask
|
1018
|
+
# that case.
|
1019
|
+
sqrt1pm1_arg = 4 * scale / (mean * sampled_chi2) # 2 / w above
|
1020
|
+
safe_sqrt1pm1_arg = jnp.where(sqrt1pm1_arg < np.inf, sqrt1pm1_arg, 1.0)
|
1021
|
+
denominator = 1.0 + jnp.sqrt(safe_sqrt1pm1_arg + 1.0)
|
1022
|
+
ratio = jnp.expm1(0.5 * jnp.log1p(safe_sqrt1pm1_arg)) / denominator
|
1023
|
+
sampled = mean * jnp.where(sqrt1pm1_arg < np.inf, ratio, 1.0) # x above
|
1024
|
+
res = jnp.where(sampled_uniform <= mean / (mean + sampled),
|
1025
|
+
sampled,
|
1026
|
+
jnp.square(mean) / sampled)
|
1027
|
+
return res
|
1028
|
+
|
1029
|
+
def t(
|
1030
|
+
self,
|
1031
|
+
df,
|
1032
|
+
size: Optional[Size] = None,
|
1033
|
+
key: Optional[SeedOrKey] = None,
|
1034
|
+
dtype: DTypeLike = None
|
1035
|
+
):
|
1036
|
+
dtype = dtype or environ.dftype()
|
1037
|
+
df = u.math.asarray(_check_py_seq(df), dtype=dtype)
|
1038
|
+
if size is None:
|
1039
|
+
size = np.shape(df)
|
1040
|
+
else:
|
1041
|
+
size = _size2shape(size)
|
1042
|
+
_check_shape("t", size, np.shape(df))
|
1043
|
+
if key is None:
|
1044
|
+
keys = self.split_key(2)
|
1045
|
+
else:
|
1046
|
+
keys = jr.split(formalize_key(key, use_prng_key), 2)
|
1047
|
+
n = jr.normal(keys[0], size, dtype=dtype)
|
1048
|
+
two = const(n, 2)
|
1049
|
+
half_df = lax.div(df, two)
|
1050
|
+
g = jr.gamma(keys[1], half_df, size, dtype=dtype)
|
1051
|
+
r = n * jnp.sqrt(half_df / g)
|
1052
|
+
return r
|
1053
|
+
|
1054
|
+
def orthogonal(
|
1055
|
+
self,
|
1056
|
+
n: int,
|
1057
|
+
size: Optional[Size] = None,
|
1058
|
+
key: Optional[SeedOrKey] = None,
|
1059
|
+
dtype: DTypeLike = None
|
1060
|
+
):
|
1061
|
+
dtype = dtype or environ.dftype()
|
1062
|
+
key = self.__get_key(key)
|
1063
|
+
size = _size2shape(size)
|
1064
|
+
_check_shape("orthogonal", size)
|
1065
|
+
n = core.concrete_or_error(index, n, "The error occurred in jax.random.orthogonal()")
|
1066
|
+
z = jr.normal(key, size + (n, n), dtype=dtype)
|
1067
|
+
q, r = jnp.linalg.qr(z)
|
1068
|
+
d = jnp.diagonal(r, 0, -2, -1)
|
1069
|
+
r = q * jnp.expand_dims(d / abs(d), -2)
|
1070
|
+
return r
|
1071
|
+
|
1072
|
+
def noncentral_chisquare(
|
1073
|
+
self,
|
1074
|
+
df,
|
1075
|
+
nonc,
|
1076
|
+
size: Optional[Size] = None,
|
1077
|
+
key: Optional[SeedOrKey] = None,
|
1078
|
+
dtype: DTypeLike = None
|
1079
|
+
):
|
1080
|
+
dtype = dtype or environ.dftype()
|
1081
|
+
df = u.math.asarray(_check_py_seq(df), dtype=dtype)
|
1082
|
+
nonc = u.math.asarray(_check_py_seq(nonc), dtype=dtype)
|
1083
|
+
if size is None:
|
1084
|
+
size = lax.broadcast_shapes(u.math.shape(df), u.math.shape(nonc))
|
1085
|
+
size = _size2shape(size)
|
1086
|
+
if key is None:
|
1087
|
+
keys = self.split_key(3)
|
1088
|
+
else:
|
1089
|
+
keys = jr.split(formalize_key(key, use_prng_key), 3)
|
1090
|
+
i = jr.poisson(keys[0], 0.5 * nonc, shape=size, dtype=environ.ditype())
|
1091
|
+
n = jr.normal(keys[1], shape=size, dtype=dtype) + jnp.sqrt(nonc)
|
1092
|
+
cond = jnp.greater(df, 1.0)
|
1093
|
+
df2 = jnp.where(cond, df - 1.0, df + 2.0 * i)
|
1094
|
+
chi2 = 2.0 * jr.gamma(keys[2], 0.5 * df2, shape=size, dtype=dtype)
|
1095
|
+
r = jnp.where(cond, chi2 + n * n, chi2)
|
1096
|
+
return r
|
1097
|
+
|
1098
|
+
def loggamma(
|
1099
|
+
self,
|
1100
|
+
a,
|
1101
|
+
size: Optional[Size] = None,
|
1102
|
+
key: Optional[SeedOrKey] = None,
|
1103
|
+
dtype: DTypeLike = None
|
1104
|
+
):
|
1105
|
+
dtype = dtype or environ.dftype()
|
1106
|
+
key = self.__get_key(key)
|
1107
|
+
a = _check_py_seq(a)
|
1108
|
+
if size is None:
|
1109
|
+
size = u.math.shape(a)
|
1110
|
+
r = jr.loggamma(key, a, shape=_size2shape(size), dtype=dtype)
|
1111
|
+
return r
|
1112
|
+
|
1113
|
+
def categorical(
|
1114
|
+
self,
|
1115
|
+
logits,
|
1116
|
+
axis: int = -1,
|
1117
|
+
size: Optional[Size] = None,
|
1118
|
+
key: Optional[SeedOrKey] = None
|
1119
|
+
):
|
1120
|
+
key = self.__get_key(key)
|
1121
|
+
logits = _check_py_seq(logits)
|
1122
|
+
if size is None:
|
1123
|
+
size = list(u.math.shape(logits))
|
1124
|
+
size.pop(axis)
|
1125
|
+
r = jr.categorical(key, logits, axis=axis, shape=_size2shape(size))
|
1126
|
+
return r
|
1127
|
+
|
1128
|
+
def zipf(
|
1129
|
+
self,
|
1130
|
+
a,
|
1131
|
+
size: Optional[Size] = None,
|
1132
|
+
key: Optional[SeedOrKey] = None,
|
1133
|
+
dtype: DTypeLike = None
|
1134
|
+
):
|
1135
|
+
a = _check_py_seq(a)
|
1136
|
+
if size is None:
|
1137
|
+
size = u.math.shape(a)
|
1138
|
+
r = zipf(
|
1139
|
+
self.__get_key(key),
|
1140
|
+
a,
|
1141
|
+
shape=size,
|
1142
|
+
dtype=dtype or environ.ditype()
|
1143
|
+
)
|
1144
|
+
return r
|
1145
|
+
|
1146
|
+
def power(
|
1147
|
+
self,
|
1148
|
+
a,
|
1149
|
+
size: Optional[Size] = None,
|
1150
|
+
key: Optional[SeedOrKey] = None,
|
1151
|
+
dtype: DTypeLike = None
|
1152
|
+
):
|
1153
|
+
a = _check_py_seq(a)
|
1154
|
+
if size is None:
|
1155
|
+
size = u.math.shape(a)
|
1156
|
+
size = _size2shape(size)
|
1157
|
+
r = power(
|
1158
|
+
self.__get_key(key),
|
1159
|
+
a,
|
1160
|
+
shape=size,
|
1161
|
+
dtype=dtype or environ.dftype(),
|
1162
|
+
)
|
1163
|
+
return r
|
1164
|
+
|
1165
|
+
def f(
|
1166
|
+
self,
|
1167
|
+
dfnum,
|
1168
|
+
dfden,
|
1169
|
+
size: Optional[Size] = None,
|
1170
|
+
key: Optional[SeedOrKey] = None,
|
1171
|
+
dtype: DTypeLike = None
|
1172
|
+
):
|
1173
|
+
dfnum = _check_py_seq(dfnum)
|
1174
|
+
dfden = _check_py_seq(dfden)
|
1175
|
+
if size is None:
|
1176
|
+
size = jnp.broadcast_shapes(u.math.shape(dfnum), u.math.shape(dfden))
|
1177
|
+
size = _size2shape(size)
|
1178
|
+
r = f(
|
1179
|
+
self.__get_key(key),
|
1180
|
+
dfnum,
|
1181
|
+
dfden,
|
1182
|
+
shape=size,
|
1183
|
+
dtype=dtype or environ.dftype(),
|
1184
|
+
)
|
1185
|
+
return r
|
1186
|
+
|
1187
|
+
def hypergeometric(
|
1188
|
+
self,
|
1189
|
+
ngood,
|
1190
|
+
nbad,
|
1191
|
+
nsample,
|
1192
|
+
size: Optional[Size] = None,
|
1193
|
+
key: Optional[SeedOrKey] = None,
|
1194
|
+
dtype: DTypeLike = None
|
1195
|
+
):
|
1196
|
+
ngood = _check_py_seq(ngood)
|
1197
|
+
nbad = _check_py_seq(nbad)
|
1198
|
+
nsample = _check_py_seq(nsample)
|
1199
|
+
if size is None:
|
1200
|
+
size = lax.broadcast_shapes(
|
1201
|
+
u.math.shape(ngood),
|
1202
|
+
u.math.shape(nbad),
|
1203
|
+
u.math.shape(nsample)
|
1204
|
+
)
|
1205
|
+
size = _size2shape(size)
|
1206
|
+
r = hypergeometric(
|
1207
|
+
self.__get_key(key),
|
1208
|
+
ngood,
|
1209
|
+
nbad,
|
1210
|
+
nsample,
|
1211
|
+
shape=size,
|
1212
|
+
dtype=dtype or environ.ditype(),
|
1213
|
+
)
|
1214
|
+
return r
|
1215
|
+
|
1216
|
+
def logseries(
|
1217
|
+
self,
|
1218
|
+
p,
|
1219
|
+
size: Optional[Size] = None,
|
1220
|
+
key: Optional[SeedOrKey] = None,
|
1221
|
+
dtype: DTypeLike = None
|
1222
|
+
):
|
1223
|
+
p = _check_py_seq(p)
|
1224
|
+
if size is None:
|
1225
|
+
size = u.math.shape(p)
|
1226
|
+
size = _size2shape(size)
|
1227
|
+
r = logseries(
|
1228
|
+
self.__get_key(key),
|
1229
|
+
p,
|
1230
|
+
shape=size,
|
1231
|
+
dtype=dtype or environ.ditype()
|
1232
|
+
)
|
1233
|
+
return r
|
1234
|
+
|
1235
|
+
def noncentral_f(
|
1236
|
+
self,
|
1237
|
+
dfnum,
|
1238
|
+
dfden,
|
1239
|
+
nonc,
|
1240
|
+
size: Optional[Size] = None,
|
1241
|
+
key: Optional[SeedOrKey] = None,
|
1242
|
+
dtype: DTypeLike = None
|
1243
|
+
):
|
1244
|
+
dfnum = _check_py_seq(dfnum)
|
1245
|
+
dfden = _check_py_seq(dfden)
|
1246
|
+
nonc = _check_py_seq(nonc)
|
1247
|
+
if size is None:
|
1248
|
+
size = lax.broadcast_shapes(u.math.shape(dfnum),
|
1249
|
+
u.math.shape(dfden),
|
1250
|
+
u.math.shape(nonc))
|
1251
|
+
size = _size2shape(size)
|
1252
|
+
r = noncentral_f(
|
1253
|
+
self.__get_key(key),
|
1254
|
+
dfnum,
|
1255
|
+
dfden,
|
1256
|
+
nonc,
|
1257
|
+
shape=size,
|
1258
|
+
dtype=dtype or environ.dftype(),
|
1259
|
+
)
|
1260
|
+
return r
|
1261
|
+
|
1262
|
+
# PyTorch compatibility #
|
1263
|
+
# --------------------- #
|
1264
|
+
|
1265
|
+
def rand_like(
|
1266
|
+
self,
|
1267
|
+
input,
|
1268
|
+
*,
|
1269
|
+
dtype=None,
|
1270
|
+
key: Optional[SeedOrKey] = None
|
1271
|
+
):
|
1272
|
+
"""Returns a tensor with the same size as input that is filled with random
|
1273
|
+
numbers from a uniform distribution on the interval ``[0, 1)``.
|
1274
|
+
|
1275
|
+
Args:
|
1276
|
+
input: the ``size`` of input will determine size of the output tensor.
|
1277
|
+
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
|
1278
|
+
key: the seed or key for the random.
|
1279
|
+
|
1280
|
+
Returns:
|
1281
|
+
The random data.
|
1282
|
+
"""
|
1283
|
+
return self.random(u.math.shape(input), key=key).astype(dtype)
|
1284
|
+
|
1285
|
+
def randn_like(
|
1286
|
+
self,
|
1287
|
+
input,
|
1288
|
+
*,
|
1289
|
+
dtype=None,
|
1290
|
+
key: Optional[SeedOrKey] = None
|
1291
|
+
):
|
1292
|
+
"""Returns a tensor with the same size as ``input`` that is filled with
|
1293
|
+
random numbers from a normal distribution with mean 0 and variance 1.
|
1294
|
+
|
1295
|
+
Args:
|
1296
|
+
input: the ``size`` of input will determine size of the output tensor.
|
1297
|
+
dtype: the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of input.
|
1298
|
+
key: the seed or key for the random.
|
1299
|
+
|
1300
|
+
Returns:
|
1301
|
+
The random data.
|
1302
|
+
"""
|
1303
|
+
return self.randn(*u.math.shape(input), key=key).astype(dtype)
|
1304
|
+
|
1305
|
+
def randint_like(
|
1306
|
+
self,
|
1307
|
+
input,
|
1308
|
+
low=0,
|
1309
|
+
high=None,
|
1310
|
+
*,
|
1311
|
+
dtype=None,
|
1312
|
+
key: Optional[SeedOrKey] = None
|
1313
|
+
):
|
1314
|
+
if high is None:
|
1315
|
+
high = max(input)
|
1316
|
+
return self.randint(low, high=high, size=u.math.shape(input), dtype=dtype, key=key)
|
1317
|
+
|
1318
|
+
|
1319
|
+
# default random generator
|
1320
|
+
DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32))
|