@genai-fi/nanogpt 0.5.1 → 0.5.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. package/dist/Generator.js +90 -41
  2. package/dist/NanoGPTModel.d.ts +1 -0
  3. package/dist/NanoGPTModel.js +86 -73
  4. package/dist/{Reshape-BE5rA4rT.js → Reshape-Bt_t7RNz.js} +4 -4
  5. package/dist/TeachableLLM.js +1 -1
  6. package/dist/TiedEmbedding-DORsPlNL.js +44 -0
  7. package/dist/{axis_util-97KkkyRQ.js → axis_util-CVbf1vmL.js} +3 -3
  8. package/dist/{broadcast_to-CMlkG8NS.js → broadcast_to-BBoMQXbL.js} +4 -4
  9. package/dist/{concat-Cxbo2sOz.js → concat-BRRtq4S2.js} +1 -1
  10. package/dist/dataset-ZHEPJmED.js +1226 -0
  11. package/dist/{dropout-kbDY39Ci.js → dropout-lQm_YyX3.js} +1 -1
  12. package/dist/{gather-Bxe1Qip8.js → gather-BWyutxwi.js} +3 -3
  13. package/dist/{gpgpu_math-C0zyxKFi.js → gpgpu_math-Df7gzJWH.js} +1 -1
  14. package/dist/{index-iNhkcAEQ.js → index-CnHyhpKc.js} +32 -32
  15. package/dist/{kernel_funcs_utils-C4eIk4fE.js → kernel_funcs_utils-Dqo82NH4.js} +25 -25
  16. package/dist/layers/BaseLayer.js +114 -3
  17. package/dist/layers/CausalSelfAttention.js +29 -28
  18. package/dist/layers/MLP.js +10 -9
  19. package/dist/layers/RMSNorm.js +12 -11
  20. package/dist/layers/RoPECache.js +3 -3
  21. package/dist/layers/TiedEmbedding.js +8 -6
  22. package/dist/layers/TransformerBlock.js +2 -2
  23. package/dist/{log_sum_exp-CkumwesB.js → log_sum_exp-CRH7Np9v.js} +12 -12
  24. package/dist/main.js +1 -1
  25. package/dist/{mat_mul-D0SifYfJ.js → mat_mul-DeGU1U_C.js} +3 -3
  26. package/dist/{max-CYaAjEEp.js → max-CcnEArWK.js} +3 -3
  27. package/dist/{moments-B06NlR_V.js → moments-DLTE6-1p.js} +4 -4
  28. package/dist/{norm-D3676xIo.js → norm-BpWsOapl.js} +5 -5
  29. package/dist/{ones-BIeFnPHR.js → ones-CDWGzVnm.js} +6 -6
  30. package/dist/ops/appendCache.js +3 -3
  31. package/dist/ops/attentionMask.js +1 -1
  32. package/dist/ops/cpu/appendCache.js +2 -2
  33. package/dist/ops/cpu/attentionMask.js +5 -5
  34. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  35. package/dist/ops/cpu/gatherSub.js +5 -5
  36. package/dist/ops/cpu/gelu.js +1 -1
  37. package/dist/ops/cpu/matMulGelu.js +1 -1
  38. package/dist/ops/cpu/matMulMul.js +1 -1
  39. package/dist/ops/cpu/mulDropout.js +1 -1
  40. package/dist/ops/cpu/normRMS.js +1 -1
  41. package/dist/ops/cpu/qkv.js +3 -3
  42. package/dist/ops/cpu/rope.js +5 -5
  43. package/dist/ops/cpu/scatterSub.js +27 -27
  44. package/dist/ops/fusedSoftmax.js +1 -1
  45. package/dist/ops/gatherSub.js +1 -1
  46. package/dist/ops/gelu.js +1 -1
  47. package/dist/ops/grads/attentionMask.js +1 -1
  48. package/dist/ops/grads/fusedSoftmax.js +2 -2
  49. package/dist/ops/grads/gelu.js +1 -1
  50. package/dist/ops/grads/matMulGelu.js +1 -1
  51. package/dist/ops/grads/normRMS.js +1 -1
  52. package/dist/ops/grads/qkv.js +1 -1
  53. package/dist/ops/grads/rope.js +1 -1
  54. package/dist/ops/matMulGelu.js +1 -1
  55. package/dist/ops/matMulMul.js +1 -1
  56. package/dist/ops/mulDrop.js +1 -1
  57. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  58. package/dist/ops/normRMS.js +1 -1
  59. package/dist/ops/qkv.js +1 -1
  60. package/dist/ops/scatterSub.js +1 -1
  61. package/dist/ops/webgl/appendCache.js +1 -1
  62. package/dist/ops/webgl/attentionMask.js +1 -1
  63. package/dist/ops/webgl/fusedSoftmax.js +36 -36
  64. package/dist/ops/webgl/gatherSub.js +1 -1
  65. package/dist/ops/webgl/gelu.js +2 -2
  66. package/dist/ops/webgl/matMulGelu.js +22 -22
  67. package/dist/ops/webgl/matMulMul.js +1 -1
  68. package/dist/ops/webgl/mulDropout.js +1 -1
  69. package/dist/ops/webgl/normRMS.js +2 -2
  70. package/dist/ops/webgl/qkv.js +1 -1
  71. package/dist/ops/webgl/rope.js +1 -1
  72. package/dist/ops/webgl/scatterSub.js +1 -1
  73. package/dist/{ops-ObfXLHYQ.js → ops-DzQTmLIl.js} +60 -60
  74. package/dist/{TiedEmbedding-DsDRvLB0.js → random_width-DI2h9CMs.js} +1215 -1250
  75. package/dist/{range-BsFU-SNG.js → range-CkOJ7090.js} +1 -1
  76. package/dist/{reshape-DxTPgnwL.js → reshape-CTIbqjwm.js} +1 -1
  77. package/dist/{sin-BOX-JVAj.js → sin-HzioENy_.js} +5 -5
  78. package/dist/{slice_util-D-kaD4ZV.js → slice_util-n4wHKmex.js} +1 -1
  79. package/dist/{softmax-BjsptB07.js → softmax-DX6qXAbm.js} +2 -2
  80. package/dist/{split-BCbrzthj.js → split-CVwhL8Oe.js} +3 -3
  81. package/dist/{stack--cqr9Dgc.js → stack-S2-D2JAQ.js} +1 -1
  82. package/dist/{sum-B_92TaHD.js → sum-UdfvaNhB.js} +4 -4
  83. package/dist/{tensor-CfiPXsW4.js → tensor-IZex6Bwp.js} +1 -1
  84. package/dist/{tensor2d-tSxWdFMH.js → tensor2d-CqtBzOKq.js} +1 -1
  85. package/dist/{tfjs_backend-NucKez4s.js → tfjs_backend-DX9yVvwk.js} +41 -41
  86. package/dist/tokeniser/CharTokeniser.js +27 -27
  87. package/dist/tokeniser/bpe.d.ts +1 -0
  88. package/dist/tokeniser/bpe.js +38 -35
  89. package/dist/training/AdamExt.js +1 -1
  90. package/dist/training/DatasetBuilder.js +22 -1242
  91. package/dist/training/FullTrainer.js +1 -1
  92. package/dist/training/Trainer.js +5 -5
  93. package/dist/training/sparseCrossEntropy.js +4 -4
  94. package/dist/utilities/dummy.js +2 -2
  95. package/dist/utilities/generate.js +3 -3
  96. package/dist/utilities/load.js +1 -1
  97. package/dist/utilities/profile.js +1 -1
  98. package/dist/utilities/weights.js +2 -2
  99. package/dist/variable-BGvK-VN3.js +23 -0
  100. package/dist/{zeros-NMYTayy7.js → zeros-CYMicyqz.js} +3 -3
  101. package/package.json +1 -1
  102. package/dist/BaseLayer-BhrMN8JO.js +0 -135
@@ -0,0 +1,1226 @@
1
+ import { ag as S, T as h, N, a as v, ah as o, ai as p, aj as g, l as k, t as y } from "./index-CnHyhpKc.js";
2
+ import { s as R } from "./index-C4L8Cm77.js";
3
+ import { s as $ } from "./stack-S2-D2JAQ.js";
4
+ import { t as B } from "./tensor-IZex6Bwp.js";
5
+ /**
6
+ * @license
7
+ * Copyright 2018 Google LLC. All Rights Reserved.
8
+ * Licensed under the Apache License, Version 2.0 (the "License");
9
+ * you may not use this file except in compliance with the License.
10
+ * You may obtain a copy of the License at
11
+ *
12
+ * http://www.apache.org/licenses/LICENSE-2.0
13
+ *
14
+ * Unless required by applicable law or agreed to in writing, software
15
+ * distributed under the License is distributed on an "AS IS" BASIS,
16
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ * See the License for the specific language governing permissions and
18
+ * limitations under the License.
19
+ *
20
+ * =============================================================================
21
+ */
22
+ function _(s, t) {
23
+ return b(s, t);
24
+ }
25
+ function b(s, t, e = /* @__PURE__ */ new Map(), r = /* @__PURE__ */ new Set()) {
26
+ if (s == null)
27
+ return null;
28
+ if (typeof Blob == "function" && s instanceof Blob)
29
+ return s.slice();
30
+ if (r.has(s))
31
+ throw new Error("Circular references are not supported.");
32
+ if (e.has(s))
33
+ return e.get(s);
34
+ const n = t(s);
35
+ if (n.recurse && n.value !== null)
36
+ throw new Error("A deep map function may not return both a value and recurse=true.");
37
+ if (n.recurse)
38
+ if (c(s)) {
39
+ const a = Array.isArray(s) ? [] : {};
40
+ r.add(s);
41
+ for (const l in s) {
42
+ const d = s[l], m = b(d, t, e, r);
43
+ a[l] = m;
44
+ }
45
+ return r.delete(s), s.__proto__ && (a.__proto__ = s.__proto__), a;
46
+ } else
47
+ throw new Error(`Can't recurse into non-iterable type: ${s}`);
48
+ else return e.set(s, n.value), n.value;
49
+ }
50
+ function P(s, t = I) {
51
+ return E(s, t);
52
+ }
53
+ function E(s, t, e = /* @__PURE__ */ new Set()) {
54
+ const r = s[0];
55
+ if (e.has(r))
56
+ throw new Error("Circular references are not supported.");
57
+ const n = t(s);
58
+ if (n.recurse && n.value !== null)
59
+ throw new Error("A deep zip function may not return both a value and recurse=true.");
60
+ if (n.recurse)
61
+ if (c(r)) {
62
+ const a = Array.isArray(r) ? [] : {};
63
+ e.add(r);
64
+ for (const l in r) {
65
+ const d = s.map((F) => F[l]), m = E(d, t, e);
66
+ a[l] = m;
67
+ }
68
+ return e.delete(r), a;
69
+ } else
70
+ throw new Error(`Can't recurse into non-iterable type: ${r}`);
71
+ else return n.value;
72
+ }
73
+ function I(s) {
74
+ return s === null ? null : c(s[0]) ? { value: null, recurse: !0 } : { value: s, recurse: !1 };
75
+ }
76
+ function c(s) {
77
+ let t = !1;
78
+ if (N().get("IS_BROWSER"))
79
+ t = s instanceof TextDecoder;
80
+ else {
81
+ const { StringDecoder: e } = require("string_decoder");
82
+ t = s instanceof e;
83
+ }
84
+ return s != null && !ArrayBuffer.isView(s) && (Array.isArray(s) || typeof s == "object" && !(s instanceof h) && !(s instanceof Promise) && !t);
85
+ }
86
+ function M(s) {
87
+ return s == null || L(s) || Array.isArray(s) || typeof s == "object" && s instanceof h || S(s);
88
+ }
89
+ function L(s) {
90
+ return s === null || typeof s != "object" && typeof s != "function";
91
+ }
92
+ /**
93
+ * @license
94
+ * Copyright 2018 Google LLC. All Rights Reserved.
95
+ * Licensed under the Apache License, Version 2.0 (the "License");
96
+ * you may not use this file except in compliance with the License.
97
+ * You may obtain a copy of the License at
98
+ *
99
+ * http://www.apache.org/licenses/LICENSE-2.0
100
+ *
101
+ * Unless required by applicable law or agreed to in writing, software
102
+ * distributed under the License is distributed on an "AS IS" BASIS,
103
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
104
+ * See the License for the specific language governing permissions and
105
+ * limitations under the License.
106
+ *
107
+ * =============================================================================
108
+ */
109
+ function D(s) {
110
+ return _(s, O);
111
+ }
112
+ function O(s) {
113
+ return s instanceof h ? { value: s.clone(), recurse: !1 } : c(s) ? { value: null, recurse: !0 } : { value: s, recurse: !1 };
114
+ }
115
+ /**
116
+ * @license
117
+ * Copyright 2018 Google LLC. All Rights Reserved.
118
+ * Licensed under the Apache License, Version 2.0 (the "License");
119
+ * you may not use this file except in compliance with the License.
120
+ * You may obtain a copy of the License at
121
+ *
122
+ * http://www.apache.org/licenses/LICENSE-2.0
123
+ *
124
+ * Unless required by applicable law or agreed to in writing, software
125
+ * distributed under the License is distributed on an "AS IS" BASIS,
126
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
127
+ * See the License for the specific language governing permissions and
128
+ * limitations under the License.
129
+ *
130
+ * =============================================================================
131
+ */
132
+ class z {
133
+ /**
134
+ * Constructs a `RingBuffer`.
135
+ * @param capacity The number of items that the buffer can accomodate.
136
+ */
137
+ constructor(t) {
138
+ if (this.capacity = t, this.begin = 0, this.end = 0, t == null)
139
+ throw new RangeError("Can't create a ring buffer of unknown capacity.");
140
+ if (t < 1)
141
+ throw new RangeError("Can't create ring buffer of capacity < 1.");
142
+ this.data = new Array(t), this.doubledCapacity = 2 * t;
143
+ }
144
+ /**
145
+ * Map any index into the range 0 <= index < 2*capacity.
146
+ */
147
+ wrap(t) {
148
+ for (; t < 0; )
149
+ t += this.doubledCapacity;
150
+ return t % this.doubledCapacity;
151
+ }
152
+ get(t) {
153
+ if (t < 0)
154
+ throw new RangeError("Can't get item at a negative index.");
155
+ return this.data[t % this.capacity];
156
+ }
157
+ set(t, e) {
158
+ if (t < 0)
159
+ throw new RangeError("Can't set item at a negative index.");
160
+ this.data[t % this.capacity] = e;
161
+ }
162
+ /**
163
+ * Returns the current number of items in the buffer.
164
+ */
165
+ length() {
166
+ let t = this.end - this.begin;
167
+ return t < 0 && (t = this.doubledCapacity + t), t;
168
+ }
169
+ /**
170
+ * Reports whether the buffer is full.
171
+ * @returns true if the number of items in the buffer equals its capacity, and
172
+ * false otherwise.
173
+ */
174
+ isFull() {
175
+ return this.length() === this.capacity;
176
+ }
177
+ /**
178
+ * Reports whether the buffer is empty.
179
+ * @returns true if the number of items in the buffer equals zero, and
180
+ * false otherwise.
181
+ */
182
+ isEmpty() {
183
+ return this.length() === 0;
184
+ }
185
+ /**
186
+ * Adds an item to the end of the buffer.
187
+ */
188
+ push(t) {
189
+ if (this.isFull())
190
+ throw new RangeError("Ring buffer is full.");
191
+ this.set(this.end, t), this.end = this.wrap(this.end + 1);
192
+ }
193
+ /**
194
+ * Adds many items to the end of the buffer, in order.
195
+ */
196
+ pushAll(t) {
197
+ for (const e of t)
198
+ this.push(e);
199
+ }
200
+ /**
201
+ * Removes and returns the last item in the buffer.
202
+ */
203
+ pop() {
204
+ if (this.isEmpty())
205
+ throw new RangeError("Ring buffer is empty.");
206
+ this.end = this.wrap(this.end - 1);
207
+ const t = this.get(this.end);
208
+ return this.set(this.end, void 0), t;
209
+ }
210
+ /**
211
+ * Adds an item to the beginning of the buffer.
212
+ */
213
+ unshift(t) {
214
+ if (this.isFull())
215
+ throw new RangeError("Ring buffer is full.");
216
+ this.begin = this.wrap(this.begin - 1), this.set(this.begin, t);
217
+ }
218
+ /**
219
+ * Removes and returns the first item in the buffer.
220
+ */
221
+ shift() {
222
+ if (this.isEmpty())
223
+ throw new RangeError("Ring buffer is empty.");
224
+ const t = this.get(this.begin);
225
+ return this.set(this.begin, void 0), this.begin = this.wrap(this.begin + 1), t;
226
+ }
227
+ /**
228
+ * Removes and returns a specific item in the buffer, and moves the last item
229
+ * to the vacated slot. This is useful for implementing a shuffling stream.
230
+ * Note that this operation necessarily scrambles the original order.
231
+ *
232
+ * @param relativeIndex: the index of the item to remove, relative to the
233
+ * first item in the buffer (e.g., hiding the ring nature of the underlying
234
+ * storage).
235
+ */
236
+ shuffleExcise(t) {
237
+ if (this.isEmpty())
238
+ throw new RangeError("Ring buffer is empty.");
239
+ const e = this.wrap(this.begin + t), r = this.get(e);
240
+ return this.set(e, this.pop()), r;
241
+ }
242
+ }
243
+ /**
244
+ * @license
245
+ * Copyright 2018 Google LLC. All Rights Reserved.
246
+ * Licensed under the Apache License, Version 2.0 (the "License");
247
+ * you may not use this file except in compliance with the License.
248
+ * You may obtain a copy of the License at
249
+ *
250
+ * http://www.apache.org/licenses/LICENSE-2.0
251
+ *
252
+ * Unless required by applicable law or agreed to in writing, software
253
+ * distributed under the License is distributed on an "AS IS" BASIS,
254
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
255
+ * See the License for the specific language governing permissions and
256
+ * limitations under the License.
257
+ *
258
+ * =============================================================================
259
+ */
260
+ class f extends z {
261
+ /**
262
+ * Constructs a `GrowingRingBuffer`.
263
+ */
264
+ constructor() {
265
+ super(f.INITIAL_CAPACITY);
266
+ }
267
+ isFull() {
268
+ return !1;
269
+ }
270
+ push(t) {
271
+ super.isFull() && this.expand(), super.push(t);
272
+ }
273
+ unshift(t) {
274
+ super.isFull() && this.expand(), super.unshift(t);
275
+ }
276
+ /**
277
+ * Doubles the capacity of the buffer.
278
+ */
279
+ expand() {
280
+ const t = this.capacity * 2, e = new Array(t), r = this.length();
281
+ for (let n = 0; n < r; n++)
282
+ e[n] = this.get(this.wrap(this.begin + n));
283
+ this.data = e, this.capacity = t, this.doubledCapacity = 2 * this.capacity, this.begin = 0, this.end = r;
284
+ }
285
+ }
286
+ f.INITIAL_CAPACITY = 32;
287
+ /**
288
+ * @license
289
+ * Copyright 2018 Google LLC. All Rights Reserved.
290
+ * Licensed under the Apache License, Version 2.0 (the "License");
291
+ * you may not use this file except in compliance with the License.
292
+ * You may obtain a copy of the License at
293
+ *
294
+ * http://www.apache.org/licenses/LICENSE-2.0
295
+ *
296
+ * Unless required by applicable law or agreed to in writing, software
297
+ * distributed under the License is distributed on an "AS IS" BASIS,
298
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
299
+ * See the License for the specific language governing permissions and
300
+ * limitations under the License.
301
+ *
302
+ * =============================================================================
303
+ */
304
+ function H(s) {
305
+ return new G(s);
306
+ }
307
+ function q(s) {
308
+ return new W(s);
309
+ }
310
+ function Q(s, t) {
311
+ return new A(s, t);
312
+ }
313
+ class i {
314
+ /**
315
+ * Collect all remaining elements of a bounded stream into an array.
316
+ * Obviously this will succeed only for small streams that fit in memory.
317
+ * Useful for testing.
318
+ *
319
+ * @returns A Promise for an array of stream elements, which will resolve
320
+ * when the stream is exhausted.
321
+ */
322
+ async toArray() {
323
+ const t = [];
324
+ let e = await this.next();
325
+ for (; !e.done; )
326
+ t.push(e.value), e = await this.next();
327
+ return t;
328
+ }
329
+ /**
330
+ * Collect all elements of this dataset into an array with prefetching 100
331
+ * elements. This is useful for testing, because the prefetch changes the
332
+ * order in which the Promises are resolved along the processing pipeline.
333
+ * This may help expose bugs where results are dependent on the order of
334
+ * Promise resolution rather than on the logical order of the stream (i.e.,
335
+ * due to hidden mutable state).
336
+ *
337
+ * @returns A Promise for an array of stream elements, which will resolve
338
+ * when the stream is exhausted.
339
+ */
340
+ async toArrayForTest() {
341
+ const t = this.prefetch(100), e = [];
342
+ let r = await t.next();
343
+ for (; !r.done; )
344
+ e.push(r.value), r = await t.next();
345
+ return e;
346
+ }
347
+ /**
348
+ * Draw items from the stream until it is exhausted.
349
+ *
350
+ * This can be useful when the stream has side effects but no output. In
351
+ * that case, calling this function guarantees that the stream will be
352
+ * fully processed.
353
+ */
354
+ async resolveFully() {
355
+ let t = await this.next();
356
+ for (; !t.done; )
357
+ t = await this.next();
358
+ }
359
+ /**
360
+ * Draw items from the stream until it is exhausted, or a predicate fails.
361
+ *
362
+ * This can be useful when the stream has side effects but no output. In
363
+ * that case, calling this function guarantees that the stream will be
364
+ * fully processed.
365
+ */
366
+ async resolveWhile(t) {
367
+ let e = await this.next(), r = t(e.value);
368
+ for (; !e.done && r; )
369
+ e = await this.next(), r = t(e.value);
370
+ }
371
+ /**
372
+ * Handles errors thrown on this stream using a provided handler function.
373
+ *
374
+ * @param handler A function that handles any `Error` thrown during a `next()`
375
+ * call and returns true if the stream should continue (dropping the failed
376
+ * call) or false if the stream should quietly terminate. If the handler
377
+ * itself throws (or rethrows) an `Error`, that will be propagated.
378
+ *
379
+ * @returns A `LazyIterator` of elements passed through from upstream,
380
+ * possibly filtering or terminating on upstream `next()` calls that
381
+ * throw an `Error`.
382
+ */
383
+ handleErrors(t) {
384
+ return new j(this, t);
385
+ }
386
+ // TODO(soergel): Implement reduce() etc.
387
+ /**
388
+ * Filters this stream according to `predicate`.
389
+ *
390
+ * @param predicate A function mapping a stream element to a boolean or a
391
+ * `Promise` for one.
392
+ *
393
+ * @returns A `LazyIterator` of elements for which the predicate was true.
394
+ */
395
+ filter(t) {
396
+ return new X(this, t);
397
+ }
398
+ /**
399
+ * Maps this stream through a 1-to-1 transform.
400
+ *
401
+ * @param transform A function mapping a stream element to a transformed
402
+ * element.
403
+ *
404
+ * @returns A `LazyIterator` of transformed elements.
405
+ */
406
+ map(t) {
407
+ return new K(this, t);
408
+ }
409
+ /**
410
+ * Maps this stream through an async 1-to-1 transform.
411
+ *
412
+ * @param transform A function mapping a stream element to a `Promise` for a
413
+ * transformed stream element.
414
+ *
415
+ * @returns A `LazyIterator` of transformed elements.
416
+ */
417
+ mapAsync(t) {
418
+ return new w(this, t);
419
+ }
420
+ /**
421
+ * Maps this stream through a 1-to-1 transform, forcing serial execution.
422
+ *
423
+ * @param transform A function mapping a stream element to a transformed
424
+ * element.
425
+ *
426
+ * @returns A `LazyIterator` of transformed elements.
427
+ */
428
+ serialMapAsync(t) {
429
+ return new w(this, t).serial();
430
+ }
431
+ /**
432
+ * Maps this stream through a 1-to-many transform.
433
+ *
434
+ * @param transform A function mapping a stream element to an array of
435
+ * transformed elements.
436
+ *
437
+ * @returns A `DataStream` of transformed elements.
438
+ */
439
+ flatmap(t) {
440
+ return new tt(this, t);
441
+ }
442
+ /**
443
+ * Apply a function to every element of the stream.
444
+ *
445
+ * @param f A function to apply to each stream element.
446
+ */
447
+ async forEachAsync(t) {
448
+ return this.map(t).resolveFully();
449
+ }
450
+ /**
451
+ * Apply a function to every element of the stream, forcing serial execution.
452
+ *
453
+ * @param f A function to apply to each stream element. Should return 'true'
454
+ * to indicate that the stream should continue, or 'false' to cause it to
455
+ * terminate.
456
+ */
457
+ async serialForEach(t) {
458
+ return this.serialMapAsync(t).resolveWhile((e) => e === !0);
459
+ }
460
+ /**
461
+ * Groups elements into batches, represented as arrays of elements.
462
+ *
463
+ * We can think of the elements of this iterator as 'rows' (even if they are
464
+ * nested structures). By the same token, consecutive values for a given
465
+ * key within the elements form a 'column'. This matches the usual sense of
466
+ * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
467
+ *
468
+ * Thus, "Row-major" means that the resulting batch is simply a collection of
469
+ * rows: `[row1, row2, row3, ...]`. This is contrast to the column-major
470
+ * form, which is needed for vectorized computation.
471
+ *
472
+ * @param batchSize The number of elements desired per batch.
473
+ * @param smallLastBatch Whether to emit the final batch when it has fewer
474
+ * than batchSize elements. Default true.
475
+ * @returns A `LazyIterator` of batches of elements, represented as arrays
476
+ * of the original element type.
477
+ */
478
+ rowMajorBatch(t, e = !0) {
479
+ return new V(this, t, e);
480
+ }
481
+ /**
482
+ * Groups elements into batches, represented in column-major form.
483
+ *
484
+ * We can think of the elements of this iterator as 'rows' (even if they are
485
+ * nested structures). By the same token, consecutive values for a given
486
+ * key within the elements form a 'column'. This matches the usual sense of
487
+ * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
488
+ *
489
+ * Thus, "column-major" means that the resulting batch is a (potentially
490
+ * nested) structure representing the columns. Each column entry, then,
491
+ * contains a collection of the values found in that column for a range of
492
+ * input elements. This representation allows for vectorized computation, in
493
+ * contrast to the row-major form.
494
+ *
495
+ * The inputs should all have the same nested structure (i.e., of arrays and
496
+ * dicts). The result is a single object with the same nested structure,
497
+ * where the leaves are arrays collecting the values of the inputs at that
498
+ * location (or, optionally, the result of a custom function applied to those
499
+ * arrays).
500
+ *
501
+ * @param batchSize The number of elements desired per batch.
502
+ * @param smallLastBatch Whether to emit the final batch when it has fewer
503
+ * than batchSize elements. Default true.
504
+ * @param zipFn: (optional) A function that expects an array of elements at a
505
+ * single node of the object tree, and returns a `DeepMapResult`. The
506
+ * `DeepMapResult` either provides a result value for that node (i.e.,
507
+ * representing the subtree), or indicates that the node should be processed
508
+ * recursively. The default zipFn recurses as far as possible and places
509
+ * arrays at the leaves.
510
+ * @returns A `LazyIterator` of batches of elements, represented as an object
511
+ * with collections at the leaves.
512
+ */
513
+ columnMajorBatch(t, e = !0, r = I) {
514
+ return this.rowMajorBatch(t, e).map((a) => P(a, r));
515
+ }
516
+ /**
517
+ * Concatenate this `LazyIterator` with another.
518
+ *
519
+ * @param iterator A `LazyIterator` to be concatenated onto this one.
520
+ * @param baseErrorHandler An optional function that can intercept `Error`s
521
+ * raised during a `next()` call on the base stream. This function can
522
+ * decide whether the error should be propagated, whether the error should
523
+ * be ignored, or whether the base stream should be terminated.
524
+ * @returns A `LazyIterator`.
525
+ */
526
+ concatenate(t, e) {
527
+ return new A(H([this, t]), e);
528
+ }
529
+ /**
530
+ * Limits this stream to return at most `count` items.
531
+ *
532
+ * @param count The maximum number of items to provide from the stream. If
533
+ * a negative or undefined value is given, the entire stream is returned
534
+ * unaltered.
535
+ */
536
+ take(t) {
537
+ return t < 0 || t == null ? this : new J(this, t);
538
+ }
539
+ /**
540
+ * Skips the first `count` items in this stream.
541
+ *
542
+ * @param count The number of items to skip. If a negative or undefined
543
+ * value is given, the entire stream is returned unaltered.
544
+ */
545
+ skip(t) {
546
+ return t < 0 || t == null ? this : new Y(this, t);
547
+ }
548
+ /**
549
+ * Prefetch the first `bufferSize` items in this stream.
550
+ *
551
+ * Note this prefetches Promises, but makes no guarantees about when those
552
+ * Promises resolve.
553
+ *
554
+ * @param bufferSize: An integer specifying the number of elements to be
555
+ * prefetched.
556
+ */
557
+ prefetch(t) {
558
+ return new C(this, t);
559
+ }
560
+ // TODO(soergel): deep sharded shuffle, where supported
561
+ /**
562
+ * Randomly shuffles the elements of this stream.
563
+ *
564
+ * @param bufferSize: An integer specifying the number of elements from
565
+ * this stream from which the new stream will sample.
566
+ * @param seed: (Optional.) An integer specifying the random seed that
567
+ * will be used to create the distribution.
568
+ */
569
+ shuffle(t, e) {
570
+ return new et(this, t, e);
571
+ }
572
+ /**
573
+ * Force an iterator to execute serially: each next() call will await the
574
+ * prior one, so that they cannot execute concurrently.
575
+ */
576
+ serial() {
577
+ return new U(this);
578
+ }
579
+ }
580
+ class G extends i {
581
+ constructor(t) {
582
+ super(), this.items = t, this.trav = 0;
583
+ }
584
+ summary() {
585
+ return `Array of ${this.items.length} items`;
586
+ }
587
+ async next() {
588
+ if (this.trav >= this.items.length)
589
+ return { value: null, done: !0 };
590
+ const t = this.items[this.trav];
591
+ return this.trav++, { value: D(t), done: !1 };
592
+ }
593
+ }
594
+ class W extends i {
595
+ constructor(t) {
596
+ super(), this.nextFn = t;
597
+ }
598
+ summary() {
599
+ return "Function call";
600
+ }
601
+ async next() {
602
+ try {
603
+ return this.nextFn();
604
+ } catch (t) {
605
+ throw t.message = `Error thrown while iterating through a dataset: ${t.message}`, t;
606
+ }
607
+ }
608
+ }
609
+ class U extends i {
610
+ constructor(t) {
611
+ super(), this.upstream = t, this.lastRead = Promise.resolve({ value: null, done: !1 });
612
+ }
613
+ summary() {
614
+ return `${this.upstream.summary()} -> Serial`;
615
+ }
616
+ async next() {
617
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
618
+ }
619
+ async serialNext() {
620
+ return this.upstream.next();
621
+ }
622
+ }
623
+ class Y extends i {
624
+ constructor(t, e) {
625
+ super(), this.upstream = t, this.maxCount = e, this.count = 0, this.lastRead = Promise.resolve({ value: null, done: !1 });
626
+ }
627
+ summary() {
628
+ return `${this.upstream.summary()} -> Skip`;
629
+ }
630
+ async next() {
631
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
632
+ }
633
+ async serialNext() {
634
+ for (; this.count++ < this.maxCount; ) {
635
+ const t = await this.upstream.next();
636
+ if (t.done)
637
+ return t;
638
+ v(t.value);
639
+ }
640
+ return this.upstream.next();
641
+ }
642
+ }
643
+ class J extends i {
644
+ constructor(t, e) {
645
+ super(), this.upstream = t, this.maxCount = e, this.count = 0;
646
+ }
647
+ summary() {
648
+ return `${this.upstream.summary()} -> Take`;
649
+ }
650
+ async next() {
651
+ return this.count++ >= this.maxCount ? { value: null, done: !0 } : this.upstream.next();
652
+ }
653
+ }
654
+ class V extends i {
655
+ constructor(t, e, r = !0) {
656
+ super(), this.upstream = t, this.batchSize = e, this.enableSmallLastBatch = r, this.lastRead = Promise.resolve({ value: null, done: !1 });
657
+ }
658
+ summary() {
659
+ return `${this.upstream.summary()} -> RowMajorBatch`;
660
+ }
661
+ async next() {
662
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
663
+ }
664
+ async serialNext() {
665
+ const t = [];
666
+ for (; t.length < this.batchSize; ) {
667
+ const e = await this.upstream.next();
668
+ if (e.done)
669
+ return this.enableSmallLastBatch && t.length > 0 ? { value: t, done: !1 } : { value: null, done: !0 };
670
+ t.push(e.value);
671
+ }
672
+ return { value: t, done: !1 };
673
+ }
674
+ }
675
+ class X extends i {
676
+ constructor(t, e) {
677
+ super(), this.upstream = t, this.predicate = e, this.lastRead = Promise.resolve({ value: null, done: !1 });
678
+ }
679
+ summary() {
680
+ return `${this.upstream.summary()} -> Filter`;
681
+ }
682
+ async next() {
683
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
684
+ }
685
+ async serialNext() {
686
+ for (; ; ) {
687
+ const t = await this.upstream.next();
688
+ if (t.done || this.predicate(t.value))
689
+ return t;
690
+ v(t.value);
691
+ }
692
+ }
693
+ }
694
+ class K extends i {
695
+ constructor(t, e) {
696
+ super(), this.upstream = t, this.transform = e;
697
+ }
698
+ summary() {
699
+ return `${this.upstream.summary()} -> Map`;
700
+ }
701
+ async next() {
702
+ const t = await this.upstream.next();
703
+ if (t.done)
704
+ return { value: null, done: !0 };
705
+ const e = o(t.value), r = this.transform(t.value), n = o(r);
706
+ for (const a of e)
707
+ p(a, n) || a.dispose();
708
+ return { value: r, done: !1 };
709
+ }
710
+ }
711
+ class j extends i {
712
+ constructor(t, e) {
713
+ super(), this.upstream = t, this.handler = e, this.count = 0, this.lastRead = Promise.resolve({ value: null, done: !1 });
714
+ }
715
+ summary() {
716
+ return `${this.upstream.summary()} -> handleErrors`;
717
+ }
718
+ async next() {
719
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
720
+ }
721
+ async serialNext() {
722
+ for (; ; )
723
+ try {
724
+ return await this.upstream.next();
725
+ } catch (t) {
726
+ if (!this.handler(t))
727
+ return { value: null, done: !0 };
728
+ }
729
+ }
730
+ }
731
+ class w extends i {
732
+ constructor(t, e) {
733
+ super(), this.upstream = t, this.transform = e;
734
+ }
735
+ summary() {
736
+ return `${this.upstream.summary()} -> AsyncMap`;
737
+ }
738
+ async next() {
739
+ const t = await this.upstream.next();
740
+ if (t.done)
741
+ return { value: null, done: !0 };
742
+ const e = o(t.value), r = await this.transform(t.value), n = o(r);
743
+ for (const a of e)
744
+ p(a, n) || a.dispose();
745
+ return { value: r, done: !1 };
746
+ }
747
+ }
748
+ class Z extends i {
749
+ constructor() {
750
+ super(), this.outputQueue = new f(), this.lastRead = Promise.resolve({ value: null, done: !1 });
751
+ }
752
+ async next() {
753
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
754
+ }
755
+ async serialNext() {
756
+ for (; this.outputQueue.length() === 0; )
757
+ if (!await this.pump())
758
+ return { value: null, done: !0 };
759
+ return { value: this.outputQueue.shift(), done: !1 };
760
+ }
761
+ }
762
+ class tt extends Z {
763
+ constructor(t, e) {
764
+ super(), this.upstream = t, this.transform = e;
765
+ }
766
+ summary() {
767
+ return `${this.upstream.summary()} -> Flatmap`;
768
+ }
769
+ async pump() {
770
+ const t = await this.upstream.next();
771
+ if (t.done)
772
+ return !1;
773
+ const e = o(t.value), r = this.transform(t.value), n = o(r);
774
+ this.outputQueue.pushAll(r);
775
+ for (const a of e)
776
+ p(a, n) || a.dispose();
777
+ return !0;
778
+ }
779
+ }
780
+ class A extends i {
781
+ constructor(t, e) {
782
+ super(), this.baseErrorHandler = e, this.lastRead = null, this.iterator = null, this.moreIterators = t;
783
+ }
784
+ summary() {
785
+ return "TODO: fill in upstream of chained summaries -> Chained";
786
+ }
787
+ async next() {
788
+ return this.lastRead = this.readFromChain(this.lastRead), this.lastRead;
789
+ }
790
+ async readFromChain(t) {
791
+ if (await t, this.iterator == null) {
792
+ const r = await this.moreIterators.next();
793
+ if (r.done)
794
+ return { value: null, done: !0 };
795
+ this.iterator = r.value, this.baseErrorHandler != null && (this.iterator = this.iterator.handleErrors(this.baseErrorHandler));
796
+ }
797
+ const e = await this.iterator.next();
798
+ return e.done ? (this.iterator = null, this.readFromChain(t)) : e;
799
+ }
800
+ }
801
+ var x;
802
+ (function(s) {
803
+ s[s.FAIL = 0] = "FAIL", s[s.SHORTEST = 1] = "SHORTEST", s[s.LONGEST = 2] = "LONGEST";
804
+ })(x || (x = {}));
805
+ class C extends i {
806
+ constructor(t, e) {
807
+ super(), this.upstream = t, this.bufferSize = e, this.buffer = new z(e);
808
+ }
809
+ summary() {
810
+ return `${this.upstream.summary()} -> Prefetch`;
811
+ }
812
+ /**
813
+ * Refill the prefetch buffer. Returns only after the buffer is full, or
814
+ * the upstream source is exhausted.
815
+ */
816
+ refill() {
817
+ for (; !this.buffer.isFull(); ) {
818
+ const t = this.upstream.next();
819
+ this.buffer.push(t);
820
+ }
821
+ }
822
+ next() {
823
+ return this.refill(), this.buffer.shift();
824
+ }
825
+ }
826
+ class et extends C {
827
+ constructor(t, e, r) {
828
+ super(t, e), this.upstream = t, this.windowSize = e, this.upstreamExhausted = !1, this.random = R.alea(r || g().toString()), this.lastRead = Promise.resolve({ value: null, done: !1 });
829
+ }
830
+ async next() {
831
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
832
+ }
833
+ randomInt(t) {
834
+ return Math.floor(this.random() * t);
835
+ }
836
+ chooseIndex() {
837
+ return this.randomInt(this.buffer.length());
838
+ }
839
+ async serialNext() {
840
+ for (this.upstreamExhausted || this.refill(); !this.buffer.isEmpty(); ) {
841
+ const t = this.chooseIndex(), e = await this.buffer.shuffleExcise(t);
842
+ if (e.done)
843
+ this.upstreamExhausted = !0;
844
+ else
845
+ return this.refill(), e;
846
+ }
847
+ return { value: null, done: !0 };
848
+ }
849
+ }
850
+ /**
851
+ * @license
852
+ * Copyright 2018 Google LLC. All Rights Reserved.
853
+ * Licensed under the Apache License, Version 2.0 (the "License");
854
+ * you may not use this file except in compliance with the License.
855
+ * You may obtain a copy of the License at
856
+ *
857
+ * http://www.apache.org/licenses/LICENSE-2.0
858
+ *
859
+ * Unless required by applicable law or agreed to in writing, software
860
+ * distributed under the License is distributed on an "AS IS" BASIS,
861
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
862
+ * See the License for the specific language governing permissions and
863
+ * limitations under the License.
864
+ *
865
+ * =============================================================================
866
+ */
867
+ class T {
868
+ constructor() {
869
+ this.size = null;
870
+ }
871
+ // TODO(soergel): Make Datasets report whether repeated iterator() calls
872
+ // produce the same result (e.g., reading from a file) or different results
873
+ // (e.g., from the webcam). Currently we don't make this distinction but it
874
+ // could be important for the user to know.
875
+ // abstract isDeterministic(): boolean;
876
+ /**
877
+ * Groups elements into batches.
878
+ *
879
+ * It is assumed that each of the incoming dataset elements has the same
880
+ * structure -- i.e. the same set of keys at each location in an object
881
+ * hierarchy. For each key, the resulting `Dataset` provides a batched
882
+ * element collecting all of the incoming values for that key.
883
+ *
884
+ * * Incoming primitives are grouped into a 1-D Tensor.
885
+ * * Incoming Tensors are grouped into a new Tensor where the 0th axis is
886
+ * the batch dimension.
887
+ * * Incoming arrays are converted to Tensor and then batched.
888
+ * * A nested array is interpreted as an n-D Tensor, so the batched result
889
+ * has n+1 dimensions.
890
+ * * An array that cannot be converted to Tensor produces an error.
891
+ *
892
+ * If an array should not be batched as a unit, it should first be converted
893
+ * to an object with integer keys.
894
+ *
895
+ * Here are a few examples:
896
+ *
897
+ * Batch a dataset of numbers:
898
+ * ```js
899
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8]).batch(4);
900
+ * await a.forEachAsync(e => e.print());
901
+ * ```
902
+ *
903
+ * Batch a dataset of arrays:
904
+ * ```js
905
+ * const b = tf.data.array([[1], [2], [3], [4], [5], [6], [7], [8]]).batch(4);
906
+ * await b.forEachAsync(e => e.print());
907
+ * ```
908
+ *
909
+ * Batch a dataset of objects:
910
+ * ```js
911
+ * const c = tf.data.array([{a: 1, b: 11}, {a: 2, b: 12}, {a: 3, b: 13},
912
+ * {a: 4, b: 14}, {a: 5, b: 15}, {a: 6, b: 16}, {a: 7, b: 17},
913
+ * {a: 8, b: 18}]).batch(4);
914
+ * await c.forEachAsync(e => {
915
+ * console.log('{');
916
+ * for(var key in e) {
917
+ * console.log(key+':');
918
+ * e[key].print();
919
+ * }
920
+ * console.log('}');
921
+ * })
922
+ * ```
923
+ *
924
+ * @param batchSize The number of elements desired per batch.
925
+ * @param smallLastBatch Whether to emit the final batch when it has fewer
926
+ * than batchSize elements. Default true.
927
+ * @returns A `Dataset`, from which a stream of batches can be obtained.
928
+ *
929
+ * @doc {heading: 'Data', subheading: 'Classes'}
930
+ */
931
+ batch(t, e = !0) {
932
+ const r = this;
933
+ k(t > 0, () => `batchSize needs to be positive, but it is
934
+ ${t}`);
935
+ let n;
936
+ return this.size === 1 / 0 || this.size == null ? n = this.size : e ? n = Math.ceil(this.size / t) : n = Math.floor(this.size / t), u(async () => (await r.iterator()).columnMajorBatch(t, e, st), n);
937
+ }
938
+ /**
939
+ * Concatenates this `Dataset` with another.
940
+ *
941
+ * ```js
942
+ * const a = tf.data.array([1, 2, 3]);
943
+ * const b = tf.data.array([4, 5, 6]);
944
+ * const c = a.concatenate(b);
945
+ * await c.forEachAsync(e => console.log(e));
946
+ * ```
947
+ *
948
+ * @param dataset A `Dataset` to be concatenated onto this one.
949
+ * @returns A `Dataset`.
950
+ *
951
+ * @doc {heading: 'Data', subheading: 'Classes'}
952
+ */
953
+ concatenate(t) {
954
+ const e = this;
955
+ let r;
956
+ return this.size === 1 / 0 || t.size === 1 / 0 ? r = 1 / 0 : this.size != null && t.size != null ? r = this.size + t.size : r = null, u(async () => (await e.iterator()).concatenate(await t.iterator()), r);
957
+ }
958
+ /**
959
+ * Filters this dataset according to `predicate`.
960
+ *
961
+ * ```js
962
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
963
+ * .filter(x => x%2 === 0);
964
+ * await a.forEachAsync(e => console.log(e));
965
+ * ```
966
+ *
967
+ * @param predicate A function mapping a dataset element to a boolean or a
968
+ * `Promise` for one.
969
+ *
970
+ * @returns A `Dataset` of elements for which the predicate was true.
971
+ *
972
+ * @doc {heading: 'Data', subheading: 'Classes'}
973
+ */
974
+ filter(t) {
975
+ const e = this;
976
+ let r;
977
+ return this.size === 1 / 0 ? r = 1 / 0 : r = null, u(async () => (await e.iterator()).filter((n) => y(() => t(n))), r);
978
+ }
979
+ /**
980
+ * Apply a function to every element of the dataset.
981
+ *
982
+ * After the function is applied to a dataset element, any Tensors contained
983
+ * within that element are disposed.
984
+ *
985
+ * ```js
986
+ * const a = tf.data.array([1, 2, 3]);
987
+ * await a.forEachAsync(e => console.log(e));
988
+ * ```
989
+ *
990
+ * @param f A function to apply to each dataset element.
991
+ * @returns A `Promise` that resolves after all elements have been processed.
992
+ *
993
+ * @doc {heading: 'Data', subheading: 'Classes'}
994
+ */
995
+ async forEachAsync(t) {
996
+ return (await this.iterator()).forEachAsync(t);
997
+ }
998
+ /**
999
+ * Maps this dataset through a 1-to-1 transform.
1000
+ *
1001
+ * ```js
1002
+ * const a = tf.data.array([1, 2, 3]).map(x => x*x);
1003
+ * await a.forEachAsync(e => console.log(e));
1004
+ * ```
1005
+ *
1006
+ * @param transform A function mapping a dataset element to a transformed
1007
+ * dataset element.
1008
+ *
1009
+ * @returns A `Dataset` of transformed elements.
1010
+ *
1011
+ * @doc {heading: 'Data', subheading: 'Classes'}
1012
+ */
1013
+ map(t) {
1014
+ const e = this;
1015
+ return u(async () => (await e.iterator()).map((r) => y(() => t(r))), this.size);
1016
+ }
1017
+ /**
1018
+ * Maps this dataset through an async 1-to-1 transform.
1019
+ *
1020
+ * ```js
1021
+ * const a =
1022
+ * tf.data.array([1, 2, 3]).mapAsync(x => new Promise(function(resolve){
1023
+ * setTimeout(() => {
1024
+ * resolve(x * x);
1025
+ * }, Math.random()*1000 + 500);
1026
+ * }));
1027
+ * console.log(await a.toArray());
1028
+ * ```
1029
+ *
1030
+ * @param transform A function mapping a dataset element to a `Promise` for a
1031
+ * transformed dataset element. This transform is responsible for disposing
1032
+ * any intermediate `Tensor`s, i.e. by wrapping its computation in
1033
+ * `tf.tidy()`; that cannot be automated here (as it is in the synchronous
1034
+ * `map()` case).
1035
+ *
1036
+ * @returns A `Dataset` of transformed elements.
1037
+ *
1038
+ * @doc {heading: 'Data', subheading: 'Classes'}
1039
+ */
1040
+ mapAsync(t) {
1041
+ const e = this;
1042
+ return u(async () => (await e.iterator()).mapAsync(t), this.size);
1043
+ }
1044
+ /**
1045
+ * Creates a `Dataset` that prefetches elements from this dataset.
1046
+ *
1047
+ * @param bufferSize: An integer specifying the number of elements to be
1048
+ * prefetched.
1049
+ * @returns A `Dataset`.
1050
+ *
1051
+ * @doc {heading: 'Data', subheading: 'Classes'}
1052
+ */
1053
+ prefetch(t) {
1054
+ if (t == null)
1055
+ throw new RangeError("`Dataset.prefetch()` requires bufferSize to be specified.");
1056
+ const e = this;
1057
+ return u(async () => (await e.iterator()).prefetch(t), this.size);
1058
+ }
1059
+ /**
1060
+ * Repeats this dataset `count` times.
1061
+ *
1062
+ * NOTE: If this dataset is a function of global state (e.g. a random number
1063
+ * generator), then different repetitions may produce different elements.
1064
+ *
1065
+ * ```js
1066
+ * const a = tf.data.array([1, 2, 3]).repeat(3);
1067
+ * await a.forEachAsync(e => console.log(e));
1068
+ * ```
1069
+ *
1070
+ * @param count: (Optional) An integer, representing the number of times
1071
+ * the dataset should be repeated. The default behavior (if `count` is
1072
+ * `undefined` or negative) is for the dataset be repeated indefinitely.
1073
+ * @returns A `Dataset`.
1074
+ *
1075
+ * @doc {heading: 'Data', subheading: 'Classes'}
1076
+ */
1077
+ repeat(t) {
1078
+ const e = this;
1079
+ let r;
1080
+ return this.size != null && t > 0 ? r = this.size * t : t === 0 ? r = 0 : this.size != null && (t === void 0 || t < 0) ? r = 1 / 0 : r = null, u(async () => {
1081
+ const n = q(async () => ({ value: await e.iterator(), done: !1 }));
1082
+ return Q(n.take(t));
1083
+ }, r);
1084
+ }
1085
+ /**
1086
+ * Creates a `Dataset` that skips `count` initial elements from this dataset.
1087
+ *
1088
+ * ```js
1089
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6]).skip(3);
1090
+ * await a.forEachAsync(e => console.log(e));
1091
+ * ```
1092
+ *
1093
+ * @param count: The number of elements of this dataset that should be skipped
1094
+ * to form the new dataset. If `count` is greater than the size of this
1095
+ * dataset, the new dataset will contain no elements. If `count`
1096
+ * is `undefined` or negative, skips the entire dataset.
1097
+ *
1098
+ * @returns A `Dataset`.
1099
+ *
1100
+ * @doc {heading: 'Data', subheading: 'Classes'}
1101
+ */
1102
+ skip(t) {
1103
+ const e = this;
1104
+ let r;
1105
+ return this.size != null && t >= 0 && this.size >= t ? r = this.size - t : this.size != null && (this.size < t || t === void 0 || t < 0) ? r = 0 : r = null, u(async () => (await e.iterator()).skip(t), r);
1106
+ }
1107
+ /**
1108
+ * Pseudorandomly shuffles the elements of this dataset. This is done in a
1109
+ * streaming manner, by sampling from a given number of prefetched elements.
1110
+ *
1111
+ * ```js
1112
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6]).shuffle(3);
1113
+ * await a.forEachAsync(e => console.log(e));
1114
+ * ```
1115
+ *
1116
+ * @param bufferSize: An integer specifying the number of elements from this
1117
+ * dataset from which the new dataset will sample.
1118
+ * @param seed: (Optional) An integer specifying the random seed that will
1119
+ * be used to create the distribution.
1120
+ * @param reshuffleEachIteration: (Optional) A boolean, which if true
1121
+ * indicates that the dataset should be pseudorandomly reshuffled each time
1122
+ * it is iterated over. If false, elements will be returned in the same
1123
+ * shuffled order on each iteration. (Defaults to `true`.)
1124
+ * @returns A `Dataset`.
1125
+ *
1126
+ * @doc {heading: 'Data', subheading: 'Classes'}
1127
+ */
1128
+ shuffle(t, e, r = !0) {
1129
+ if (t == null || t < 0)
1130
+ throw this.size == null ? new RangeError("`Dataset.shuffle()` requires bufferSize to be specified.") : new RangeError(`\`Dataset.shuffle()\` requires bufferSize to be specified. If your data fits in main memory (for regular JS objects), and/or GPU memory (for \`tf.Tensor\`s), consider setting bufferSize to the dataset size (${this.size} elements)`);
1131
+ const n = this, a = R.alea(e || g().toString());
1132
+ return u(async () => {
1133
+ let l = a.int32();
1134
+ return r && (l += a.int32()), (await n.iterator()).shuffle(t, l.toString());
1135
+ }, this.size);
1136
+ }
1137
+ /**
1138
+ * Creates a `Dataset` with at most `count` initial elements from this
1139
+ * dataset.
1140
+ *
1141
+ * ```js
1142
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6]).take(3);
1143
+ * await a.forEachAsync(e => console.log(e));
1144
+ * ```
1145
+ *
1146
+ * @param count: The number of elements of this dataset that should be taken
1147
+ * to form the new dataset. If `count` is `undefined` or negative, or if
1148
+ * `count` is greater than the size of this dataset, the new dataset will
1149
+ * contain all elements of this dataset.
1150
+ * @returns A `Dataset`.
1151
+ *
1152
+ * @doc {heading: 'Data', subheading: 'Classes'}
1153
+ */
1154
+ take(t) {
1155
+ const e = this;
1156
+ let r;
1157
+ return this.size != null && this.size > t ? r = t : this.size != null && this.size <= t ? r = this.size : r = null, u(async () => (await e.iterator()).take(t), r);
1158
+ }
1159
+ /**
1160
+ * Collect all elements of this dataset into an array.
1161
+ *
1162
+ * Obviously this will succeed only for small datasets that fit in memory.
1163
+ * Useful for testing and generally should be avoided if possible.
1164
+ *
1165
+ * ```js
1166
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6]);
1167
+ * console.log(await a.toArray());
1168
+ * ```
1169
+ *
1170
+ * @returns A Promise for an array of elements, which will resolve
1171
+ * when a new stream has been obtained and fully consumed.
1172
+ *
1173
+ * @doc {heading: 'Data', subheading: 'Classes'}
1174
+ */
1175
+ async toArray() {
1176
+ if (this.size === 1 / 0)
1177
+ throw new Error("Can not convert infinite data stream to array.");
1178
+ return (await this.iterator()).toArray();
1179
+ }
1180
+ /**
1181
+ * Collect all elements of this dataset into an array with prefetching 100
1182
+ * elements. This is useful for testing, because the prefetch changes the
1183
+ * order in which the Promises are resolved along the processing pipeline.
1184
+ * This may help expose bugs where results are dependent on the order of
1185
+ * Promise resolution rather than on the logical order of the stream (i.e.,
1186
+ * due to hidden mutable state).
1187
+ *
1188
+ * @returns A Promise for an array of elements, which will resolve
1189
+ * when a new stream has been obtained and fully consumed.
1190
+ */
1191
+ async toArrayForTest() {
1192
+ if (this.size === 1 / 0)
1193
+ throw new Error("Can not convert infinite data stream to array.");
1194
+ return (await this.iterator()).toArrayForTest();
1195
+ }
1196
+ }
1197
+ T.MAX_BUFFER_SIZE = 1e4;
1198
+ function u(s, t = null) {
1199
+ return new class extends T {
1200
+ constructor() {
1201
+ super(...arguments), this.size = t;
1202
+ }
1203
+ /*
1204
+ * Provide a new stream of elements. Note this will also start new streams
1205
+ * from any underlying `Dataset`s.
1206
+ */
1207
+ async iterator() {
1208
+ return s();
1209
+ }
1210
+ }();
1211
+ }
1212
+ function st(s) {
1213
+ if (s === null)
1214
+ return null;
1215
+ const t = s[0];
1216
+ return M(t) ? { value: rt(s), recurse: !1 } : { value: null, recurse: !0 };
1217
+ }
1218
+ function rt(s) {
1219
+ if (s.length === 0)
1220
+ throw new Error("Can't make a batch of zero elements.");
1221
+ return s[0] instanceof h ? $(s) : B(s);
1222
+ }
1223
+ export {
1224
+ u as d,
1225
+ q as i
1226
+ };