@genai-fi/nanogpt 0.2.12 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. package/dist/Generator.js +30 -25
  2. package/dist/NanoGPTModel.d.ts +13 -14
  3. package/dist/NanoGPTModel.js +167 -85
  4. package/dist/TeachableLLM.d.ts +3 -5
  5. package/dist/TeachableLLM.js +47 -35
  6. package/dist/Trainer.js +8 -8
  7. package/dist/concat-BIZS_td9.js +33 -0
  8. package/dist/data/parquet.js +1 -1
  9. package/dist/exports_layers-7idKoYqh.js +25 -0
  10. package/dist/{sum-D7fu15XL.js → gather-BPGW8RsB.js} +6 -8
  11. package/dist/index-C4L8Cm77.js +349 -0
  12. package/dist/{index-YPKosni4.js → index-pWA4_lUh.js} +1020 -782
  13. package/dist/layers/CausalSelfAttention.d.ts +11 -11
  14. package/dist/layers/CausalSelfAttention.js +71 -63
  15. package/dist/layers/MLP.d.ts +6 -7
  16. package/dist/layers/MLP.js +18 -16
  17. package/dist/layers/RMSNorm.d.ts +6 -7
  18. package/dist/layers/RMSNorm.js +15 -13
  19. package/dist/layers/RoPECache.d.ts +4 -5
  20. package/dist/layers/RoPECache.js +36 -12
  21. package/dist/layers/TiedEmbedding.d.ts +7 -8
  22. package/dist/layers/TiedEmbedding.js +16 -418
  23. package/dist/layers/TransformerBlock.d.ts +8 -9
  24. package/dist/layers/TransformerBlock.js +12 -12
  25. package/dist/main.d.ts +1 -0
  26. package/dist/main.js +35 -21
  27. package/dist/{mat_mul-Bu7bhLms.js → mat_mul-D7_a4KJn.js} +5 -5
  28. package/dist/moments-DfcpfwKi.js +132 -0
  29. package/dist/ones-Cog-G2ag.js +29 -0
  30. package/dist/ops/appendCache.d.ts +2 -0
  31. package/dist/ops/appendCache.js +9 -0
  32. package/dist/ops/attentionMask.d.ts +1 -1
  33. package/dist/ops/attentionMask.js +7 -85
  34. package/dist/ops/cpu/appendCache.d.ts +2 -0
  35. package/dist/ops/cpu/appendCache.js +28 -0
  36. package/dist/ops/cpu/attentionMask.js +18 -0
  37. package/dist/ops/cpu/gatherSub.d.ts +1 -0
  38. package/dist/ops/cpu/gatherSub.js +34 -0
  39. package/dist/ops/cpu/qkv.d.ts +5 -0
  40. package/dist/ops/cpu/qkv.js +38 -0
  41. package/dist/ops/cpu/rope.d.ts +6 -0
  42. package/dist/ops/cpu/rope.js +38 -0
  43. package/dist/ops/cpu/scatterSub.d.ts +1 -0
  44. package/dist/ops/cpu/scatterSub.js +70 -0
  45. package/dist/ops/gatherSub.d.ts +1 -1
  46. package/dist/ops/gatherSub.js +6 -63
  47. package/dist/ops/grads/attentionMask.d.ts +1 -0
  48. package/dist/ops/grads/attentionMask.js +21 -0
  49. package/dist/ops/grads/qkv.d.ts +1 -0
  50. package/dist/ops/grads/qkv.js +20 -0
  51. package/dist/ops/grads/rope.d.ts +1 -0
  52. package/dist/ops/grads/rope.js +14 -0
  53. package/dist/ops/node/sparseCrossEntropy.js +1 -1
  54. package/dist/ops/qkv.d.ts +1 -6
  55. package/dist/ops/qkv.js +7 -124
  56. package/dist/ops/rope.d.ts +0 -5
  57. package/dist/ops/rope.js +7 -151
  58. package/dist/ops/scatterSub.d.ts +1 -1
  59. package/dist/ops/scatterSub.js +6 -147
  60. package/dist/ops/webgl/appendCache.d.ts +1 -0
  61. package/dist/ops/webgl/appendCache.js +43 -0
  62. package/dist/ops/webgl/attentionMask.d.ts +1 -0
  63. package/dist/ops/webgl/attentionMask.js +43 -0
  64. package/dist/ops/webgl/gatherSub.d.ts +1 -0
  65. package/dist/ops/webgl/gatherSub.js +27 -0
  66. package/dist/ops/webgl/qkv.d.ts +1 -0
  67. package/dist/ops/webgl/qkv.js +46 -0
  68. package/dist/ops/webgl/rope.d.ts +1 -0
  69. package/dist/ops/webgl/rope.js +56 -0
  70. package/dist/ops/webgl/scatterSub.d.ts +1 -0
  71. package/dist/ops/webgl/scatterSub.js +27 -0
  72. package/dist/{parquet-BRl5lE_I.js → parquet-C0Tlmv9c.js} +3045 -3048
  73. package/dist/random_width-PbCt7RXv.js +15489 -0
  74. package/dist/range-CcDl05lo.js +26 -0
  75. package/dist/{reshape-DmnmKT6r.js → reshape-C8CR_Bad.js} +3 -3
  76. package/dist/sin-BJIrfnj7.js +47 -0
  77. package/dist/softmax-Be_lsqUc.js +105 -0
  78. package/dist/{complex-CJ-qCcLB.js → split-DZbvruEP.js} +6 -8
  79. package/dist/stack-BMm-efee.js +27 -0
  80. package/dist/sum-C7Mgy9Bw.js +104 -0
  81. package/dist/tensor-DJVbYhh1.js +24 -0
  82. package/dist/tensor2d-ZuQSh2D-.js +30 -0
  83. package/dist/tokeniser/bpe.d.ts +17 -6
  84. package/dist/tokeniser/bpe.js +88 -60
  85. package/dist/training/AdamExt.js +1 -1
  86. package/dist/training/DatasetBuilder.d.ts +6 -6
  87. package/dist/training/DatasetBuilder.js +1262 -17
  88. package/dist/training/Evaluator.d.ts +3 -2
  89. package/dist/training/FullTrainer.d.ts +9 -8
  90. package/dist/training/FullTrainer.js +26 -25
  91. package/dist/training/LayerTrainer.d.ts +9 -8
  92. package/dist/training/LayerTrainer.js +34 -33
  93. package/dist/training/Trainer.d.ts +22 -21
  94. package/dist/training/Trainer.js +21 -18
  95. package/dist/training/sparseCrossEntropy.js +22 -166
  96. package/dist/utilities/dummy.js +10 -8
  97. package/dist/utilities/generate.js +14 -11
  98. package/dist/utilities/load.d.ts +1 -2
  99. package/dist/utilities/load.js +37 -35
  100. package/dist/utilities/profile.js +1 -1
  101. package/dist/utilities/save.js +14 -9
  102. package/dist/utilities/tokenParse.d.ts +1 -1
  103. package/dist/utilities/tokenParse.js +7 -61
  104. package/dist/utilities/weights.d.ts +3 -3
  105. package/dist/utilities/weights.js +21 -19
  106. package/dist/variable-Dl_ub3pk.js +23 -0
  107. package/dist/{stack-BtKpB0Ry.js → zeros-CCy9C3uU.js} +18 -16
  108. package/package.json +2 -1
  109. package/dist/assets/worker-BYeSPNkq.js +0 -1
  110. package/dist/tokeniser/NodeTokeniser.d.ts +0 -20
  111. package/dist/tokeniser/NodeTokeniser.js +0 -46
  112. package/dist/tokeniser/WebTokeniser.d.ts +0 -18
  113. package/dist/tokeniser/WebTokeniser.js +0 -96
  114. package/dist/tokeniser/worker.js +0 -53
  115. /package/dist/{tokeniser/worker.d.ts → ops/cpu/attentionMask.d.ts} +0 -0
@@ -1,31 +1,1276 @@
1
- class z {
1
+ import { a7 as $, a8 as m, a9 as M, a as R, aa as f, ab as v, ac as z, j as _, t as x } from "../index-pWA4_lUh.js";
2
+ import { s as E } from "../index-C4L8Cm77.js";
3
+ import { s as P } from "../stack-BMm-efee.js";
4
+ import { t as D } from "../tensor-DJVbYhh1.js";
5
+ import "../index-Tf7vU29b.js";
6
+ /**
7
+ * @license
8
+ * Copyright 2018 Google LLC. All Rights Reserved.
9
+ * Licensed under the Apache License, Version 2.0 (the "License");
10
+ * you may not use this file except in compliance with the License.
11
+ * You may obtain a copy of the License at
12
+ *
13
+ * http://www.apache.org/licenses/LICENSE-2.0
14
+ *
15
+ * Unless required by applicable law or agreed to in writing, software
16
+ * distributed under the License is distributed on an "AS IS" BASIS,
17
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ * See the License for the specific language governing permissions and
19
+ * limitations under the License.
20
+ *
21
+ * =============================================================================
22
+ */
23
+ function L(s, t) {
24
+ return I(s, t);
25
+ }
26
+ function I(s, t, e = /* @__PURE__ */ new Map(), r = /* @__PURE__ */ new Set()) {
27
+ if (s == null)
28
+ return null;
29
+ if (typeof Blob == "function" && s instanceof Blob)
30
+ return s.slice();
31
+ if (r.has(s))
32
+ throw new Error("Circular references are not supported.");
33
+ if (e.has(s))
34
+ return e.get(s);
35
+ const n = t(s);
36
+ if (n.recurse && n.value !== null)
37
+ throw new Error("A deep map function may not return both a value and recurse=true.");
38
+ if (n.recurse)
39
+ if (p(s)) {
40
+ const a = Array.isArray(s) ? [] : {};
41
+ r.add(s);
42
+ for (const l in s) {
43
+ const h = s[l], c = I(h, t, e, r);
44
+ a[l] = c;
45
+ }
46
+ return r.delete(s), s.__proto__ && (a.__proto__ = s.__proto__), a;
47
+ } else
48
+ throw new Error(`Can't recurse into non-iterable type: ${s}`);
49
+ else return e.set(s, n.value), n.value;
50
+ }
51
+ function O(s, t = T) {
52
+ return A(s, t);
53
+ }
54
+ function A(s, t, e = /* @__PURE__ */ new Set()) {
55
+ const r = s[0];
56
+ if (e.has(r))
57
+ throw new Error("Circular references are not supported.");
58
+ const n = t(s);
59
+ if (n.recurse && n.value !== null)
60
+ throw new Error("A deep zip function may not return both a value and recurse=true.");
61
+ if (n.recurse)
62
+ if (p(r)) {
63
+ const a = Array.isArray(r) ? [] : {};
64
+ e.add(r);
65
+ for (const l in r) {
66
+ const h = s.map((w) => w[l]), c = A(h, t, e);
67
+ a[l] = c;
68
+ }
69
+ return e.delete(r), a;
70
+ } else
71
+ throw new Error(`Can't recurse into non-iterable type: ${r}`);
72
+ else return n.value;
73
+ }
74
+ function T(s) {
75
+ return s === null ? null : p(s[0]) ? { value: null, recurse: !0 } : { value: s, recurse: !1 };
76
+ }
77
+ function p(s) {
78
+ let t = !1;
79
+ if (M().get("IS_BROWSER"))
80
+ t = s instanceof TextDecoder;
81
+ else {
82
+ const { StringDecoder: e } = require("string_decoder");
83
+ t = s instanceof e;
84
+ }
85
+ return s != null && !ArrayBuffer.isView(s) && (Array.isArray(s) || typeof s == "object" && !(s instanceof m) && !(s instanceof Promise) && !t);
86
+ }
87
+ function H(s) {
88
+ return s == null || q(s) || Array.isArray(s) || typeof s == "object" && s instanceof m || $(s);
89
+ }
90
+ function q(s) {
91
+ return s === null || typeof s != "object" && typeof s != "function";
92
+ }
93
+ /**
94
+ * @license
95
+ * Copyright 2018 Google LLC. All Rights Reserved.
96
+ * Licensed under the Apache License, Version 2.0 (the "License");
97
+ * you may not use this file except in compliance with the License.
98
+ * You may obtain a copy of the License at
99
+ *
100
+ * http://www.apache.org/licenses/LICENSE-2.0
101
+ *
102
+ * Unless required by applicable law or agreed to in writing, software
103
+ * distributed under the License is distributed on an "AS IS" BASIS,
104
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
105
+ * See the License for the specific language governing permissions and
106
+ * limitations under the License.
107
+ *
108
+ * =============================================================================
109
+ */
110
+ function Q(s) {
111
+ return L(s, G);
112
+ }
113
+ function G(s) {
114
+ return s instanceof m ? { value: s.clone(), recurse: !1 } : p(s) ? { value: null, recurse: !0 } : { value: s, recurse: !1 };
115
+ }
116
+ /**
117
+ * @license
118
+ * Copyright 2018 Google LLC. All Rights Reserved.
119
+ * Licensed under the Apache License, Version 2.0 (the "License");
120
+ * you may not use this file except in compliance with the License.
121
+ * You may obtain a copy of the License at
122
+ *
123
+ * http://www.apache.org/licenses/LICENSE-2.0
124
+ *
125
+ * Unless required by applicable law or agreed to in writing, software
126
+ * distributed under the License is distributed on an "AS IS" BASIS,
127
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
128
+ * See the License for the specific language governing permissions and
129
+ * limitations under the License.
130
+ *
131
+ * =============================================================================
132
+ */
133
+ class C {
134
+ /**
135
+ * Constructs a `RingBuffer`.
136
+ * @param capacity The number of items that the buffer can accomodate.
137
+ */
138
+ constructor(t) {
139
+ if (this.capacity = t, this.begin = 0, this.end = 0, t == null)
140
+ throw new RangeError("Can't create a ring buffer of unknown capacity.");
141
+ if (t < 1)
142
+ throw new RangeError("Can't create ring buffer of capacity < 1.");
143
+ this.data = new Array(t), this.doubledCapacity = 2 * t;
144
+ }
145
+ /**
146
+ * Map any index into the range 0 <= index < 2*capacity.
147
+ */
148
+ wrap(t) {
149
+ for (; t < 0; )
150
+ t += this.doubledCapacity;
151
+ return t % this.doubledCapacity;
152
+ }
153
+ get(t) {
154
+ if (t < 0)
155
+ throw new RangeError("Can't get item at a negative index.");
156
+ return this.data[t % this.capacity];
157
+ }
158
+ set(t, e) {
159
+ if (t < 0)
160
+ throw new RangeError("Can't set item at a negative index.");
161
+ this.data[t % this.capacity] = e;
162
+ }
163
+ /**
164
+ * Returns the current number of items in the buffer.
165
+ */
166
+ length() {
167
+ let t = this.end - this.begin;
168
+ return t < 0 && (t = this.doubledCapacity + t), t;
169
+ }
170
+ /**
171
+ * Reports whether the buffer is full.
172
+ * @returns true if the number of items in the buffer equals its capacity, and
173
+ * false otherwise.
174
+ */
175
+ isFull() {
176
+ return this.length() === this.capacity;
177
+ }
178
+ /**
179
+ * Reports whether the buffer is empty.
180
+ * @returns true if the number of items in the buffer equals zero, and
181
+ * false otherwise.
182
+ */
183
+ isEmpty() {
184
+ return this.length() === 0;
185
+ }
186
+ /**
187
+ * Adds an item to the end of the buffer.
188
+ */
189
+ push(t) {
190
+ if (this.isFull())
191
+ throw new RangeError("Ring buffer is full.");
192
+ this.set(this.end, t), this.end = this.wrap(this.end + 1);
193
+ }
194
+ /**
195
+ * Adds many items to the end of the buffer, in order.
196
+ */
197
+ pushAll(t) {
198
+ for (const e of t)
199
+ this.push(e);
200
+ }
201
+ /**
202
+ * Removes and returns the last item in the buffer.
203
+ */
204
+ pop() {
205
+ if (this.isEmpty())
206
+ throw new RangeError("Ring buffer is empty.");
207
+ this.end = this.wrap(this.end - 1);
208
+ const t = this.get(this.end);
209
+ return this.set(this.end, void 0), t;
210
+ }
211
+ /**
212
+ * Adds an item to the beginning of the buffer.
213
+ */
214
+ unshift(t) {
215
+ if (this.isFull())
216
+ throw new RangeError("Ring buffer is full.");
217
+ this.begin = this.wrap(this.begin - 1), this.set(this.begin, t);
218
+ }
219
+ /**
220
+ * Removes and returns the first item in the buffer.
221
+ */
222
+ shift() {
223
+ if (this.isEmpty())
224
+ throw new RangeError("Ring buffer is empty.");
225
+ const t = this.get(this.begin);
226
+ return this.set(this.begin, void 0), this.begin = this.wrap(this.begin + 1), t;
227
+ }
228
+ /**
229
+ * Removes and returns a specific item in the buffer, and moves the last item
230
+ * to the vacated slot. This is useful for implementing a shuffling stream.
231
+ * Note that this operation necessarily scrambles the original order.
232
+ *
233
+ * @param relativeIndex: the index of the item to remove, relative to the
234
+ * first item in the buffer (e.g., hiding the ring nature of the underlying
235
+ * storage).
236
+ */
237
+ shuffleExcise(t) {
238
+ if (this.isEmpty())
239
+ throw new RangeError("Ring buffer is empty.");
240
+ const e = this.wrap(this.begin + t), r = this.get(e);
241
+ return this.set(e, this.pop()), r;
242
+ }
243
+ }
244
+ /**
245
+ * @license
246
+ * Copyright 2018 Google LLC. All Rights Reserved.
247
+ * Licensed under the Apache License, Version 2.0 (the "License");
248
+ * you may not use this file except in compliance with the License.
249
+ * You may obtain a copy of the License at
250
+ *
251
+ * http://www.apache.org/licenses/LICENSE-2.0
252
+ *
253
+ * Unless required by applicable law or agreed to in writing, software
254
+ * distributed under the License is distributed on an "AS IS" BASIS,
255
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
256
+ * See the License for the specific language governing permissions and
257
+ * limitations under the License.
258
+ *
259
+ * =============================================================================
260
+ */
261
+ class y extends C {
262
+ /**
263
+ * Constructs a `GrowingRingBuffer`.
264
+ */
265
+ constructor() {
266
+ super(y.INITIAL_CAPACITY);
267
+ }
268
+ isFull() {
269
+ return !1;
270
+ }
271
+ push(t) {
272
+ super.isFull() && this.expand(), super.push(t);
273
+ }
274
+ unshift(t) {
275
+ super.isFull() && this.expand(), super.unshift(t);
276
+ }
277
+ /**
278
+ * Doubles the capacity of the buffer.
279
+ */
280
+ expand() {
281
+ const t = this.capacity * 2, e = new Array(t), r = this.length();
282
+ for (let n = 0; n < r; n++)
283
+ e[n] = this.get(this.wrap(this.begin + n));
284
+ this.data = e, this.capacity = t, this.doubledCapacity = 2 * this.capacity, this.begin = 0, this.end = r;
285
+ }
286
+ }
287
+ y.INITIAL_CAPACITY = 32;
288
+ /**
289
+ * @license
290
+ * Copyright 2018 Google LLC. All Rights Reserved.
291
+ * Licensed under the Apache License, Version 2.0 (the "License");
292
+ * you may not use this file except in compliance with the License.
293
+ * You may obtain a copy of the License at
294
+ *
295
+ * http://www.apache.org/licenses/LICENSE-2.0
296
+ *
297
+ * Unless required by applicable law or agreed to in writing, software
298
+ * distributed under the License is distributed on an "AS IS" BASIS,
299
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
300
+ * See the License for the specific language governing permissions and
301
+ * limitations under the License.
302
+ *
303
+ * =============================================================================
304
+ */
305
+ function W(s) {
306
+ return new Y(s);
307
+ }
308
+ function k(s) {
309
+ return new J(s);
310
+ }
311
+ function U(s, t) {
312
+ return new F(s, t);
313
+ }
314
+ class i {
315
+ /**
316
+ * Collect all remaining elements of a bounded stream into an array.
317
+ * Obviously this will succeed only for small streams that fit in memory.
318
+ * Useful for testing.
319
+ *
320
+ * @returns A Promise for an array of stream elements, which will resolve
321
+ * when the stream is exhausted.
322
+ */
323
+ async toArray() {
324
+ const t = [];
325
+ let e = await this.next();
326
+ for (; !e.done; )
327
+ t.push(e.value), e = await this.next();
328
+ return t;
329
+ }
330
+ /**
331
+ * Collect all elements of this dataset into an array with prefetching 100
332
+ * elements. This is useful for testing, because the prefetch changes the
333
+ * order in which the Promises are resolved along the processing pipeline.
334
+ * This may help expose bugs where results are dependent on the order of
335
+ * Promise resolution rather than on the logical order of the stream (i.e.,
336
+ * due to hidden mutable state).
337
+ *
338
+ * @returns A Promise for an array of stream elements, which will resolve
339
+ * when the stream is exhausted.
340
+ */
341
+ async toArrayForTest() {
342
+ const t = this.prefetch(100), e = [];
343
+ let r = await t.next();
344
+ for (; !r.done; )
345
+ e.push(r.value), r = await t.next();
346
+ return e;
347
+ }
348
+ /**
349
+ * Draw items from the stream until it is exhausted.
350
+ *
351
+ * This can be useful when the stream has side effects but no output. In
352
+ * that case, calling this function guarantees that the stream will be
353
+ * fully processed.
354
+ */
355
+ async resolveFully() {
356
+ let t = await this.next();
357
+ for (; !t.done; )
358
+ t = await this.next();
359
+ }
360
+ /**
361
+ * Draw items from the stream until it is exhausted, or a predicate fails.
362
+ *
363
+ * This can be useful when the stream has side effects but no output. In
364
+ * that case, calling this function guarantees that the stream will be
365
+ * fully processed.
366
+ */
367
+ async resolveWhile(t) {
368
+ let e = await this.next(), r = t(e.value);
369
+ for (; !e.done && r; )
370
+ e = await this.next(), r = t(e.value);
371
+ }
372
+ /**
373
+ * Handles errors thrown on this stream using a provided handler function.
374
+ *
375
+ * @param handler A function that handles any `Error` thrown during a `next()`
376
+ * call and returns true if the stream should continue (dropping the failed
377
+ * call) or false if the stream should quietly terminate. If the handler
378
+ * itself throws (or rethrows) an `Error`, that will be propagated.
379
+ *
380
+ * @returns A `LazyIterator` of elements passed through from upstream,
381
+ * possibly filtering or terminating on upstream `next()` calls that
382
+ * throw an `Error`.
383
+ */
384
+ handleErrors(t) {
385
+ return new et(this, t);
386
+ }
387
+ // TODO(soergel): Implement reduce() etc.
388
+ /**
389
+ * Filters this stream according to `predicate`.
390
+ *
391
+ * @param predicate A function mapping a stream element to a boolean or a
392
+ * `Promise` for one.
393
+ *
394
+ * @returns A `LazyIterator` of elements for which the predicate was true.
395
+ */
396
+ filter(t) {
397
+ return new Z(this, t);
398
+ }
399
+ /**
400
+ * Maps this stream through a 1-to-1 transform.
401
+ *
402
+ * @param transform A function mapping a stream element to a transformed
403
+ * element.
404
+ *
405
+ * @returns A `LazyIterator` of transformed elements.
406
+ */
407
+ map(t) {
408
+ return new tt(this, t);
409
+ }
410
+ /**
411
+ * Maps this stream through an async 1-to-1 transform.
412
+ *
413
+ * @param transform A function mapping a stream element to a `Promise` for a
414
+ * transformed stream element.
415
+ *
416
+ * @returns A `LazyIterator` of transformed elements.
417
+ */
418
+ mapAsync(t) {
419
+ return new g(this, t);
420
+ }
421
+ /**
422
+ * Maps this stream through a 1-to-1 transform, forcing serial execution.
423
+ *
424
+ * @param transform A function mapping a stream element to a transformed
425
+ * element.
426
+ *
427
+ * @returns A `LazyIterator` of transformed elements.
428
+ */
429
+ serialMapAsync(t) {
430
+ return new g(this, t).serial();
431
+ }
432
+ /**
433
+ * Maps this stream through a 1-to-many transform.
434
+ *
435
+ * @param transform A function mapping a stream element to an array of
436
+ * transformed elements.
437
+ *
438
+ * @returns A `DataStream` of transformed elements.
439
+ */
440
+ flatmap(t) {
441
+ return new rt(this, t);
442
+ }
443
+ /**
444
+ * Apply a function to every element of the stream.
445
+ *
446
+ * @param f A function to apply to each stream element.
447
+ */
448
+ async forEachAsync(t) {
449
+ return this.map(t).resolveFully();
450
+ }
451
+ /**
452
+ * Apply a function to every element of the stream, forcing serial execution.
453
+ *
454
+ * @param f A function to apply to each stream element. Should return 'true'
455
+ * to indicate that the stream should continue, or 'false' to cause it to
456
+ * terminate.
457
+ */
458
+ async serialForEach(t) {
459
+ return this.serialMapAsync(t).resolveWhile((e) => e === !0);
460
+ }
461
+ /**
462
+ * Groups elements into batches, represented as arrays of elements.
463
+ *
464
+ * We can think of the elements of this iterator as 'rows' (even if they are
465
+ * nested structures). By the same token, consecutive values for a given
466
+ * key within the elements form a 'column'. This matches the usual sense of
467
+ * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
468
+ *
469
+ * Thus, "Row-major" means that the resulting batch is simply a collection of
470
+ * rows: `[row1, row2, row3, ...]`. This is contrast to the column-major
471
+ * form, which is needed for vectorized computation.
472
+ *
473
+ * @param batchSize The number of elements desired per batch.
474
+ * @param smallLastBatch Whether to emit the final batch when it has fewer
475
+ * than batchSize elements. Default true.
476
+ * @returns A `LazyIterator` of batches of elements, represented as arrays
477
+ * of the original element type.
478
+ */
479
+ rowMajorBatch(t, e = !0) {
480
+ return new j(this, t, e);
481
+ }
482
+ /**
483
+ * Groups elements into batches, represented in column-major form.
484
+ *
485
+ * We can think of the elements of this iterator as 'rows' (even if they are
486
+ * nested structures). By the same token, consecutive values for a given
487
+ * key within the elements form a 'column'. This matches the usual sense of
488
+ * 'row' and 'column' when processing tabular data (e.g., parsing a CSV).
489
+ *
490
+ * Thus, "column-major" means that the resulting batch is a (potentially
491
+ * nested) structure representing the columns. Each column entry, then,
492
+ * contains a collection of the values found in that column for a range of
493
+ * input elements. This representation allows for vectorized computation, in
494
+ * contrast to the row-major form.
495
+ *
496
+ * The inputs should all have the same nested structure (i.e., of arrays and
497
+ * dicts). The result is a single object with the same nested structure,
498
+ * where the leaves are arrays collecting the values of the inputs at that
499
+ * location (or, optionally, the result of a custom function applied to those
500
+ * arrays).
501
+ *
502
+ * @param batchSize The number of elements desired per batch.
503
+ * @param smallLastBatch Whether to emit the final batch when it has fewer
504
+ * than batchSize elements. Default true.
505
+ * @param zipFn: (optional) A function that expects an array of elements at a
506
+ * single node of the object tree, and returns a `DeepMapResult`. The
507
+ * `DeepMapResult` either provides a result value for that node (i.e.,
508
+ * representing the subtree), or indicates that the node should be processed
509
+ * recursively. The default zipFn recurses as far as possible and places
510
+ * arrays at the leaves.
511
+ * @returns A `LazyIterator` of batches of elements, represented as an object
512
+ * with collections at the leaves.
513
+ */
514
+ columnMajorBatch(t, e = !0, r = T) {
515
+ return this.rowMajorBatch(t, e).map((a) => O(a, r));
516
+ }
517
+ /**
518
+ * Concatenate this `LazyIterator` with another.
519
+ *
520
+ * @param iterator A `LazyIterator` to be concatenated onto this one.
521
+ * @param baseErrorHandler An optional function that can intercept `Error`s
522
+ * raised during a `next()` call on the base stream. This function can
523
+ * decide whether the error should be propagated, whether the error should
524
+ * be ignored, or whether the base stream should be terminated.
525
+ * @returns A `LazyIterator`.
526
+ */
527
+ concatenate(t, e) {
528
+ return new F(W([this, t]), e);
529
+ }
530
+ /**
531
+ * Limits this stream to return at most `count` items.
532
+ *
533
+ * @param count The maximum number of items to provide from the stream. If
534
+ * a negative or undefined value is given, the entire stream is returned
535
+ * unaltered.
536
+ */
537
+ take(t) {
538
+ return t < 0 || t == null ? this : new K(this, t);
539
+ }
540
+ /**
541
+ * Skips the first `count` items in this stream.
542
+ *
543
+ * @param count The number of items to skip. If a negative or undefined
544
+ * value is given, the entire stream is returned unaltered.
545
+ */
546
+ skip(t) {
547
+ return t < 0 || t == null ? this : new X(this, t);
548
+ }
549
+ /**
550
+ * Prefetch the first `bufferSize` items in this stream.
551
+ *
552
+ * Note this prefetches Promises, but makes no guarantees about when those
553
+ * Promises resolve.
554
+ *
555
+ * @param bufferSize: An integer specifying the number of elements to be
556
+ * prefetched.
557
+ */
558
+ prefetch(t) {
559
+ return new S(this, t);
560
+ }
561
+ // TODO(soergel): deep sharded shuffle, where supported
562
+ /**
563
+ * Randomly shuffles the elements of this stream.
564
+ *
565
+ * @param bufferSize: An integer specifying the number of elements from
566
+ * this stream from which the new stream will sample.
567
+ * @param seed: (Optional.) An integer specifying the random seed that
568
+ * will be used to create the distribution.
569
+ */
570
+ shuffle(t, e) {
571
+ return new nt(this, t, e);
572
+ }
573
+ /**
574
+ * Force an iterator to execute serially: each next() call will await the
575
+ * prior one, so that they cannot execute concurrently.
576
+ */
577
+ serial() {
578
+ return new V(this);
579
+ }
580
+ }
581
+ class Y extends i {
582
+ constructor(t) {
583
+ super(), this.items = t, this.trav = 0;
584
+ }
585
+ summary() {
586
+ return `Array of ${this.items.length} items`;
587
+ }
588
+ async next() {
589
+ if (this.trav >= this.items.length)
590
+ return { value: null, done: !0 };
591
+ const t = this.items[this.trav];
592
+ return this.trav++, { value: Q(t), done: !1 };
593
+ }
594
+ }
595
+ class J extends i {
596
+ constructor(t) {
597
+ super(), this.nextFn = t;
598
+ }
599
+ summary() {
600
+ return "Function call";
601
+ }
602
+ async next() {
603
+ try {
604
+ return this.nextFn();
605
+ } catch (t) {
606
+ throw t.message = `Error thrown while iterating through a dataset: ${t.message}`, t;
607
+ }
608
+ }
609
+ }
610
+ class V extends i {
611
+ constructor(t) {
612
+ super(), this.upstream = t, this.lastRead = Promise.resolve({ value: null, done: !1 });
613
+ }
614
+ summary() {
615
+ return `${this.upstream.summary()} -> Serial`;
616
+ }
617
+ async next() {
618
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
619
+ }
620
+ async serialNext() {
621
+ return this.upstream.next();
622
+ }
623
+ }
624
+ class X extends i {
625
+ constructor(t, e) {
626
+ super(), this.upstream = t, this.maxCount = e, this.count = 0, this.lastRead = Promise.resolve({ value: null, done: !1 });
627
+ }
628
+ summary() {
629
+ return `${this.upstream.summary()} -> Skip`;
630
+ }
631
+ async next() {
632
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
633
+ }
634
+ async serialNext() {
635
+ for (; this.count++ < this.maxCount; ) {
636
+ const t = await this.upstream.next();
637
+ if (t.done)
638
+ return t;
639
+ R(t.value);
640
+ }
641
+ return this.upstream.next();
642
+ }
643
+ }
644
+ class K extends i {
645
+ constructor(t, e) {
646
+ super(), this.upstream = t, this.maxCount = e, this.count = 0;
647
+ }
648
+ summary() {
649
+ return `${this.upstream.summary()} -> Take`;
650
+ }
651
+ async next() {
652
+ return this.count++ >= this.maxCount ? { value: null, done: !0 } : this.upstream.next();
653
+ }
654
+ }
655
+ class j extends i {
656
+ constructor(t, e, r = !0) {
657
+ super(), this.upstream = t, this.batchSize = e, this.enableSmallLastBatch = r, this.lastRead = Promise.resolve({ value: null, done: !1 });
658
+ }
659
+ summary() {
660
+ return `${this.upstream.summary()} -> RowMajorBatch`;
661
+ }
662
+ async next() {
663
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
664
+ }
665
+ async serialNext() {
666
+ const t = [];
667
+ for (; t.length < this.batchSize; ) {
668
+ const e = await this.upstream.next();
669
+ if (e.done)
670
+ return this.enableSmallLastBatch && t.length > 0 ? { value: t, done: !1 } : { value: null, done: !0 };
671
+ t.push(e.value);
672
+ }
673
+ return { value: t, done: !1 };
674
+ }
675
+ }
676
+ class Z extends i {
677
+ constructor(t, e) {
678
+ super(), this.upstream = t, this.predicate = e, this.lastRead = Promise.resolve({ value: null, done: !1 });
679
+ }
680
+ summary() {
681
+ return `${this.upstream.summary()} -> Filter`;
682
+ }
683
+ async next() {
684
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
685
+ }
686
+ async serialNext() {
687
+ for (; ; ) {
688
+ const t = await this.upstream.next();
689
+ if (t.done || this.predicate(t.value))
690
+ return t;
691
+ R(t.value);
692
+ }
693
+ }
694
+ }
695
+ class tt extends i {
696
+ constructor(t, e) {
697
+ super(), this.upstream = t, this.transform = e;
698
+ }
699
+ summary() {
700
+ return `${this.upstream.summary()} -> Map`;
701
+ }
702
+ async next() {
703
+ const t = await this.upstream.next();
704
+ if (t.done)
705
+ return { value: null, done: !0 };
706
+ const e = f(t.value), r = this.transform(t.value), n = f(r);
707
+ for (const a of e)
708
+ v(a, n) || a.dispose();
709
+ return { value: r, done: !1 };
710
+ }
711
+ }
712
+ class et extends i {
713
+ constructor(t, e) {
714
+ super(), this.upstream = t, this.handler = e, this.count = 0, this.lastRead = Promise.resolve({ value: null, done: !1 });
715
+ }
716
+ summary() {
717
+ return `${this.upstream.summary()} -> handleErrors`;
718
+ }
719
+ async next() {
720
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
721
+ }
722
+ async serialNext() {
723
+ for (; ; )
724
+ try {
725
+ return await this.upstream.next();
726
+ } catch (t) {
727
+ if (!this.handler(t))
728
+ return { value: null, done: !0 };
729
+ }
730
+ }
731
+ }
732
+ class g extends i {
733
+ constructor(t, e) {
734
+ super(), this.upstream = t, this.transform = e;
735
+ }
736
+ summary() {
737
+ return `${this.upstream.summary()} -> AsyncMap`;
738
+ }
739
+ async next() {
740
+ const t = await this.upstream.next();
741
+ if (t.done)
742
+ return { value: null, done: !0 };
743
+ const e = f(t.value), r = await this.transform(t.value), n = f(r);
744
+ for (const a of e)
745
+ v(a, n) || a.dispose();
746
+ return { value: r, done: !1 };
747
+ }
748
+ }
749
+ class st extends i {
750
+ constructor() {
751
+ super(), this.outputQueue = new y(), this.lastRead = Promise.resolve({ value: null, done: !1 });
752
+ }
753
+ async next() {
754
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
755
+ }
756
+ async serialNext() {
757
+ for (; this.outputQueue.length() === 0; )
758
+ if (!await this.pump())
759
+ return { value: null, done: !0 };
760
+ return { value: this.outputQueue.shift(), done: !1 };
761
+ }
762
+ }
763
+ class rt extends st {
764
+ constructor(t, e) {
765
+ super(), this.upstream = t, this.transform = e;
766
+ }
767
+ summary() {
768
+ return `${this.upstream.summary()} -> Flatmap`;
769
+ }
770
+ async pump() {
771
+ const t = await this.upstream.next();
772
+ if (t.done)
773
+ return !1;
774
+ const e = f(t.value), r = this.transform(t.value), n = f(r);
775
+ this.outputQueue.pushAll(r);
776
+ for (const a of e)
777
+ v(a, n) || a.dispose();
778
+ return !0;
779
+ }
780
+ }
781
+ class F extends i {
782
+ constructor(t, e) {
783
+ super(), this.baseErrorHandler = e, this.lastRead = null, this.iterator = null, this.moreIterators = t;
784
+ }
785
+ summary() {
786
+ return "TODO: fill in upstream of chained summaries -> Chained";
787
+ }
788
+ async next() {
789
+ return this.lastRead = this.readFromChain(this.lastRead), this.lastRead;
790
+ }
791
+ async readFromChain(t) {
792
+ if (await t, this.iterator == null) {
793
+ const r = await this.moreIterators.next();
794
+ if (r.done)
795
+ return { value: null, done: !0 };
796
+ this.iterator = r.value, this.baseErrorHandler != null && (this.iterator = this.iterator.handleErrors(this.baseErrorHandler));
797
+ }
798
+ const e = await this.iterator.next();
799
+ return e.done ? (this.iterator = null, this.readFromChain(t)) : e;
800
+ }
801
+ }
802
+ var b;
803
+ (function(s) {
804
+ s[s.FAIL = 0] = "FAIL", s[s.SHORTEST = 1] = "SHORTEST", s[s.LONGEST = 2] = "LONGEST";
805
+ })(b || (b = {}));
806
+ class S extends i {
807
+ constructor(t, e) {
808
+ super(), this.upstream = t, this.bufferSize = e, this.buffer = new C(e);
809
+ }
810
+ summary() {
811
+ return `${this.upstream.summary()} -> Prefetch`;
812
+ }
813
+ /**
814
+ * Refill the prefetch buffer. Returns only after the buffer is full, or
815
+ * the upstream source is exhausted.
816
+ */
817
+ refill() {
818
+ for (; !this.buffer.isFull(); ) {
819
+ const t = this.upstream.next();
820
+ this.buffer.push(t);
821
+ }
822
+ }
823
+ next() {
824
+ return this.refill(), this.buffer.shift();
825
+ }
826
+ }
827
+ class nt extends S {
828
+ constructor(t, e, r) {
829
+ super(t, e), this.upstream = t, this.windowSize = e, this.upstreamExhausted = !1, this.random = E.alea(r || z().toString()), this.lastRead = Promise.resolve({ value: null, done: !1 });
830
+ }
831
+ async next() {
832
+ return this.lastRead = this.lastRead.then(() => this.serialNext()), this.lastRead;
833
+ }
834
+ randomInt(t) {
835
+ return Math.floor(this.random() * t);
836
+ }
837
+ chooseIndex() {
838
+ return this.randomInt(this.buffer.length());
839
+ }
840
+ async serialNext() {
841
+ for (this.upstreamExhausted || this.refill(); !this.buffer.isEmpty(); ) {
842
+ const t = this.chooseIndex(), e = await this.buffer.shuffleExcise(t);
843
+ if (e.done)
844
+ this.upstreamExhausted = !0;
845
+ else
846
+ return this.refill(), e;
847
+ }
848
+ return { value: null, done: !0 };
849
+ }
850
+ }
851
+ /**
852
+ * @license
853
+ * Copyright 2018 Google LLC. All Rights Reserved.
854
+ * Licensed under the Apache License, Version 2.0 (the "License");
855
+ * you may not use this file except in compliance with the License.
856
+ * You may obtain a copy of the License at
857
+ *
858
+ * http://www.apache.org/licenses/LICENSE-2.0
859
+ *
860
+ * Unless required by applicable law or agreed to in writing, software
861
+ * distributed under the License is distributed on an "AS IS" BASIS,
862
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
863
+ * See the License for the specific language governing permissions and
864
+ * limitations under the License.
865
+ *
866
+ * =============================================================================
867
+ */
868
+ class N {
869
+ constructor() {
870
+ this.size = null;
871
+ }
872
+ // TODO(soergel): Make Datasets report whether repeated iterator() calls
873
+ // produce the same result (e.g., reading from a file) or different results
874
+ // (e.g., from the webcam). Currently we don't make this distinction but it
875
+ // could be important for the user to know.
876
+ // abstract isDeterministic(): boolean;
877
+ /**
878
+ * Groups elements into batches.
879
+ *
880
+ * It is assumed that each of the incoming dataset elements has the same
881
+ * structure -- i.e. the same set of keys at each location in an object
882
+ * hierarchy. For each key, the resulting `Dataset` provides a batched
883
+ * element collecting all of the incoming values for that key.
884
+ *
885
+ * * Incoming primitives are grouped into a 1-D Tensor.
886
+ * * Incoming Tensors are grouped into a new Tensor where the 0th axis is
887
+ * the batch dimension.
888
+ * * Incoming arrays are converted to Tensor and then batched.
889
+ * * A nested array is interpreted as an n-D Tensor, so the batched result
890
+ * has n+1 dimensions.
891
+ * * An array that cannot be converted to Tensor produces an error.
892
+ *
893
+ * If an array should not be batched as a unit, it should first be converted
894
+ * to an object with integer keys.
895
+ *
896
+ * Here are a few examples:
897
+ *
898
+ * Batch a dataset of numbers:
899
+ * ```js
900
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8]).batch(4);
901
+ * await a.forEachAsync(e => e.print());
902
+ * ```
903
+ *
904
+ * Batch a dataset of arrays:
905
+ * ```js
906
+ * const b = tf.data.array([[1], [2], [3], [4], [5], [6], [7], [8]]).batch(4);
907
+ * await b.forEachAsync(e => e.print());
908
+ * ```
909
+ *
910
+ * Batch a dataset of objects:
911
+ * ```js
912
+ * const c = tf.data.array([{a: 1, b: 11}, {a: 2, b: 12}, {a: 3, b: 13},
913
+ * {a: 4, b: 14}, {a: 5, b: 15}, {a: 6, b: 16}, {a: 7, b: 17},
914
+ * {a: 8, b: 18}]).batch(4);
915
+ * await c.forEachAsync(e => {
916
+ * console.log('{');
917
+ * for(var key in e) {
918
+ * console.log(key+':');
919
+ * e[key].print();
920
+ * }
921
+ * console.log('}');
922
+ * })
923
+ * ```
924
+ *
925
+ * @param batchSize The number of elements desired per batch.
926
+ * @param smallLastBatch Whether to emit the final batch when it has fewer
927
+ * than batchSize elements. Default true.
928
+ * @returns A `Dataset`, from which a stream of batches can be obtained.
929
+ *
930
+ * @doc {heading: 'Data', subheading: 'Classes'}
931
+ */
932
+ batch(t, e = !0) {
933
+ const r = this;
934
+ _(t > 0, () => `batchSize needs to be positive, but it is
935
+ ${t}`);
936
+ let n;
937
+ return this.size === 1 / 0 || this.size == null ? n = this.size : e ? n = Math.ceil(this.size / t) : n = Math.floor(this.size / t), o(async () => (await r.iterator()).columnMajorBatch(t, e, at), n);
938
+ }
939
+ /**
940
+ * Concatenates this `Dataset` with another.
941
+ *
942
+ * ```js
943
+ * const a = tf.data.array([1, 2, 3]);
944
+ * const b = tf.data.array([4, 5, 6]);
945
+ * const c = a.concatenate(b);
946
+ * await c.forEachAsync(e => console.log(e));
947
+ * ```
948
+ *
949
+ * @param dataset A `Dataset` to be concatenated onto this one.
950
+ * @returns A `Dataset`.
951
+ *
952
+ * @doc {heading: 'Data', subheading: 'Classes'}
953
+ */
954
+ concatenate(t) {
955
+ const e = this;
956
+ let r;
957
+ return this.size === 1 / 0 || t.size === 1 / 0 ? r = 1 / 0 : this.size != null && t.size != null ? r = this.size + t.size : r = null, o(async () => (await e.iterator()).concatenate(await t.iterator()), r);
958
+ }
959
+ /**
960
+ * Filters this dataset according to `predicate`.
961
+ *
962
+ * ```js
963
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
964
+ * .filter(x => x%2 === 0);
965
+ * await a.forEachAsync(e => console.log(e));
966
+ * ```
967
+ *
968
+ * @param predicate A function mapping a dataset element to a boolean or a
969
+ * `Promise` for one.
970
+ *
971
+ * @returns A `Dataset` of elements for which the predicate was true.
972
+ *
973
+ * @doc {heading: 'Data', subheading: 'Classes'}
974
+ */
975
+ filter(t) {
976
+ const e = this;
977
+ let r;
978
+ return this.size === 1 / 0 ? r = 1 / 0 : r = null, o(async () => (await e.iterator()).filter((n) => x(() => t(n))), r);
979
+ }
980
+ /**
981
+ * Apply a function to every element of the dataset.
982
+ *
983
+ * After the function is applied to a dataset element, any Tensors contained
984
+ * within that element are disposed.
985
+ *
986
+ * ```js
987
+ * const a = tf.data.array([1, 2, 3]);
988
+ * await a.forEachAsync(e => console.log(e));
989
+ * ```
990
+ *
991
+ * @param f A function to apply to each dataset element.
992
+ * @returns A `Promise` that resolves after all elements have been processed.
993
+ *
994
+ * @doc {heading: 'Data', subheading: 'Classes'}
995
+ */
996
+ async forEachAsync(t) {
997
+ return (await this.iterator()).forEachAsync(t);
998
+ }
999
+ /**
1000
+ * Maps this dataset through a 1-to-1 transform.
1001
+ *
1002
+ * ```js
1003
+ * const a = tf.data.array([1, 2, 3]).map(x => x*x);
1004
+ * await a.forEachAsync(e => console.log(e));
1005
+ * ```
1006
+ *
1007
+ * @param transform A function mapping a dataset element to a transformed
1008
+ * dataset element.
1009
+ *
1010
+ * @returns A `Dataset` of transformed elements.
1011
+ *
1012
+ * @doc {heading: 'Data', subheading: 'Classes'}
1013
+ */
1014
+ map(t) {
1015
+ const e = this;
1016
+ return o(async () => (await e.iterator()).map((r) => x(() => t(r))), this.size);
1017
+ }
1018
+ /**
1019
+ * Maps this dataset through an async 1-to-1 transform.
1020
+ *
1021
+ * ```js
1022
+ * const a =
1023
+ * tf.data.array([1, 2, 3]).mapAsync(x => new Promise(function(resolve){
1024
+ * setTimeout(() => {
1025
+ * resolve(x * x);
1026
+ * }, Math.random()*1000 + 500);
1027
+ * }));
1028
+ * console.log(await a.toArray());
1029
+ * ```
1030
+ *
1031
+ * @param transform A function mapping a dataset element to a `Promise` for a
1032
+ * transformed dataset element. This transform is responsible for disposing
1033
+ * any intermediate `Tensor`s, i.e. by wrapping its computation in
1034
+ * `tf.tidy()`; that cannot be automated here (as it is in the synchronous
1035
+ * `map()` case).
1036
+ *
1037
+ * @returns A `Dataset` of transformed elements.
1038
+ *
1039
+ * @doc {heading: 'Data', subheading: 'Classes'}
1040
+ */
1041
+ mapAsync(t) {
1042
+ const e = this;
1043
+ return o(async () => (await e.iterator()).mapAsync(t), this.size);
1044
+ }
1045
+ /**
1046
+ * Creates a `Dataset` that prefetches elements from this dataset.
1047
+ *
1048
+ * @param bufferSize: An integer specifying the number of elements to be
1049
+ * prefetched.
1050
+ * @returns A `Dataset`.
1051
+ *
1052
+ * @doc {heading: 'Data', subheading: 'Classes'}
1053
+ */
1054
+ prefetch(t) {
1055
+ if (t == null)
1056
+ throw new RangeError("`Dataset.prefetch()` requires bufferSize to be specified.");
1057
+ const e = this;
1058
+ return o(async () => (await e.iterator()).prefetch(t), this.size);
1059
+ }
1060
+ /**
1061
+ * Repeats this dataset `count` times.
1062
+ *
1063
+ * NOTE: If this dataset is a function of global state (e.g. a random number
1064
+ * generator), then different repetitions may produce different elements.
1065
+ *
1066
+ * ```js
1067
+ * const a = tf.data.array([1, 2, 3]).repeat(3);
1068
+ * await a.forEachAsync(e => console.log(e));
1069
+ * ```
1070
+ *
1071
+ * @param count: (Optional) An integer, representing the number of times
1072
+ * the dataset should be repeated. The default behavior (if `count` is
1073
+ * `undefined` or negative) is for the dataset be repeated indefinitely.
1074
+ * @returns A `Dataset`.
1075
+ *
1076
+ * @doc {heading: 'Data', subheading: 'Classes'}
1077
+ */
1078
+ repeat(t) {
1079
+ const e = this;
1080
+ let r;
1081
+ return this.size != null && t > 0 ? r = this.size * t : t === 0 ? r = 0 : this.size != null && (t === void 0 || t < 0) ? r = 1 / 0 : r = null, o(async () => {
1082
+ const n = k(async () => ({ value: await e.iterator(), done: !1 }));
1083
+ return U(n.take(t));
1084
+ }, r);
1085
+ }
1086
+ /**
1087
+ * Creates a `Dataset` that skips `count` initial elements from this dataset.
1088
+ *
1089
+ * ```js
1090
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6]).skip(3);
1091
+ * await a.forEachAsync(e => console.log(e));
1092
+ * ```
1093
+ *
1094
+ * @param count: The number of elements of this dataset that should be skipped
1095
+ * to form the new dataset. If `count` is greater than the size of this
1096
+ * dataset, the new dataset will contain no elements. If `count`
1097
+ * is `undefined` or negative, skips the entire dataset.
1098
+ *
1099
+ * @returns A `Dataset`.
1100
+ *
1101
+ * @doc {heading: 'Data', subheading: 'Classes'}
1102
+ */
1103
+ skip(t) {
1104
+ const e = this;
1105
+ let r;
1106
+ return this.size != null && t >= 0 && this.size >= t ? r = this.size - t : this.size != null && (this.size < t || t === void 0 || t < 0) ? r = 0 : r = null, o(async () => (await e.iterator()).skip(t), r);
1107
+ }
1108
+ /**
1109
+ * Pseudorandomly shuffles the elements of this dataset. This is done in a
1110
+ * streaming manner, by sampling from a given number of prefetched elements.
1111
+ *
1112
+ * ```js
1113
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6]).shuffle(3);
1114
+ * await a.forEachAsync(e => console.log(e));
1115
+ * ```
1116
+ *
1117
+ * @param bufferSize: An integer specifying the number of elements from this
1118
+ * dataset from which the new dataset will sample.
1119
+ * @param seed: (Optional) An integer specifying the random seed that will
1120
+ * be used to create the distribution.
1121
+ * @param reshuffleEachIteration: (Optional) A boolean, which if true
1122
+ * indicates that the dataset should be pseudorandomly reshuffled each time
1123
+ * it is iterated over. If false, elements will be returned in the same
1124
+ * shuffled order on each iteration. (Defaults to `true`.)
1125
+ * @returns A `Dataset`.
1126
+ *
1127
+ * @doc {heading: 'Data', subheading: 'Classes'}
1128
+ */
1129
+ shuffle(t, e, r = !0) {
1130
+ if (t == null || t < 0)
1131
+ throw this.size == null ? new RangeError("`Dataset.shuffle()` requires bufferSize to be specified.") : new RangeError(`\`Dataset.shuffle()\` requires bufferSize to be specified. If your data fits in main memory (for regular JS objects), and/or GPU memory (for \`tf.Tensor\`s), consider setting bufferSize to the dataset size (${this.size} elements)`);
1132
+ const n = this, a = E.alea(e || z().toString());
1133
+ return o(async () => {
1134
+ let l = a.int32();
1135
+ return r && (l += a.int32()), (await n.iterator()).shuffle(t, l.toString());
1136
+ }, this.size);
1137
+ }
1138
+ /**
1139
+ * Creates a `Dataset` with at most `count` initial elements from this
1140
+ * dataset.
1141
+ *
1142
+ * ```js
1143
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6]).take(3);
1144
+ * await a.forEachAsync(e => console.log(e));
1145
+ * ```
1146
+ *
1147
+ * @param count: The number of elements of this dataset that should be taken
1148
+ * to form the new dataset. If `count` is `undefined` or negative, or if
1149
+ * `count` is greater than the size of this dataset, the new dataset will
1150
+ * contain all elements of this dataset.
1151
+ * @returns A `Dataset`.
1152
+ *
1153
+ * @doc {heading: 'Data', subheading: 'Classes'}
1154
+ */
1155
+ take(t) {
1156
+ const e = this;
1157
+ let r;
1158
+ return this.size != null && this.size > t ? r = t : this.size != null && this.size <= t ? r = this.size : r = null, o(async () => (await e.iterator()).take(t), r);
1159
+ }
1160
+ /**
1161
+ * Collect all elements of this dataset into an array.
1162
+ *
1163
+ * Obviously this will succeed only for small datasets that fit in memory.
1164
+ * Useful for testing and generally should be avoided if possible.
1165
+ *
1166
+ * ```js
1167
+ * const a = tf.data.array([1, 2, 3, 4, 5, 6]);
1168
+ * console.log(await a.toArray());
1169
+ * ```
1170
+ *
1171
+ * @returns A Promise for an array of elements, which will resolve
1172
+ * when a new stream has been obtained and fully consumed.
1173
+ *
1174
+ * @doc {heading: 'Data', subheading: 'Classes'}
1175
+ */
1176
+ async toArray() {
1177
+ if (this.size === 1 / 0)
1178
+ throw new Error("Can not convert infinite data stream to array.");
1179
+ return (await this.iterator()).toArray();
1180
+ }
1181
+ /**
1182
+ * Collect all elements of this dataset into an array with prefetching 100
1183
+ * elements. This is useful for testing, because the prefetch changes the
1184
+ * order in which the Promises are resolved along the processing pipeline.
1185
+ * This may help expose bugs where results are dependent on the order of
1186
+ * Promise resolution rather than on the logical order of the stream (i.e.,
1187
+ * due to hidden mutable state).
1188
+ *
1189
+ * @returns A Promise for an array of elements, which will resolve
1190
+ * when a new stream has been obtained and fully consumed.
1191
+ */
1192
+ async toArrayForTest() {
1193
+ if (this.size === 1 / 0)
1194
+ throw new Error("Can not convert infinite data stream to array.");
1195
+ return (await this.iterator()).toArrayForTest();
1196
+ }
1197
+ }
1198
+ N.MAX_BUFFER_SIZE = 1e4;
1199
+ function o(s, t = null) {
1200
+ return new class extends N {
1201
+ constructor() {
1202
+ super(...arguments), this.size = t;
1203
+ }
1204
+ /*
1205
+ * Provide a new stream of elements. Note this will also start new streams
1206
+ * from any underlying `Dataset`s.
1207
+ */
1208
+ async iterator() {
1209
+ return s();
1210
+ }
1211
+ }();
1212
+ }
1213
+ function at(s) {
1214
+ if (s === null)
1215
+ return null;
1216
+ const t = s[0];
1217
+ return H(t) ? { value: it(s), recurse: !1 } : { value: null, recurse: !0 };
1218
+ }
1219
+ function it(s) {
1220
+ if (s.length === 0)
1221
+ throw new Error("Can't make a batch of zero elements.");
1222
+ return s[0] instanceof m ? P(s) : D(s);
1223
+ }
1224
+ /**
1225
+ * @license
1226
+ * Copyright 2018 Google LLC. All Rights Reserved.
1227
+ * Licensed under the Apache License, Version 2.0 (the "License");
1228
+ * you may not use this file except in compliance with the License.
1229
+ * You may obtain a copy of the License at
1230
+ *
1231
+ * http://www.apache.org/licenses/LICENSE-2.0
1232
+ *
1233
+ * Unless required by applicable law or agreed to in writing, software
1234
+ * distributed under the License is distributed on an "AS IS" BASIS,
1235
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1236
+ * See the License for the specific language governing permissions and
1237
+ * limitations under the License.
1238
+ *
1239
+ * =============================================================================
1240
+ */
1241
+ function ut(s) {
1242
+ return o(async () => {
1243
+ const t = await s();
1244
+ return k(() => t.next());
1245
+ });
1246
+ }
1247
+ class dt {
2
1248
  tokenizer;
3
1249
  blockSize;
4
- tf;
5
- constructor(s, o, i = 128) {
6
- this.tokenizer = o, this.blockSize = i, this.tf = s;
1250
+ constructor(t, e = 128) {
1251
+ this.tokenizer = t, this.blockSize = e;
7
1252
  }
8
1253
  // Create dataset from text files
9
- async createTextDataset(s, o = 32, i = 0, c = 1) {
10
- const r = await Promise.all(s.map((t) => this.tokenizer.encode(t))), h = this.tokenizer.eosToken >= 0, n = r.map((t) => h ? [...t, this.tokenizer.eosToken] : t).flat(), a = n.slice(
11
- Math.floor(i * n.length),
12
- c === 1 ? void 0 : Math.floor(c * n.length)
13
- ), l = (function* () {
1254
+ async createTextDataset(t, e = 32, r = 0, n = 1) {
1255
+ const a = await Promise.all(t.map((u) => this.tokenizer.encode(u))), l = this.tokenizer.eosToken >= 0, h = a.map((u) => l ? [...u, this.tokenizer.eosToken] : u).flat(), c = h.slice(
1256
+ Math.floor(r * h.length),
1257
+ n === 1 ? void 0 : Math.floor(n * h.length)
1258
+ ), w = (function* () {
14
1259
  for (; ; ) {
15
- const t = Math.floor(Math.random() * (a.length - this.blockSize - 1)), e = a.slice(t, t + this.blockSize), k = a.slice(t + 1, t + this.blockSize + 1);
16
- yield { xs: e, ys: k };
1260
+ const u = Math.floor(Math.random() * (c.length - this.blockSize - 1)), d = c.slice(u, u + this.blockSize), B = c.slice(u + 1, u + this.blockSize + 1);
1261
+ yield { xs: d, ys: B };
17
1262
  }
18
1263
  }).bind(this);
19
- return this.tf.data.generator(l).batch(o).map((t) => {
20
- const e = t;
21
- return this.tf.tidy(() => ({
22
- xs: e.xs.cast("int32"),
23
- ys: e.ys.cast("int32")
1264
+ return ut(w).batch(e).map((u) => {
1265
+ const d = u;
1266
+ return x(() => ({
1267
+ xs: d.xs.cast("int32"),
1268
+ ys: d.ys.cast("int32")
24
1269
  // this.tf.oneHot(batchData.ys.cast('int32'), this.tokenizer.vocabSize),
25
1270
  }));
26
1271
  }).prefetch(2);
27
1272
  }
28
1273
  }
29
1274
  export {
30
- z as DatasetBuilder
1275
+ dt as DatasetBuilder
31
1276
  };