@genai-fi/nanogpt 0.19.1 → 0.20.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
- package/dist/DatasetBuilder-DgURD85T.js +712 -0
- package/dist/Generator.js +2 -11941
- package/dist/RealDiv-DBu0FQqT.js +362 -0
- package/dist/Reshape-CABOPB9d.js +94 -0
- package/dist/Reshape-DqO3r8BC.js +17 -0
- package/dist/TeachableLLM.d.ts +5 -5
- package/dist/TeachableLLM.js +2 -273
- package/dist/Trainer.js +2 -244
- package/dist/backend.js +12 -12
- package/dist/backend_util-Cg-roD1p.js +399 -0
- package/dist/binary_op_util-CrYk9LXL.js +103 -0
- package/dist/checks/appendCache.js +54 -21
- package/dist/checks/attentionMask.js +55 -36
- package/dist/checks/check.js +31 -19
- package/dist/checks/gelu.js +45 -17
- package/dist/checks/index.js +25 -25
- package/dist/checks/matMulGelu.js +83 -27
- package/dist/checks/normRMS.js +27 -15
- package/dist/checks/normRMSGrad.js +21 -11
- package/dist/checks/packUnpack.js +45 -17
- package/dist/checks/qkv.js +33 -33
- package/dist/checks/rope.js +29 -35
- package/dist/checks/weights.d.ts +1 -1
- package/dist/checks/weights.js +25 -29
- package/dist/chunk-BPntVaq0.js +23 -0
- package/dist/complex_util-CkazZsaH.js +60 -0
- package/dist/concat_util-CWDZCBlA.js +19 -0
- package/dist/data/docx.js +3044 -13
- package/dist/data/pdf.js +16 -13
- package/dist/data/textLoader.js +607 -112
- package/dist/dist-BewPQWjc.js +7572 -0
- package/dist/dist-DVmq73nz.js +8775 -0
- package/dist/dist-DXwIvKxl.js +896 -0
- package/dist/dist-VEU5mfO0.js +7545 -0
- package/dist/gelu-Bf1HW1RY.js +27 -0
- package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
- package/dist/inference/types.js +0 -1
- package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
- package/dist/layers/BaseLayer.d.ts +2 -2
- package/dist/layers/BaseLayer.js +75 -73
- package/dist/layers/CausalSelfAttention.d.ts +1 -1
- package/dist/layers/CausalSelfAttention.js +98 -85
- package/dist/layers/LoRA.js +47 -57
- package/dist/layers/MLP.d.ts +1 -1
- package/dist/layers/MLP.js +33 -43
- package/dist/layers/PositionEmbedding.d.ts +1 -1
- package/dist/layers/PositionEmbedding.js +26 -30
- package/dist/layers/RMSNorm.d.ts +1 -1
- package/dist/layers/RMSNorm.js +19 -21
- package/dist/layers/RoPECache.js +336 -49
- package/dist/layers/TiedEmbedding.d.ts +1 -1
- package/dist/layers/TiedEmbedding.js +30 -34
- package/dist/layers/TransformerBlock.d.ts +1 -1
- package/dist/layers/TransformerBlock.js +50 -39
- package/dist/layers/WeightStore.js +68 -75
- package/dist/loader/load.js +2 -68
- package/dist/loader/loadHF.d.ts +2 -2
- package/dist/loader/loadHF.js +2 -22
- package/dist/loader/loadTransformers.d.ts +1 -1
- package/dist/loader/loadTransformers.js +2 -44
- package/dist/loader/loadZipMeta.js +15 -15
- package/dist/loader/newZipLoad.js +2 -31
- package/dist/loader/oldZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.js +2 -80
- package/dist/loader/save.d.ts +5 -5
- package/dist/loader/save.js +2 -90
- package/dist/loader/types.d.ts +9 -8
- package/dist/loader/types.js +0 -1
- package/dist/main-CPjeMv0G.js +13500 -0
- package/dist/main.d.ts +1 -1
- package/dist/main.js +16 -109
- package/dist/matMul16-BNfZSnNM.js +81 -0
- package/dist/matMulGelu-CPTntosE.js +162 -0
- package/dist/models/NanoGPTV1.js +2 -99
- package/dist/models/NanoGPTV2.js +2 -90
- package/dist/models/config.js +34 -47
- package/dist/models/factory.d.ts +1 -1
- package/dist/models/factory.js +2 -16
- package/dist/models/model.d.ts +2 -2
- package/dist/models/model.js +2 -134
- package/dist/ops/adamAdjust.js +15 -6
- package/dist/ops/adamMoments.js +13 -6
- package/dist/ops/add16.js +9 -6
- package/dist/ops/appendCache.js +22 -19
- package/dist/ops/attentionMask.js +12 -6
- package/dist/ops/concat16.js +7 -8
- package/dist/ops/cpu/adamAdjust.js +15 -17
- package/dist/ops/cpu/adamMoments.js +15 -15
- package/dist/ops/cpu/appendCache.js +64 -22
- package/dist/ops/cpu/attentionMask.js +15 -21
- package/dist/ops/cpu/fusedSoftmax.js +21 -28
- package/dist/ops/cpu/gatherSub.js +11 -17
- package/dist/ops/cpu/gelu.js +34 -38
- package/dist/ops/cpu/matMul16.js +13 -14
- package/dist/ops/cpu/matMulGelu.js +39 -51
- package/dist/ops/cpu/matMulMul.js +19 -22
- package/dist/ops/cpu/mulDropout.js +19 -22
- package/dist/ops/cpu/normRMS.js +33 -37
- package/dist/ops/cpu/qkv.js +72 -40
- package/dist/ops/cpu/rope.js +79 -36
- package/dist/ops/cpu/scatterSub.js +11 -22
- package/dist/ops/dot16.js +28 -41
- package/dist/ops/dropout.js +10 -13
- package/dist/ops/dropout16.js +20 -23
- package/dist/ops/gatherSub.js +10 -6
- package/dist/ops/gelu.js +2 -8
- package/dist/ops/globalNorm.js +18 -12
- package/dist/ops/grads/add16.js +27 -26
- package/dist/ops/grads/attentionMask.js +26 -21
- package/dist/ops/grads/dropout16.js +0 -1
- package/dist/ops/grads/gelu.js +2 -5
- package/dist/ops/grads/matMul16.js +2 -9
- package/dist/ops/grads/matMulGelu.js +21 -16
- package/dist/ops/grads/mul16.js +0 -3
- package/dist/ops/grads/normRMS.js +34 -30
- package/dist/ops/grads/pack16.js +2 -6
- package/dist/ops/grads/qkv.js +44 -32
- package/dist/ops/grads/rope.js +2 -5
- package/dist/ops/grads/softmax16.js +21 -23
- package/dist/ops/grads/unpack16.js +2 -5
- package/dist/ops/grads/utils.js +9 -11
- package/dist/ops/matMul16.js +2 -13
- package/dist/ops/matMulGelu.js +16 -10
- package/dist/ops/matMulMul.js +13 -6
- package/dist/ops/mul16.js +42 -38
- package/dist/ops/mulDrop.js +12 -6
- package/dist/ops/normRMS.js +18 -15
- package/dist/ops/pack16.js +2 -5
- package/dist/ops/qkv.js +12 -6
- package/dist/ops/reshape16.js +31 -39
- package/dist/ops/rope.d.ts +1 -1
- package/dist/ops/rope.js +2 -7
- package/dist/ops/scatterSub.js +10 -6
- package/dist/ops/slice16.js +9 -7
- package/dist/ops/softmax16.js +7 -7
- package/dist/ops/sub16.js +9 -6
- package/dist/ops/sum16.js +12 -12
- package/dist/ops/transpose16.js +29 -37
- package/dist/ops/unpack16.js +2 -6
- package/dist/ops/webgl/adamAdjust.js +62 -29
- package/dist/ops/webgl/adamMoments.js +30 -26
- package/dist/ops/webgl/appendCache.js +30 -21
- package/dist/ops/webgl/attentionMask.js +43 -24
- package/dist/ops/webgl/dropout16.js +11 -10
- package/dist/ops/webgl/fusedSoftmax.js +69 -79
- package/dist/ops/webgl/gatherSub.js +27 -26
- package/dist/ops/webgl/gelu.js +32 -34
- package/dist/ops/webgl/log.js +14 -23
- package/dist/ops/webgl/matMul16.js +36 -44
- package/dist/ops/webgl/matMulGelu.js +2 -9
- package/dist/ops/webgl/matMulMul.js +23 -27
- package/dist/ops/webgl/mulDropout.js +31 -40
- package/dist/ops/webgl/normRMS.js +92 -71
- package/dist/ops/webgl/qkv.js +35 -27
- package/dist/ops/webgl/rope.js +37 -21
- package/dist/ops/webgl/scatterSub.js +27 -26
- package/dist/ops/webgpu/adamAdjust.js +59 -39
- package/dist/ops/webgpu/adamMoments.js +62 -46
- package/dist/ops/webgpu/add16.js +13 -12
- package/dist/ops/webgpu/appendCache.js +79 -54
- package/dist/ops/webgpu/attentionMask.js +41 -25
- package/dist/ops/webgpu/attentionMask32_program.js +34 -26
- package/dist/ops/webgpu/clipScale.js +44 -57
- package/dist/ops/webgpu/concat16.js +96 -111
- package/dist/ops/webgpu/dropout16.js +40 -32
- package/dist/ops/webgpu/gatherSub.js +43 -30
- package/dist/ops/webgpu/gelu.js +88 -82
- package/dist/ops/webgpu/index.js +16 -16
- package/dist/ops/webgpu/matMul16.js +69 -64
- package/dist/ops/webgpu/matMul16_program.js +152 -192
- package/dist/ops/webgpu/mul16.js +13 -12
- package/dist/ops/webgpu/norm2.js +45 -75
- package/dist/ops/webgpu/normRMS.js +25 -33
- package/dist/ops/webgpu/normRMS16_program.js +21 -18
- package/dist/ops/webgpu/normRMS32_program.js +21 -18
- package/dist/ops/webgpu/normRMSGrad.js +125 -184
- package/dist/ops/webgpu/pack16.js +20 -17
- package/dist/ops/webgpu/pack16_program.js +48 -47
- package/dist/ops/webgpu/qkv.js +63 -23
- package/dist/ops/webgpu/rope.js +85 -57
- package/dist/ops/webgpu/scatterSub.js +43 -30
- package/dist/ops/webgpu/slice16.js +66 -61
- package/dist/ops/webgpu/softmax16.js +17 -20
- package/dist/ops/webgpu/softmax16_program.js +34 -18
- package/dist/ops/webgpu/softmax16_subgroup_program.js +40 -45
- package/dist/ops/webgpu/softmax16grad.js +30 -36
- package/dist/ops/webgpu/sub16.js +13 -12
- package/dist/ops/webgpu/sum16.js +28 -37
- package/dist/ops/webgpu/transpose16.js +36 -33
- package/dist/ops/webgpu/transpose16_program.js +40 -39
- package/dist/ops/webgpu/transpose16_shared_program.js +53 -44
- package/dist/ops/webgpu/unpack16.js +49 -37
- package/dist/ops/webgpu/utils/binary_op.js +70 -68
- package/dist/ops/webgpu/utils/deviceInfo.d.ts +1 -1
- package/dist/ops/webgpu/utils/deviceInfo.js +10 -10
- package/dist/ops/webgpu/utils/reductions.js +136 -148
- package/dist/pack16-Ck-spx_F.js +39 -0
- package/dist/patches/webgpu_backend.d.ts +2 -2
- package/dist/patches/webgpu_backend.js +42 -55
- package/dist/patches/webgpu_base.js +21 -33
- package/dist/patches/webgpu_program.js +213 -320
- package/dist/pdf-UoDqCYzz.js +16726 -0
- package/dist/picomatch-3tUnMMbd.js +1063 -0
- package/dist/rope-CbeGlsV8.js +25 -0
- package/dist/selu_util-zkAx5doH.js +24 -0
- package/dist/shared-D1coEFea.js +1314 -0
- package/dist/shared-DOgWaqvL.js +5 -0
- package/dist/slice_util-Dgb3ANWI.js +208 -0
- package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
- package/dist/tokeniser/BaseTokeniser.js +2 -124
- package/dist/tokeniser/CharTokeniser.js +91 -106
- package/dist/tokeniser/bpe.js +163 -166
- package/dist/tokeniser/messages.js +0 -1
- package/dist/tokeniser/type.js +0 -1
- package/dist/training/AdamW.js +127 -137
- package/dist/training/BasicTrainer.d.ts +1 -1
- package/dist/training/BasicTrainer.js +264 -264
- package/dist/training/DatasetBuilder.js +2 -86
- package/dist/training/Evaluator.d.ts +1 -1
- package/dist/training/Evaluator.js +47 -38
- package/dist/training/LRScheduler.js +37 -33
- package/dist/training/PreTrainer.d.ts +2 -2
- package/dist/training/PreTrainer.js +21 -19
- package/dist/training/SFTTrainer.d.ts +2 -2
- package/dist/training/SFTTrainer.js +23 -21
- package/dist/training/loss.js +17 -22
- package/dist/training/orthoGrad.js +9 -9
- package/dist/training/sparseCrossEntropy.js +45 -67
- package/dist/training/tasks/ConversationTask.d.ts +1 -1
- package/dist/training/tasks/ConversationTask.js +36 -38
- package/dist/training/tasks/PretrainingTask.d.ts +1 -1
- package/dist/training/tasks/PretrainingTask.js +41 -46
- package/dist/training/tasks/StartSentenceTask.d.ts +1 -1
- package/dist/training/tasks/StartSentenceTask.js +44 -48
- package/dist/training/tasks/Task.d.ts +1 -1
- package/dist/training/tasks/Task.js +53 -66
- package/dist/training/tasks/splitter.js +17 -20
- package/dist/training/types.d.ts +2 -2
- package/dist/training/types.js +0 -1
- package/dist/training/validation.d.ts +1 -1
- package/dist/training/validation.js +2 -84
- package/dist/utilities/arrayClose.js +15 -19
- package/dist/utilities/datasetID.js +17 -20
- package/dist/utilities/dummy.d.ts +1 -1
- package/dist/utilities/dummy.js +33 -40
- package/dist/utilities/multinomialCPU.js +8 -12
- package/dist/utilities/naming.js +0 -1
- package/dist/utilities/packed.js +10 -12
- package/dist/utilities/parameters.d.ts +1 -1
- package/dist/utilities/parameters.js +32 -51
- package/dist/utilities/performance.js +15 -15
- package/dist/utilities/profile.js +32 -37
- package/dist/utilities/safetensors.js +49 -79
- package/dist/utilities/sentences.d.ts +1 -1
- package/dist/utilities/sentences.js +29 -38
- package/dist/utilities/tokenParse.js +16 -20
- package/dist/utilities/topP.js +11 -12
- package/dist/utilities/waitForModel.d.ts +1 -1
- package/dist/utilities/waitForModel.js +11 -11
- package/dist/utilities/weights.js +37 -42
- package/dist/utilities/yielder.js +6 -6
- package/dist/webgpu-Dt7BMzWz.js +525 -0
- package/dist/webgpu_program-WOyIVMlZ.js +392 -0
- package/dist/webgpu_util-B_F3SShA.js +106 -0
- package/package.json +9 -10
- package/dist/RealDiv-CGwv0liw.js +0 -365
- package/dist/Reshape-BW__R4mZ.js +0 -79
- package/dist/Reshape-CPBkTIH2.js +0 -14
- package/dist/_commonjsHelpers-ByX85dGu.js +0 -33
- package/dist/axis_util-GTVlo58H.js +0 -55
- package/dist/backend_util-GaFarB78.js +0 -425
- package/dist/backend_webgpu-BqASlsbV.js +0 -545
- package/dist/binary_op_util-pKXltfxI.js +0 -192
- package/dist/broadcast_to-eS93CCN_.js +0 -28
- package/dist/clip_by_value-DDA7rrcT.js +0 -12
- package/dist/complex-DI35Q-gW.js +0 -11
- package/dist/complex_util-Yc1A_gV1.js +0 -55
- package/dist/concat-CAQpCret.js +0 -17
- package/dist/concat_util-D18dJ4fD.js +0 -22
- package/dist/data/parquet.d.ts +0 -2
- package/dist/data/parquet.js +0 -17
- package/dist/dataset-CGGp1z9P.js +0 -1124
- package/dist/dropout_util--NxWuYg2.js +0 -27
- package/dist/expand_dims-Bkd1YD5x.js +0 -11
- package/dist/exports_initializers-CYzKLjN7.js +0 -7
- package/dist/floor-BQtb-Azg.js +0 -9
- package/dist/gather-qIqEqaGn.js +0 -9
- package/dist/gelu-B220X1Go.js +0 -26
- package/dist/gpgpu_math-BwvV12df.js +0 -2022
- package/dist/index-CUXkjxiT.js +0 -3516
- package/dist/index-CieiGp4Y.js +0 -349
- package/dist/index-CjOWnMXP.js +0 -7308
- package/dist/index-Cp39cXWe.js +0 -1016
- package/dist/index-D5v913EJ.js +0 -4
- package/dist/index-DmeWGGmS.js +0 -1074
- package/dist/index-DvYrXKkX.js +0 -113
- package/dist/index-Ksja3su6.js +0 -151
- package/dist/index-xuotMAFm.js +0 -118
- package/dist/jszip.min-BZhlzntC.js +0 -2313
- package/dist/kernel_funcs_utils-pq0CK9co.js +0 -306
- package/dist/matMul16-BcVC_E62.js +0 -80
- package/dist/matMulGelu-JNLZqKQp.js +0 -163
- package/dist/mat_mul-DhG0Newp.js +0 -11
- package/dist/mod-CSdCpRjf.js +0 -11
- package/dist/non_max_suppression_impl-B2W7YjZB.js +0 -102
- package/dist/not_equal-hurPF26l.js +0 -64
- package/dist/ones-BytntneX.js +0 -14
- package/dist/ops-CsXeTq1P.js +0 -476
- package/dist/pack16-bqltoUlR.js +0 -39
- package/dist/papaparse.min-C0cScC2i.js +0 -418
- package/dist/parquet-Bqjmp2vo.js +0 -44231
- package/dist/pdf-NIhmP3sq.js +0 -19477
- package/dist/rand_util-CZ7yLoUm.js +0 -50
- package/dist/random_normal-IBRrha8a.js +0 -14
- package/dist/random_width-DN5ZtQkM.js +0 -9796
- package/dist/range-C-CjF-LI.js +0 -10
- package/dist/relu-J_X6MUzx.js +0 -9
- package/dist/reshape-BDOuCSNW.js +0 -9
- package/dist/resize_nearest_neighbor-BojqlfRe.js +0 -150
- package/dist/rope-DcrZM_e6.js +0 -24
- package/dist/scatter_nd_util-ByNJaL6I.js +0 -46
- package/dist/segment_util-Dasb2Zaf.js +0 -43
- package/dist/selu_util-BLhIqRkw.js +0 -44
- package/dist/shared-3agzAqQ_.js +0 -53
- package/dist/shared-CagdqkLh.js +0 -2143
- package/dist/slice-BzS11Qh0.js +0 -12
- package/dist/slice_util-CC35pLmT.js +0 -153
- package/dist/softmax-D4q1LJN7.js +0 -12
- package/dist/split-C2Sj255c.js +0 -9
- package/dist/squeeze-ho4wLUek.js +0 -10
- package/dist/stack-DudVrtmG.js +0 -11
- package/dist/step-BTxPtq1r.js +0 -261
- package/dist/sum-BpiwSWvg.js +0 -11
- package/dist/tensor-BWFldCso.js +0 -8
- package/dist/tensor1d-LMGMIUlr.js +0 -11
- package/dist/tensor2d-BnXMKScO.js +0 -14
- package/dist/tensor4d-C6UCG_u8.js +0 -14
- package/dist/tfjs_backend-BGnG-ppu.js +0 -654
- package/dist/tile-CFy-xTO6.js +0 -11
- package/dist/transpose-9kRxIXWR.js +0 -36
- package/dist/unsorted_segment_sum-DJvk5xnh.js +0 -277
- package/dist/variable-Ck482e3n.js +0 -7
- package/dist/webgpu_program-B4HmApL1.js +0 -525
- package/dist/webgpu_util-DYlGSwOJ.js +0 -64
- package/dist/zeros-DvZpK8s6.js +0 -13
- package/dist/zeros_like-CWjDdwr-.js +0 -721
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
import { Dn as e, En as t, Io as n, Ks as r, Ms as i, Si as a, Tn as o, nc as s, oc as c, wn as l, xn as u } from "./dist-BewPQWjc.js";
|
|
2
|
+
import { L as d } from "./backend_util-Cg-roD1p.js";
|
|
3
|
+
import { o as f } from "./gpgpu_math-DvLcCH6u.js";
|
|
4
|
+
import { J as p, b as m } from "./shared-DOgWaqvL.js";
|
|
5
|
+
import { S as h, n as g } from "./kernel_funcs_utils-HiXOOx3f.js";
|
|
6
|
+
import { t as _ } from "./Reshape-CABOPB9d.js";
|
|
7
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/mean_gpu.js
|
|
8
|
+
var v = class {
|
|
9
|
+
constructor(e, t) {
|
|
10
|
+
this.variableNames = ["x"];
|
|
11
|
+
let { windowSize: n, batchSize: i, inSize: a, outSize: o } = e;
|
|
12
|
+
this.outputShape = [i, o];
|
|
13
|
+
let s = Math.floor(n / 4) * 4, c = n % 4, l = "sumValue += dot(values, ones);";
|
|
14
|
+
if (t != null) {
|
|
15
|
+
let e = 1 / t;
|
|
16
|
+
l = `sumValue += dot(values * ${r(e) ? e.toPrecision(2) : e}, ones);`;
|
|
17
|
+
}
|
|
18
|
+
let u = "";
|
|
19
|
+
a % n > 0 && (u = `
|
|
20
|
+
if (inIdx < 0 || inIdx >= ${a}) {
|
|
21
|
+
return 0.0;
|
|
22
|
+
}
|
|
23
|
+
`), this.userCode = `
|
|
24
|
+
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
25
|
+
|
|
26
|
+
float getValue(int batch, int inIdx) {
|
|
27
|
+
${u}
|
|
28
|
+
return getX(batch, inIdx);
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
void main() {
|
|
32
|
+
ivec2 coords = getOutputCoords();
|
|
33
|
+
int batch = coords[0];
|
|
34
|
+
int outIdx = coords[1];
|
|
35
|
+
int inOffset = outIdx * ${n};
|
|
36
|
+
|
|
37
|
+
float sumValue = 0.0;
|
|
38
|
+
|
|
39
|
+
for (int i = 0; i < ${s}; i += 4) {
|
|
40
|
+
int inIdx = inOffset + i;
|
|
41
|
+
vec4 values = vec4(
|
|
42
|
+
getValue(batch, inIdx),
|
|
43
|
+
getValue(batch, inIdx + 1),
|
|
44
|
+
getValue(batch, inIdx + 2),
|
|
45
|
+
getValue(batch, inIdx + 3)
|
|
46
|
+
);
|
|
47
|
+
|
|
48
|
+
${l}
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
int inIdx = inOffset + ${s};
|
|
52
|
+
if (${c === 1}) {
|
|
53
|
+
vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
|
|
54
|
+
|
|
55
|
+
${l}
|
|
56
|
+
} else if (${c === 2}) {
|
|
57
|
+
vec4 values = vec4(
|
|
58
|
+
getValue(batch, inIdx),
|
|
59
|
+
getValue(batch, inIdx + 1), 0.0, 0.0);
|
|
60
|
+
|
|
61
|
+
${l}
|
|
62
|
+
} else if (${c === 3}) {
|
|
63
|
+
vec4 values = vec4(
|
|
64
|
+
getValue(batch, inIdx),
|
|
65
|
+
getValue(batch, inIdx + 1),
|
|
66
|
+
getValue(batch, inIdx + 2), 0.0);
|
|
67
|
+
|
|
68
|
+
${l}
|
|
69
|
+
}
|
|
70
|
+
setOutput(sumValue);
|
|
71
|
+
}
|
|
72
|
+
`;
|
|
73
|
+
}
|
|
74
|
+
}, y = class {
|
|
75
|
+
constructor(e, t) {
|
|
76
|
+
this.variableNames = ["x"];
|
|
77
|
+
let { windowSize: n, batchSize: r, inSize: i, outSize: a } = e;
|
|
78
|
+
this.outputShape = [r, a];
|
|
79
|
+
let o = "0.0", s = "";
|
|
80
|
+
t === "prod" ? o = "1.0" : t === "min" ? (o = "1.0 / 1e-20", s = "min") : t === "max" && (o = "-1.0 / 1e-20", s = "max");
|
|
81
|
+
let c = `${t}(${t}(${t}(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])`;
|
|
82
|
+
t === "sum" ? c = "sumValue" : t === "prod" ? c = "prodValue" : t === "all" ? c = "allValue" : t === "any" && (c = "anyValue");
|
|
83
|
+
let l = Math.floor(n / 4) * 4, u = n % 4, d = `
|
|
84
|
+
if (${t === "sum"}) {
|
|
85
|
+
sumValue += dot(values, ones);
|
|
86
|
+
} else if (${t === "prod"}) {
|
|
87
|
+
vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
|
|
88
|
+
prodValue *= tmp[0] * tmp[1];
|
|
89
|
+
} else {
|
|
90
|
+
minMaxValue = ${s}(values, minMaxValue);
|
|
91
|
+
if (${t === "min"} || ${t === "max"}) {
|
|
92
|
+
minMaxValue = ${s}(values, minMaxValue);
|
|
93
|
+
bvec4 isNaN = isnan(values);
|
|
94
|
+
if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
|
|
95
|
+
minMaxValue = vec4(NAN);
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
`, f = "vec4";
|
|
100
|
+
t === "all" ? (o = "1.0", d = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ", f = "bvec4") : t === "any" && (o = "0.0", d = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ", f = "bvec4");
|
|
101
|
+
let p = "";
|
|
102
|
+
i % n > 0 && (p = `
|
|
103
|
+
if (inIdx < 0 || inIdx >= ${i}) {
|
|
104
|
+
return initializationValue;
|
|
105
|
+
}
|
|
106
|
+
`), this.userCode = `
|
|
107
|
+
const float initializationValue = ${o};
|
|
108
|
+
const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
|
|
109
|
+
|
|
110
|
+
float getValue(int batch, int inIdx) {
|
|
111
|
+
${p}
|
|
112
|
+
return getX(batch, inIdx);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
void main() {
|
|
116
|
+
ivec2 coords = getOutputCoords();
|
|
117
|
+
int batch = coords[0];
|
|
118
|
+
int outIdx = coords[1];
|
|
119
|
+
int inOffset = outIdx * ${n};
|
|
120
|
+
|
|
121
|
+
vec4 minMaxValue = vec4(${o});
|
|
122
|
+
float prodValue = 1.0;
|
|
123
|
+
float sumValue = 0.0;
|
|
124
|
+
float allValue = 1.0;
|
|
125
|
+
float anyValue = 0.0;
|
|
126
|
+
|
|
127
|
+
for (int i = 0; i < ${l}; i += 4) {
|
|
128
|
+
int inIdx = inOffset + i;
|
|
129
|
+
${f} values = ${f}(
|
|
130
|
+
getValue(batch, inIdx),
|
|
131
|
+
getValue(batch, inIdx + 1),
|
|
132
|
+
getValue(batch, inIdx + 2),
|
|
133
|
+
getValue(batch, inIdx + 3)
|
|
134
|
+
);
|
|
135
|
+
|
|
136
|
+
${d}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
int inIdx = inOffset + ${l};
|
|
140
|
+
if (${u === 1}) {
|
|
141
|
+
${f} values = ${f}(
|
|
142
|
+
getValue(batch, inIdx),
|
|
143
|
+
initializationValue,
|
|
144
|
+
initializationValue,
|
|
145
|
+
initializationValue
|
|
146
|
+
);
|
|
147
|
+
|
|
148
|
+
${d}
|
|
149
|
+
} else if (${u === 2}) {
|
|
150
|
+
${f} values = ${f}(
|
|
151
|
+
getValue(batch, inIdx),
|
|
152
|
+
getValue(batch, inIdx + 1),
|
|
153
|
+
initializationValue,
|
|
154
|
+
initializationValue
|
|
155
|
+
);
|
|
156
|
+
|
|
157
|
+
${d}
|
|
158
|
+
} else if (${u === 3}) {
|
|
159
|
+
${f} values = ${f}(
|
|
160
|
+
getValue(batch, inIdx),
|
|
161
|
+
getValue(batch, inIdx + 1),
|
|
162
|
+
getValue(batch, inIdx + 2),
|
|
163
|
+
initializationValue
|
|
164
|
+
);
|
|
165
|
+
|
|
166
|
+
${d}
|
|
167
|
+
}
|
|
168
|
+
setOutput(${c});
|
|
169
|
+
}
|
|
170
|
+
`;
|
|
171
|
+
}
|
|
172
|
+
};
|
|
173
|
+
//#endregion
|
|
174
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernel_utils/reduce.js
|
|
175
|
+
function b(e) {
|
|
176
|
+
let t = [];
|
|
177
|
+
for (; t.length === 0 || t[t.length - 1].outSize !== 1;) {
|
|
178
|
+
let n = t.length ? t[t.length - 1].outSize : e[1], r = d(n);
|
|
179
|
+
t.push({
|
|
180
|
+
inSize: n,
|
|
181
|
+
windowSize: r,
|
|
182
|
+
outSize: Math.ceil(n / r)
|
|
183
|
+
});
|
|
184
|
+
}
|
|
185
|
+
return t;
|
|
186
|
+
}
|
|
187
|
+
function x(e, t, n, r) {
|
|
188
|
+
let i = b(e.shape), a = e;
|
|
189
|
+
for (let o = 0; o < i.length; o++) {
|
|
190
|
+
let { inSize: s, windowSize: c, outSize: l } = i[o], u, d;
|
|
191
|
+
u = n === "mean" ? o === 0 ? new v({
|
|
192
|
+
windowSize: c,
|
|
193
|
+
inSize: s,
|
|
194
|
+
batchSize: e.shape[0],
|
|
195
|
+
outSize: l
|
|
196
|
+
}, s) : new v({
|
|
197
|
+
windowSize: c,
|
|
198
|
+
inSize: s,
|
|
199
|
+
batchSize: e.shape[0],
|
|
200
|
+
outSize: l
|
|
201
|
+
}) : new y({
|
|
202
|
+
windowSize: c,
|
|
203
|
+
inSize: s,
|
|
204
|
+
batchSize: e.shape[0],
|
|
205
|
+
outSize: l
|
|
206
|
+
}, n), d = a, a = r.runWebGLProgram(u, [a], t), d.dataId !== e.dataId && r.disposeIntermediateTensorInfo(d);
|
|
207
|
+
}
|
|
208
|
+
return a;
|
|
209
|
+
}
|
|
210
|
+
//#endregion
|
|
211
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/transpose_gpu.js
|
|
212
|
+
var S = class {
|
|
213
|
+
constructor(e, t) {
|
|
214
|
+
this.variableNames = ["A"];
|
|
215
|
+
let n = Array(e.length);
|
|
216
|
+
for (let r = 0; r < n.length; r++) n[r] = e[t[r]];
|
|
217
|
+
this.outputShape = n, this.rank = n.length;
|
|
218
|
+
let r = f(this.rank), i = C(t);
|
|
219
|
+
this.userCode = `
|
|
220
|
+
void main() {
|
|
221
|
+
${r} resRC = getOutputCoords();
|
|
222
|
+
setOutput(getA(${i}));
|
|
223
|
+
}
|
|
224
|
+
`;
|
|
225
|
+
}
|
|
226
|
+
};
|
|
227
|
+
function C(e) {
|
|
228
|
+
let t = e.length;
|
|
229
|
+
if (t > 6) throw Error(`Transpose for rank ${t} is not yet supported`);
|
|
230
|
+
let n = [
|
|
231
|
+
"resRC.x",
|
|
232
|
+
"resRC.y",
|
|
233
|
+
"resRC.z",
|
|
234
|
+
"resRC.w",
|
|
235
|
+
"resRC.u",
|
|
236
|
+
"resRC.v"
|
|
237
|
+
], r = Array(t);
|
|
238
|
+
for (let t = 0; t < e.length; t++) r[e[t]] = n[t];
|
|
239
|
+
return r.join();
|
|
240
|
+
}
|
|
241
|
+
//#endregion
|
|
242
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/transpose_packed_gpu.js
|
|
243
|
+
var w = class {
|
|
244
|
+
constructor(e, t) {
|
|
245
|
+
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
|
|
246
|
+
let n = Array(e.length);
|
|
247
|
+
for (let r = 0; r < n.length; r++) n[r] = e[t[r]];
|
|
248
|
+
if (this.outputShape = n, this.rank = n.length, this.rank > 6) throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
|
|
249
|
+
let r = f(this.rank), i = h("rc", this.rank), a = Array(this.rank);
|
|
250
|
+
for (let e = 0; e < t.length; e++) a[t[e]] = i[e];
|
|
251
|
+
let o = `vec2(${a.slice(-2).join()})`, s = `++${i[this.rank - 1]} < ${n[this.rank - 1]}`, c = `getChannel(getA(${a.join()}), ${o})`;
|
|
252
|
+
this.userCode = `
|
|
253
|
+
void main() {
|
|
254
|
+
${r} rc = getOutputCoords();
|
|
255
|
+
vec4 result = vec4(0.);
|
|
256
|
+
result[0] = ${c};
|
|
257
|
+
if(${s}) {
|
|
258
|
+
result[1] = ${c};
|
|
259
|
+
}
|
|
260
|
+
--${i[this.rank - 1]};
|
|
261
|
+
if(++${i[this.rank - 2]} < ${n[this.rank - 2]}) {
|
|
262
|
+
result[2] = ${c};
|
|
263
|
+
if(${s}) {
|
|
264
|
+
result[3] = ${c};
|
|
265
|
+
}
|
|
266
|
+
}
|
|
267
|
+
setOutput(result);
|
|
268
|
+
}
|
|
269
|
+
`;
|
|
270
|
+
}
|
|
271
|
+
};
|
|
272
|
+
//#endregion
|
|
273
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Transpose_impl.js
|
|
274
|
+
function T(e, t, n) {
|
|
275
|
+
let r = i().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new w(e.shape, t) : new S(e.shape, t);
|
|
276
|
+
return n.runWebGLProgram(r, [e], e.dtype);
|
|
277
|
+
}
|
|
278
|
+
//#endregion
|
|
279
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Sum_impl.js
|
|
280
|
+
function E(n, r, i, d) {
|
|
281
|
+
let f = r, p = n.shape.length, m = s(f, n.shape), h = m, g = t(h, p), v = g != null, y = n;
|
|
282
|
+
v && (y = T(n, g, d), h = e(h.length, p)), u("sum", h, p);
|
|
283
|
+
let [b, S] = l(y.shape, h), C = b;
|
|
284
|
+
i && (C = o(b, m));
|
|
285
|
+
let w = c(S), E = c(n.shape) / w, D = _({
|
|
286
|
+
inputs: { x: y },
|
|
287
|
+
attrs: { shape: [E, w] },
|
|
288
|
+
backend: d
|
|
289
|
+
}), O = x(D, a(n.dtype), "sum", d), k = _({
|
|
290
|
+
inputs: { x: O },
|
|
291
|
+
attrs: { shape: C },
|
|
292
|
+
backend: d
|
|
293
|
+
});
|
|
294
|
+
return d.disposeIntermediateTensorInfo(D), d.disposeIntermediateTensorInfo(O), v && d.disposeIntermediateTensorInfo(y), k;
|
|
295
|
+
}
|
|
296
|
+
//#endregion
|
|
297
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Sum.js
|
|
298
|
+
function D(e) {
|
|
299
|
+
let { inputs: t, backend: n, attrs: r } = e, { x: i } = t, { axis: a, keepDims: o } = r;
|
|
300
|
+
return E(i, a, o, n);
|
|
301
|
+
}
|
|
302
|
+
var O = {
|
|
303
|
+
kernelName: "Sum",
|
|
304
|
+
backendName: "webgl",
|
|
305
|
+
kernelFunc: D
|
|
306
|
+
};
|
|
307
|
+
//#endregion
|
|
308
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Max_impl.js
|
|
309
|
+
function k(e, t, n, r) {
|
|
310
|
+
let i = c(t), a = c(e.shape) / i, o = _({
|
|
311
|
+
inputs: { x: e },
|
|
312
|
+
attrs: { shape: [a, i] },
|
|
313
|
+
backend: r
|
|
314
|
+
}), s = x(o, e.dtype, "max", r), l = _({
|
|
315
|
+
inputs: { x: s },
|
|
316
|
+
attrs: { shape: n },
|
|
317
|
+
backend: r
|
|
318
|
+
});
|
|
319
|
+
return r.disposeIntermediateTensorInfo(o), r.disposeIntermediateTensorInfo(s), l;
|
|
320
|
+
}
|
|
321
|
+
//#endregion
|
|
322
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Max.js
|
|
323
|
+
function A(n) {
|
|
324
|
+
let { inputs: r, backend: i, attrs: a } = n, { x: d } = r, { reductionIndices: f, keepDims: h } = a, g = d.shape.length, _ = s(f, d.shape), v = _, y = t(v, g), b = y != null, x = i.shouldExecuteOnCPU([d]), S = d;
|
|
325
|
+
if (b) {
|
|
326
|
+
if (x) {
|
|
327
|
+
let e = i.texData.get(S.dataId).values, t = Array(g);
|
|
328
|
+
for (let e = 0; e < t.length; e++) t[e] = d.shape[y[e]];
|
|
329
|
+
let n = p(e, d.shape, d.dtype, y, t);
|
|
330
|
+
S = i.makeTensorInfo(t, d.dtype);
|
|
331
|
+
let r = i.texData.get(S.dataId);
|
|
332
|
+
r.values = n;
|
|
333
|
+
} else S = T(d, y, i);
|
|
334
|
+
v = e(v.length, g);
|
|
335
|
+
}
|
|
336
|
+
u("max", v, g);
|
|
337
|
+
let [C, w] = l(S.shape, v), E = C;
|
|
338
|
+
h && (E = o(C, _));
|
|
339
|
+
let D;
|
|
340
|
+
if (x) {
|
|
341
|
+
let e = i.texData.get(S.dataId).values, t = m(e, c(w), E, d.dtype);
|
|
342
|
+
D = i.makeTensorInfo(E, d.dtype);
|
|
343
|
+
let n = i.texData.get(D.dataId);
|
|
344
|
+
n.values = t;
|
|
345
|
+
} else D = k(S, w, E, i);
|
|
346
|
+
return b && i.disposeIntermediateTensorInfo(S), D;
|
|
347
|
+
}
|
|
348
|
+
var j = {
|
|
349
|
+
kernelName: "Max",
|
|
350
|
+
backendName: "webgl",
|
|
351
|
+
kernelFunc: A
|
|
352
|
+
}, M = g({
|
|
353
|
+
opSnippet: "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;",
|
|
354
|
+
packedOpSnippet: "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n",
|
|
355
|
+
checkOutOfBounds: !0
|
|
356
|
+
}), N = {
|
|
357
|
+
kernelName: n,
|
|
358
|
+
backendName: "webgl",
|
|
359
|
+
kernelFunc: M
|
|
360
|
+
};
|
|
361
|
+
//#endregion
|
|
362
|
+
export { D as a, x as c, j as i, N as n, O as o, A as r, T as s, M as t };
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import { Bo as e, Gs as t, Ps as n, oc as r } from "./dist-BewPQWjc.js";
|
|
2
|
+
import { E as i, a, c as o, d as s, j as c, l, u, z as d } from "./gpgpu_math-DvLcCH6u.js";
|
|
3
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/reshape_packed_gpu.js
|
|
4
|
+
var f = class {
|
|
5
|
+
constructor(e, t) {
|
|
6
|
+
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{
|
|
7
|
+
name: "inputShape",
|
|
8
|
+
type: "ivec3"
|
|
9
|
+
}], this.outputShape = e, this.enableShapeUniforms = a(this.outputShape.length);
|
|
10
|
+
let n = "";
|
|
11
|
+
for (let e = 0; e < 4; e++) {
|
|
12
|
+
let t = "thisRC = rc;";
|
|
13
|
+
e % 2 == 1 && (t += "thisRC.z += 1;"), e > 1 && (t += "thisRC.y += 1;"), n += `
|
|
14
|
+
${t}
|
|
15
|
+
${e > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : ""}
|
|
16
|
+
int flatIndex = getFlatIndex(thisRC);
|
|
17
|
+
|
|
18
|
+
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
|
|
19
|
+
vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
|
|
20
|
+
|
|
21
|
+
result[${e}] =
|
|
22
|
+
getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
|
|
23
|
+
${e > 0 ? "}" : ""}
|
|
24
|
+
`;
|
|
25
|
+
}
|
|
26
|
+
this.userCode = `
|
|
27
|
+
${p(t, this.enableShapeUniforms)}
|
|
28
|
+
${this.enableShapeUniforms ? l() : o(e)}
|
|
29
|
+
|
|
30
|
+
void main() {
|
|
31
|
+
ivec3 rc = getOutputCoords();
|
|
32
|
+
|
|
33
|
+
vec4 result = vec4(0.);
|
|
34
|
+
|
|
35
|
+
ivec3 thisRC;
|
|
36
|
+
int rows = ${this.enableShapeUniforms ? "outShape[1]" : e[1]};
|
|
37
|
+
int cols = ${this.enableShapeUniforms ? "outShape[2]" : e[2]};
|
|
38
|
+
|
|
39
|
+
${n}
|
|
40
|
+
|
|
41
|
+
setOutput(result);
|
|
42
|
+
}
|
|
43
|
+
`;
|
|
44
|
+
}
|
|
45
|
+
};
|
|
46
|
+
function p(e, t) {
|
|
47
|
+
return `
|
|
48
|
+
ivec3 inputCoordsFromReshapedOutCoords(int index) {
|
|
49
|
+
${t ? s([
|
|
50
|
+
"r",
|
|
51
|
+
"c",
|
|
52
|
+
"d"
|
|
53
|
+
], "inputShape") : u([
|
|
54
|
+
"r",
|
|
55
|
+
"c",
|
|
56
|
+
"d"
|
|
57
|
+
], e)}
|
|
58
|
+
return ivec3(r, c, d);
|
|
59
|
+
}
|
|
60
|
+
`;
|
|
61
|
+
}
|
|
62
|
+
//#endregion
|
|
63
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernel_utils/reshape.js
|
|
64
|
+
function m(e, t, n) {
|
|
65
|
+
let r = [i(e.shape), ...c(e.shape)], a = {
|
|
66
|
+
dtype: e.dtype,
|
|
67
|
+
shape: r,
|
|
68
|
+
dataId: e.dataId
|
|
69
|
+
}, o = new f([i(t), ...c(t)], r), s = [r], l = n.runWebGLProgram(o, [a], e.dtype, s, !0);
|
|
70
|
+
return {
|
|
71
|
+
dataId: l.dataId,
|
|
72
|
+
shape: t,
|
|
73
|
+
dtype: l.dtype
|
|
74
|
+
};
|
|
75
|
+
}
|
|
76
|
+
//#endregion
|
|
77
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Reshape.js
|
|
78
|
+
function h(e) {
|
|
79
|
+
let { inputs: i, backend: a, attrs: o } = e, { x: s } = i, { shape: c } = o, l = a, u = r(s.shape), f = t(c, u), p = r(f);
|
|
80
|
+
n(u === p, () => `The new shape (${f}) has ${p} elements and the old shape (${s.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
|
|
81
|
+
let h = l.texData.get(s.dataId);
|
|
82
|
+
return h.isPacked && !d(s.shape, f) && !(h.texture !== null && d(h.shape, f)) ? m(s, f, l) : (l.incRef(s.dataId), {
|
|
83
|
+
dataId: s.dataId,
|
|
84
|
+
shape: f,
|
|
85
|
+
dtype: s.dtype
|
|
86
|
+
});
|
|
87
|
+
}
|
|
88
|
+
var g = {
|
|
89
|
+
kernelName: e,
|
|
90
|
+
backendName: "webgl",
|
|
91
|
+
kernelFunc: h
|
|
92
|
+
};
|
|
93
|
+
//#endregion
|
|
94
|
+
export { g as n, f as r, h as t };
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import { Bo as e, Gs as t, Ps as n, oc as r } from "./dist-BewPQWjc.js";
|
|
2
|
+
//#region node_modules/@tensorflow/tfjs-backend-webgpu/dist/kernels/Reshape.js
|
|
3
|
+
function i(e) {
|
|
4
|
+
let { inputs: i, attrs: a } = e, { x: o } = i, { shape: s } = a, c = r(o.shape), l = t(s, c), u = r(l);
|
|
5
|
+
return n(c === u, () => `The new shape (${l}) has ${u} elements and the old shape (${o.shape}) has ${c} elements. The new shape and old shape must have the same number of elements.`), e.backend.incRef(o.dataId), {
|
|
6
|
+
dataId: o.dataId,
|
|
7
|
+
shape: l,
|
|
8
|
+
dtype: o.dtype
|
|
9
|
+
};
|
|
10
|
+
}
|
|
11
|
+
var a = {
|
|
12
|
+
kernelName: e,
|
|
13
|
+
backendName: "webgpu",
|
|
14
|
+
kernelFunc: i
|
|
15
|
+
};
|
|
16
|
+
//#endregion
|
|
17
|
+
export { a as n, i as t };
|
package/dist/TeachableLLM.d.ts
CHANGED
|
@@ -8,7 +8,7 @@ import { default as MemoryProfiler } from './utilities/profile';
|
|
|
8
8
|
import { default as Model, ModelForwardAttributes } from './models/model';
|
|
9
9
|
import { Task } from './training/tasks/Task';
|
|
10
10
|
import { TrainingLogEntry, TrainingOptions } from './training/types';
|
|
11
|
-
import {
|
|
11
|
+
import { ModelMode, TransformersMetadata } from './loader/types';
|
|
12
12
|
type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
|
|
13
13
|
export default class TeachableLLM {
|
|
14
14
|
private ee;
|
|
@@ -22,8 +22,8 @@ export default class TeachableLLM {
|
|
|
22
22
|
constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes, GPTConfig>);
|
|
23
23
|
get currentTrainer(): Trainer | null;
|
|
24
24
|
get vocab(): string[];
|
|
25
|
-
get
|
|
26
|
-
set
|
|
25
|
+
get mode(): ModelMode;
|
|
26
|
+
set mode(mode: ModelMode);
|
|
27
27
|
/** Model is fully loaded */
|
|
28
28
|
get loaded(): boolean;
|
|
29
29
|
get config(): GPTConfig;
|
|
@@ -57,12 +57,12 @@ export default class TeachableLLM {
|
|
|
57
57
|
generateText(options?: IGenerateOptions): Promise<Conversation[]>;
|
|
58
58
|
dispose(): void;
|
|
59
59
|
on(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
|
|
60
|
-
on(event: '
|
|
60
|
+
on(event: 'mode', listener: (mode: ModelMode) => void): void;
|
|
61
61
|
on(event: 'error', listener: (error: Error) => void): void;
|
|
62
62
|
on(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
|
|
63
63
|
on(event: 'loaded' | 'changeLoRA', listener: () => void): void;
|
|
64
64
|
off(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
|
|
65
|
-
off(event: '
|
|
65
|
+
off(event: 'mode', listener: (mode: ModelMode) => void): void;
|
|
66
66
|
off(event: 'error', listener: (error: Error) => void): void;
|
|
67
67
|
off(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
|
|
68
68
|
off(event: 'loaded' | 'changeLoRA', listener: () => void): void;
|