@genai-fi/nanogpt 0.19.1 → 0.20.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (347) hide show
  1. package/dist/BaseTokeniser-DSg9zcYq.js +221 -0
  2. package/dist/DatasetBuilder-DgURD85T.js +712 -0
  3. package/dist/Generator.js +2 -11941
  4. package/dist/RealDiv-DBu0FQqT.js +362 -0
  5. package/dist/Reshape-CABOPB9d.js +94 -0
  6. package/dist/Reshape-DqO3r8BC.js +17 -0
  7. package/dist/TeachableLLM.d.ts +5 -5
  8. package/dist/TeachableLLM.js +2 -273
  9. package/dist/Trainer.js +2 -244
  10. package/dist/backend.js +12 -12
  11. package/dist/backend_util-Cg-roD1p.js +399 -0
  12. package/dist/binary_op_util-CrYk9LXL.js +103 -0
  13. package/dist/checks/appendCache.js +54 -21
  14. package/dist/checks/attentionMask.js +55 -36
  15. package/dist/checks/check.js +31 -19
  16. package/dist/checks/gelu.js +45 -17
  17. package/dist/checks/index.js +25 -25
  18. package/dist/checks/matMulGelu.js +83 -27
  19. package/dist/checks/normRMS.js +27 -15
  20. package/dist/checks/normRMSGrad.js +21 -11
  21. package/dist/checks/packUnpack.js +45 -17
  22. package/dist/checks/qkv.js +33 -33
  23. package/dist/checks/rope.js +29 -35
  24. package/dist/checks/weights.d.ts +1 -1
  25. package/dist/checks/weights.js +25 -29
  26. package/dist/chunk-BPntVaq0.js +23 -0
  27. package/dist/complex_util-CkazZsaH.js +60 -0
  28. package/dist/concat_util-CWDZCBlA.js +19 -0
  29. package/dist/data/docx.js +3044 -13
  30. package/dist/data/pdf.js +16 -13
  31. package/dist/data/textLoader.js +607 -112
  32. package/dist/dist-BewPQWjc.js +7572 -0
  33. package/dist/dist-DVmq73nz.js +8775 -0
  34. package/dist/dist-DXwIvKxl.js +896 -0
  35. package/dist/dist-VEU5mfO0.js +7545 -0
  36. package/dist/gelu-Bf1HW1RY.js +27 -0
  37. package/dist/gpgpu_math-DvLcCH6u.js +1612 -0
  38. package/dist/inference/types.js +0 -1
  39. package/dist/kernel_funcs_utils-HiXOOx3f.js +229 -0
  40. package/dist/layers/BaseLayer.d.ts +2 -2
  41. package/dist/layers/BaseLayer.js +75 -73
  42. package/dist/layers/CausalSelfAttention.d.ts +1 -1
  43. package/dist/layers/CausalSelfAttention.js +98 -85
  44. package/dist/layers/LoRA.js +47 -57
  45. package/dist/layers/MLP.d.ts +1 -1
  46. package/dist/layers/MLP.js +33 -43
  47. package/dist/layers/PositionEmbedding.d.ts +1 -1
  48. package/dist/layers/PositionEmbedding.js +26 -30
  49. package/dist/layers/RMSNorm.d.ts +1 -1
  50. package/dist/layers/RMSNorm.js +19 -21
  51. package/dist/layers/RoPECache.js +336 -49
  52. package/dist/layers/TiedEmbedding.d.ts +1 -1
  53. package/dist/layers/TiedEmbedding.js +30 -34
  54. package/dist/layers/TransformerBlock.d.ts +1 -1
  55. package/dist/layers/TransformerBlock.js +50 -39
  56. package/dist/layers/WeightStore.js +68 -75
  57. package/dist/loader/load.js +2 -68
  58. package/dist/loader/loadHF.d.ts +2 -2
  59. package/dist/loader/loadHF.js +2 -22
  60. package/dist/loader/loadTransformers.d.ts +1 -1
  61. package/dist/loader/loadTransformers.js +2 -44
  62. package/dist/loader/loadZipMeta.js +15 -15
  63. package/dist/loader/newZipLoad.js +2 -31
  64. package/dist/loader/oldZipLoad.d.ts +2 -2
  65. package/dist/loader/oldZipLoad.js +2 -80
  66. package/dist/loader/save.d.ts +5 -5
  67. package/dist/loader/save.js +2 -90
  68. package/dist/loader/types.d.ts +9 -8
  69. package/dist/loader/types.js +0 -1
  70. package/dist/main-CPjeMv0G.js +13500 -0
  71. package/dist/main.d.ts +1 -1
  72. package/dist/main.js +16 -109
  73. package/dist/matMul16-BNfZSnNM.js +81 -0
  74. package/dist/matMulGelu-CPTntosE.js +162 -0
  75. package/dist/models/NanoGPTV1.js +2 -99
  76. package/dist/models/NanoGPTV2.js +2 -90
  77. package/dist/models/config.js +34 -47
  78. package/dist/models/factory.d.ts +1 -1
  79. package/dist/models/factory.js +2 -16
  80. package/dist/models/model.d.ts +2 -2
  81. package/dist/models/model.js +2 -134
  82. package/dist/ops/adamAdjust.js +15 -6
  83. package/dist/ops/adamMoments.js +13 -6
  84. package/dist/ops/add16.js +9 -6
  85. package/dist/ops/appendCache.js +22 -19
  86. package/dist/ops/attentionMask.js +12 -6
  87. package/dist/ops/concat16.js +7 -8
  88. package/dist/ops/cpu/adamAdjust.js +15 -17
  89. package/dist/ops/cpu/adamMoments.js +15 -15
  90. package/dist/ops/cpu/appendCache.js +64 -22
  91. package/dist/ops/cpu/attentionMask.js +15 -21
  92. package/dist/ops/cpu/fusedSoftmax.js +21 -28
  93. package/dist/ops/cpu/gatherSub.js +11 -17
  94. package/dist/ops/cpu/gelu.js +34 -38
  95. package/dist/ops/cpu/matMul16.js +13 -14
  96. package/dist/ops/cpu/matMulGelu.js +39 -51
  97. package/dist/ops/cpu/matMulMul.js +19 -22
  98. package/dist/ops/cpu/mulDropout.js +19 -22
  99. package/dist/ops/cpu/normRMS.js +33 -37
  100. package/dist/ops/cpu/qkv.js +72 -40
  101. package/dist/ops/cpu/rope.js +79 -36
  102. package/dist/ops/cpu/scatterSub.js +11 -22
  103. package/dist/ops/dot16.js +28 -41
  104. package/dist/ops/dropout.js +10 -13
  105. package/dist/ops/dropout16.js +20 -23
  106. package/dist/ops/gatherSub.js +10 -6
  107. package/dist/ops/gelu.js +2 -8
  108. package/dist/ops/globalNorm.js +18 -12
  109. package/dist/ops/grads/add16.js +27 -26
  110. package/dist/ops/grads/attentionMask.js +26 -21
  111. package/dist/ops/grads/dropout16.js +0 -1
  112. package/dist/ops/grads/gelu.js +2 -5
  113. package/dist/ops/grads/matMul16.js +2 -9
  114. package/dist/ops/grads/matMulGelu.js +21 -16
  115. package/dist/ops/grads/mul16.js +0 -3
  116. package/dist/ops/grads/normRMS.js +34 -30
  117. package/dist/ops/grads/pack16.js +2 -6
  118. package/dist/ops/grads/qkv.js +44 -32
  119. package/dist/ops/grads/rope.js +2 -5
  120. package/dist/ops/grads/softmax16.js +21 -23
  121. package/dist/ops/grads/unpack16.js +2 -5
  122. package/dist/ops/grads/utils.js +9 -11
  123. package/dist/ops/matMul16.js +2 -13
  124. package/dist/ops/matMulGelu.js +16 -10
  125. package/dist/ops/matMulMul.js +13 -6
  126. package/dist/ops/mul16.js +42 -38
  127. package/dist/ops/mulDrop.js +12 -6
  128. package/dist/ops/normRMS.js +18 -15
  129. package/dist/ops/pack16.js +2 -5
  130. package/dist/ops/qkv.js +12 -6
  131. package/dist/ops/reshape16.js +31 -39
  132. package/dist/ops/rope.d.ts +1 -1
  133. package/dist/ops/rope.js +2 -7
  134. package/dist/ops/scatterSub.js +10 -6
  135. package/dist/ops/slice16.js +9 -7
  136. package/dist/ops/softmax16.js +7 -7
  137. package/dist/ops/sub16.js +9 -6
  138. package/dist/ops/sum16.js +12 -12
  139. package/dist/ops/transpose16.js +29 -37
  140. package/dist/ops/unpack16.js +2 -6
  141. package/dist/ops/webgl/adamAdjust.js +62 -29
  142. package/dist/ops/webgl/adamMoments.js +30 -26
  143. package/dist/ops/webgl/appendCache.js +30 -21
  144. package/dist/ops/webgl/attentionMask.js +43 -24
  145. package/dist/ops/webgl/dropout16.js +11 -10
  146. package/dist/ops/webgl/fusedSoftmax.js +69 -79
  147. package/dist/ops/webgl/gatherSub.js +27 -26
  148. package/dist/ops/webgl/gelu.js +32 -34
  149. package/dist/ops/webgl/log.js +14 -23
  150. package/dist/ops/webgl/matMul16.js +36 -44
  151. package/dist/ops/webgl/matMulGelu.js +2 -9
  152. package/dist/ops/webgl/matMulMul.js +23 -27
  153. package/dist/ops/webgl/mulDropout.js +31 -40
  154. package/dist/ops/webgl/normRMS.js +92 -71
  155. package/dist/ops/webgl/qkv.js +35 -27
  156. package/dist/ops/webgl/rope.js +37 -21
  157. package/dist/ops/webgl/scatterSub.js +27 -26
  158. package/dist/ops/webgpu/adamAdjust.js +59 -39
  159. package/dist/ops/webgpu/adamMoments.js +62 -46
  160. package/dist/ops/webgpu/add16.js +13 -12
  161. package/dist/ops/webgpu/appendCache.js +79 -54
  162. package/dist/ops/webgpu/attentionMask.js +41 -25
  163. package/dist/ops/webgpu/attentionMask32_program.js +34 -26
  164. package/dist/ops/webgpu/clipScale.js +44 -57
  165. package/dist/ops/webgpu/concat16.js +96 -111
  166. package/dist/ops/webgpu/dropout16.js +40 -32
  167. package/dist/ops/webgpu/gatherSub.js +43 -30
  168. package/dist/ops/webgpu/gelu.js +88 -82
  169. package/dist/ops/webgpu/index.js +16 -16
  170. package/dist/ops/webgpu/matMul16.js +69 -64
  171. package/dist/ops/webgpu/matMul16_program.js +152 -192
  172. package/dist/ops/webgpu/mul16.js +13 -12
  173. package/dist/ops/webgpu/norm2.js +45 -75
  174. package/dist/ops/webgpu/normRMS.js +25 -33
  175. package/dist/ops/webgpu/normRMS16_program.js +21 -18
  176. package/dist/ops/webgpu/normRMS32_program.js +21 -18
  177. package/dist/ops/webgpu/normRMSGrad.js +125 -184
  178. package/dist/ops/webgpu/pack16.js +20 -17
  179. package/dist/ops/webgpu/pack16_program.js +48 -47
  180. package/dist/ops/webgpu/qkv.js +63 -23
  181. package/dist/ops/webgpu/rope.js +85 -57
  182. package/dist/ops/webgpu/scatterSub.js +43 -30
  183. package/dist/ops/webgpu/slice16.js +66 -61
  184. package/dist/ops/webgpu/softmax16.js +17 -20
  185. package/dist/ops/webgpu/softmax16_program.js +34 -18
  186. package/dist/ops/webgpu/softmax16_subgroup_program.js +40 -45
  187. package/dist/ops/webgpu/softmax16grad.js +30 -36
  188. package/dist/ops/webgpu/sub16.js +13 -12
  189. package/dist/ops/webgpu/sum16.js +28 -37
  190. package/dist/ops/webgpu/transpose16.js +36 -33
  191. package/dist/ops/webgpu/transpose16_program.js +40 -39
  192. package/dist/ops/webgpu/transpose16_shared_program.js +53 -44
  193. package/dist/ops/webgpu/unpack16.js +49 -37
  194. package/dist/ops/webgpu/utils/binary_op.js +70 -68
  195. package/dist/ops/webgpu/utils/deviceInfo.d.ts +1 -1
  196. package/dist/ops/webgpu/utils/deviceInfo.js +10 -10
  197. package/dist/ops/webgpu/utils/reductions.js +136 -148
  198. package/dist/pack16-Ck-spx_F.js +39 -0
  199. package/dist/patches/webgpu_backend.d.ts +2 -2
  200. package/dist/patches/webgpu_backend.js +42 -55
  201. package/dist/patches/webgpu_base.js +21 -33
  202. package/dist/patches/webgpu_program.js +213 -320
  203. package/dist/pdf-UoDqCYzz.js +16726 -0
  204. package/dist/picomatch-3tUnMMbd.js +1063 -0
  205. package/dist/rope-CbeGlsV8.js +25 -0
  206. package/dist/selu_util-zkAx5doH.js +24 -0
  207. package/dist/shared-D1coEFea.js +1314 -0
  208. package/dist/shared-DOgWaqvL.js +5 -0
  209. package/dist/slice_util-Dgb3ANWI.js +208 -0
  210. package/dist/tfjs_backend-BjuQ5FqB.js +614 -0
  211. package/dist/tokeniser/BaseTokeniser.js +2 -124
  212. package/dist/tokeniser/CharTokeniser.js +91 -106
  213. package/dist/tokeniser/bpe.js +163 -166
  214. package/dist/tokeniser/messages.js +0 -1
  215. package/dist/tokeniser/type.js +0 -1
  216. package/dist/training/AdamW.js +127 -137
  217. package/dist/training/BasicTrainer.d.ts +1 -1
  218. package/dist/training/BasicTrainer.js +264 -264
  219. package/dist/training/DatasetBuilder.js +2 -86
  220. package/dist/training/Evaluator.d.ts +1 -1
  221. package/dist/training/Evaluator.js +47 -38
  222. package/dist/training/LRScheduler.js +37 -33
  223. package/dist/training/PreTrainer.d.ts +2 -2
  224. package/dist/training/PreTrainer.js +21 -19
  225. package/dist/training/SFTTrainer.d.ts +2 -2
  226. package/dist/training/SFTTrainer.js +23 -21
  227. package/dist/training/loss.js +17 -22
  228. package/dist/training/orthoGrad.js +9 -9
  229. package/dist/training/sparseCrossEntropy.js +45 -67
  230. package/dist/training/tasks/ConversationTask.d.ts +1 -1
  231. package/dist/training/tasks/ConversationTask.js +36 -38
  232. package/dist/training/tasks/PretrainingTask.d.ts +1 -1
  233. package/dist/training/tasks/PretrainingTask.js +41 -46
  234. package/dist/training/tasks/StartSentenceTask.d.ts +1 -1
  235. package/dist/training/tasks/StartSentenceTask.js +44 -48
  236. package/dist/training/tasks/Task.d.ts +1 -1
  237. package/dist/training/tasks/Task.js +53 -66
  238. package/dist/training/tasks/splitter.js +17 -20
  239. package/dist/training/types.d.ts +2 -2
  240. package/dist/training/types.js +0 -1
  241. package/dist/training/validation.d.ts +1 -1
  242. package/dist/training/validation.js +2 -84
  243. package/dist/utilities/arrayClose.js +15 -19
  244. package/dist/utilities/datasetID.js +17 -20
  245. package/dist/utilities/dummy.d.ts +1 -1
  246. package/dist/utilities/dummy.js +33 -40
  247. package/dist/utilities/multinomialCPU.js +8 -12
  248. package/dist/utilities/naming.js +0 -1
  249. package/dist/utilities/packed.js +10 -12
  250. package/dist/utilities/parameters.d.ts +1 -1
  251. package/dist/utilities/parameters.js +32 -51
  252. package/dist/utilities/performance.js +15 -15
  253. package/dist/utilities/profile.js +32 -37
  254. package/dist/utilities/safetensors.js +49 -79
  255. package/dist/utilities/sentences.d.ts +1 -1
  256. package/dist/utilities/sentences.js +29 -38
  257. package/dist/utilities/tokenParse.js +16 -20
  258. package/dist/utilities/topP.js +11 -12
  259. package/dist/utilities/waitForModel.d.ts +1 -1
  260. package/dist/utilities/waitForModel.js +11 -11
  261. package/dist/utilities/weights.js +37 -42
  262. package/dist/utilities/yielder.js +6 -6
  263. package/dist/webgpu-Dt7BMzWz.js +525 -0
  264. package/dist/webgpu_program-WOyIVMlZ.js +392 -0
  265. package/dist/webgpu_util-B_F3SShA.js +106 -0
  266. package/package.json +9 -10
  267. package/dist/RealDiv-CGwv0liw.js +0 -365
  268. package/dist/Reshape-BW__R4mZ.js +0 -79
  269. package/dist/Reshape-CPBkTIH2.js +0 -14
  270. package/dist/_commonjsHelpers-ByX85dGu.js +0 -33
  271. package/dist/axis_util-GTVlo58H.js +0 -55
  272. package/dist/backend_util-GaFarB78.js +0 -425
  273. package/dist/backend_webgpu-BqASlsbV.js +0 -545
  274. package/dist/binary_op_util-pKXltfxI.js +0 -192
  275. package/dist/broadcast_to-eS93CCN_.js +0 -28
  276. package/dist/clip_by_value-DDA7rrcT.js +0 -12
  277. package/dist/complex-DI35Q-gW.js +0 -11
  278. package/dist/complex_util-Yc1A_gV1.js +0 -55
  279. package/dist/concat-CAQpCret.js +0 -17
  280. package/dist/concat_util-D18dJ4fD.js +0 -22
  281. package/dist/data/parquet.d.ts +0 -2
  282. package/dist/data/parquet.js +0 -17
  283. package/dist/dataset-CGGp1z9P.js +0 -1124
  284. package/dist/dropout_util--NxWuYg2.js +0 -27
  285. package/dist/expand_dims-Bkd1YD5x.js +0 -11
  286. package/dist/exports_initializers-CYzKLjN7.js +0 -7
  287. package/dist/floor-BQtb-Azg.js +0 -9
  288. package/dist/gather-qIqEqaGn.js +0 -9
  289. package/dist/gelu-B220X1Go.js +0 -26
  290. package/dist/gpgpu_math-BwvV12df.js +0 -2022
  291. package/dist/index-CUXkjxiT.js +0 -3516
  292. package/dist/index-CieiGp4Y.js +0 -349
  293. package/dist/index-CjOWnMXP.js +0 -7308
  294. package/dist/index-Cp39cXWe.js +0 -1016
  295. package/dist/index-D5v913EJ.js +0 -4
  296. package/dist/index-DmeWGGmS.js +0 -1074
  297. package/dist/index-DvYrXKkX.js +0 -113
  298. package/dist/index-Ksja3su6.js +0 -151
  299. package/dist/index-xuotMAFm.js +0 -118
  300. package/dist/jszip.min-BZhlzntC.js +0 -2313
  301. package/dist/kernel_funcs_utils-pq0CK9co.js +0 -306
  302. package/dist/matMul16-BcVC_E62.js +0 -80
  303. package/dist/matMulGelu-JNLZqKQp.js +0 -163
  304. package/dist/mat_mul-DhG0Newp.js +0 -11
  305. package/dist/mod-CSdCpRjf.js +0 -11
  306. package/dist/non_max_suppression_impl-B2W7YjZB.js +0 -102
  307. package/dist/not_equal-hurPF26l.js +0 -64
  308. package/dist/ones-BytntneX.js +0 -14
  309. package/dist/ops-CsXeTq1P.js +0 -476
  310. package/dist/pack16-bqltoUlR.js +0 -39
  311. package/dist/papaparse.min-C0cScC2i.js +0 -418
  312. package/dist/parquet-Bqjmp2vo.js +0 -44231
  313. package/dist/pdf-NIhmP3sq.js +0 -19477
  314. package/dist/rand_util-CZ7yLoUm.js +0 -50
  315. package/dist/random_normal-IBRrha8a.js +0 -14
  316. package/dist/random_width-DN5ZtQkM.js +0 -9796
  317. package/dist/range-C-CjF-LI.js +0 -10
  318. package/dist/relu-J_X6MUzx.js +0 -9
  319. package/dist/reshape-BDOuCSNW.js +0 -9
  320. package/dist/resize_nearest_neighbor-BojqlfRe.js +0 -150
  321. package/dist/rope-DcrZM_e6.js +0 -24
  322. package/dist/scatter_nd_util-ByNJaL6I.js +0 -46
  323. package/dist/segment_util-Dasb2Zaf.js +0 -43
  324. package/dist/selu_util-BLhIqRkw.js +0 -44
  325. package/dist/shared-3agzAqQ_.js +0 -53
  326. package/dist/shared-CagdqkLh.js +0 -2143
  327. package/dist/slice-BzS11Qh0.js +0 -12
  328. package/dist/slice_util-CC35pLmT.js +0 -153
  329. package/dist/softmax-D4q1LJN7.js +0 -12
  330. package/dist/split-C2Sj255c.js +0 -9
  331. package/dist/squeeze-ho4wLUek.js +0 -10
  332. package/dist/stack-DudVrtmG.js +0 -11
  333. package/dist/step-BTxPtq1r.js +0 -261
  334. package/dist/sum-BpiwSWvg.js +0 -11
  335. package/dist/tensor-BWFldCso.js +0 -8
  336. package/dist/tensor1d-LMGMIUlr.js +0 -11
  337. package/dist/tensor2d-BnXMKScO.js +0 -14
  338. package/dist/tensor4d-C6UCG_u8.js +0 -14
  339. package/dist/tfjs_backend-BGnG-ppu.js +0 -654
  340. package/dist/tile-CFy-xTO6.js +0 -11
  341. package/dist/transpose-9kRxIXWR.js +0 -36
  342. package/dist/unsorted_segment_sum-DJvk5xnh.js +0 -277
  343. package/dist/variable-Ck482e3n.js +0 -7
  344. package/dist/webgpu_program-B4HmApL1.js +0 -525
  345. package/dist/webgpu_util-DYlGSwOJ.js +0 -64
  346. package/dist/zeros-DvZpK8s6.js +0 -13
  347. package/dist/zeros_like-CWjDdwr-.js +0 -721
@@ -0,0 +1,362 @@
1
+ import { Dn as e, En as t, Io as n, Ks as r, Ms as i, Si as a, Tn as o, nc as s, oc as c, wn as l, xn as u } from "./dist-BewPQWjc.js";
2
+ import { L as d } from "./backend_util-Cg-roD1p.js";
3
+ import { o as f } from "./gpgpu_math-DvLcCH6u.js";
4
+ import { J as p, b as m } from "./shared-DOgWaqvL.js";
5
+ import { S as h, n as g } from "./kernel_funcs_utils-HiXOOx3f.js";
6
+ import { t as _ } from "./Reshape-CABOPB9d.js";
7
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/mean_gpu.js
8
+ var v = class {
9
+ constructor(e, t) {
10
+ this.variableNames = ["x"];
11
+ let { windowSize: n, batchSize: i, inSize: a, outSize: o } = e;
12
+ this.outputShape = [i, o];
13
+ let s = Math.floor(n / 4) * 4, c = n % 4, l = "sumValue += dot(values, ones);";
14
+ if (t != null) {
15
+ let e = 1 / t;
16
+ l = `sumValue += dot(values * ${r(e) ? e.toPrecision(2) : e}, ones);`;
17
+ }
18
+ let u = "";
19
+ a % n > 0 && (u = `
20
+ if (inIdx < 0 || inIdx >= ${a}) {
21
+ return 0.0;
22
+ }
23
+ `), this.userCode = `
24
+ const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
25
+
26
+ float getValue(int batch, int inIdx) {
27
+ ${u}
28
+ return getX(batch, inIdx);
29
+ }
30
+
31
+ void main() {
32
+ ivec2 coords = getOutputCoords();
33
+ int batch = coords[0];
34
+ int outIdx = coords[1];
35
+ int inOffset = outIdx * ${n};
36
+
37
+ float sumValue = 0.0;
38
+
39
+ for (int i = 0; i < ${s}; i += 4) {
40
+ int inIdx = inOffset + i;
41
+ vec4 values = vec4(
42
+ getValue(batch, inIdx),
43
+ getValue(batch, inIdx + 1),
44
+ getValue(batch, inIdx + 2),
45
+ getValue(batch, inIdx + 3)
46
+ );
47
+
48
+ ${l}
49
+ }
50
+
51
+ int inIdx = inOffset + ${s};
52
+ if (${c === 1}) {
53
+ vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);
54
+
55
+ ${l}
56
+ } else if (${c === 2}) {
57
+ vec4 values = vec4(
58
+ getValue(batch, inIdx),
59
+ getValue(batch, inIdx + 1), 0.0, 0.0);
60
+
61
+ ${l}
62
+ } else if (${c === 3}) {
63
+ vec4 values = vec4(
64
+ getValue(batch, inIdx),
65
+ getValue(batch, inIdx + 1),
66
+ getValue(batch, inIdx + 2), 0.0);
67
+
68
+ ${l}
69
+ }
70
+ setOutput(sumValue);
71
+ }
72
+ `;
73
+ }
74
+ }, y = class {
75
+ constructor(e, t) {
76
+ this.variableNames = ["x"];
77
+ let { windowSize: n, batchSize: r, inSize: i, outSize: a } = e;
78
+ this.outputShape = [r, a];
79
+ let o = "0.0", s = "";
80
+ t === "prod" ? o = "1.0" : t === "min" ? (o = "1.0 / 1e-20", s = "min") : t === "max" && (o = "-1.0 / 1e-20", s = "max");
81
+ let c = `${t}(${t}(${t}(minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])`;
82
+ t === "sum" ? c = "sumValue" : t === "prod" ? c = "prodValue" : t === "all" ? c = "allValue" : t === "any" && (c = "anyValue");
83
+ let l = Math.floor(n / 4) * 4, u = n % 4, d = `
84
+ if (${t === "sum"}) {
85
+ sumValue += dot(values, ones);
86
+ } else if (${t === "prod"}) {
87
+ vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);
88
+ prodValue *= tmp[0] * tmp[1];
89
+ } else {
90
+ minMaxValue = ${s}(values, minMaxValue);
91
+ if (${t === "min"} || ${t === "max"}) {
92
+ minMaxValue = ${s}(values, minMaxValue);
93
+ bvec4 isNaN = isnan(values);
94
+ if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {
95
+ minMaxValue = vec4(NAN);
96
+ }
97
+ }
98
+ }
99
+ `, f = "vec4";
100
+ t === "all" ? (o = "1.0", d = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ", f = "bvec4") : t === "any" && (o = "0.0", d = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ", f = "bvec4");
101
+ let p = "";
102
+ i % n > 0 && (p = `
103
+ if (inIdx < 0 || inIdx >= ${i}) {
104
+ return initializationValue;
105
+ }
106
+ `), this.userCode = `
107
+ const float initializationValue = ${o};
108
+ const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);
109
+
110
+ float getValue(int batch, int inIdx) {
111
+ ${p}
112
+ return getX(batch, inIdx);
113
+ }
114
+
115
+ void main() {
116
+ ivec2 coords = getOutputCoords();
117
+ int batch = coords[0];
118
+ int outIdx = coords[1];
119
+ int inOffset = outIdx * ${n};
120
+
121
+ vec4 minMaxValue = vec4(${o});
122
+ float prodValue = 1.0;
123
+ float sumValue = 0.0;
124
+ float allValue = 1.0;
125
+ float anyValue = 0.0;
126
+
127
+ for (int i = 0; i < ${l}; i += 4) {
128
+ int inIdx = inOffset + i;
129
+ ${f} values = ${f}(
130
+ getValue(batch, inIdx),
131
+ getValue(batch, inIdx + 1),
132
+ getValue(batch, inIdx + 2),
133
+ getValue(batch, inIdx + 3)
134
+ );
135
+
136
+ ${d}
137
+ }
138
+
139
+ int inIdx = inOffset + ${l};
140
+ if (${u === 1}) {
141
+ ${f} values = ${f}(
142
+ getValue(batch, inIdx),
143
+ initializationValue,
144
+ initializationValue,
145
+ initializationValue
146
+ );
147
+
148
+ ${d}
149
+ } else if (${u === 2}) {
150
+ ${f} values = ${f}(
151
+ getValue(batch, inIdx),
152
+ getValue(batch, inIdx + 1),
153
+ initializationValue,
154
+ initializationValue
155
+ );
156
+
157
+ ${d}
158
+ } else if (${u === 3}) {
159
+ ${f} values = ${f}(
160
+ getValue(batch, inIdx),
161
+ getValue(batch, inIdx + 1),
162
+ getValue(batch, inIdx + 2),
163
+ initializationValue
164
+ );
165
+
166
+ ${d}
167
+ }
168
+ setOutput(${c});
169
+ }
170
+ `;
171
+ }
172
+ };
173
+ //#endregion
174
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernel_utils/reduce.js
175
+ function b(e) {
176
+ let t = [];
177
+ for (; t.length === 0 || t[t.length - 1].outSize !== 1;) {
178
+ let n = t.length ? t[t.length - 1].outSize : e[1], r = d(n);
179
+ t.push({
180
+ inSize: n,
181
+ windowSize: r,
182
+ outSize: Math.ceil(n / r)
183
+ });
184
+ }
185
+ return t;
186
+ }
187
+ function x(e, t, n, r) {
188
+ let i = b(e.shape), a = e;
189
+ for (let o = 0; o < i.length; o++) {
190
+ let { inSize: s, windowSize: c, outSize: l } = i[o], u, d;
191
+ u = n === "mean" ? o === 0 ? new v({
192
+ windowSize: c,
193
+ inSize: s,
194
+ batchSize: e.shape[0],
195
+ outSize: l
196
+ }, s) : new v({
197
+ windowSize: c,
198
+ inSize: s,
199
+ batchSize: e.shape[0],
200
+ outSize: l
201
+ }) : new y({
202
+ windowSize: c,
203
+ inSize: s,
204
+ batchSize: e.shape[0],
205
+ outSize: l
206
+ }, n), d = a, a = r.runWebGLProgram(u, [a], t), d.dataId !== e.dataId && r.disposeIntermediateTensorInfo(d);
207
+ }
208
+ return a;
209
+ }
210
+ //#endregion
211
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/transpose_gpu.js
212
+ var S = class {
213
+ constructor(e, t) {
214
+ this.variableNames = ["A"];
215
+ let n = Array(e.length);
216
+ for (let r = 0; r < n.length; r++) n[r] = e[t[r]];
217
+ this.outputShape = n, this.rank = n.length;
218
+ let r = f(this.rank), i = C(t);
219
+ this.userCode = `
220
+ void main() {
221
+ ${r} resRC = getOutputCoords();
222
+ setOutput(getA(${i}));
223
+ }
224
+ `;
225
+ }
226
+ };
227
+ function C(e) {
228
+ let t = e.length;
229
+ if (t > 6) throw Error(`Transpose for rank ${t} is not yet supported`);
230
+ let n = [
231
+ "resRC.x",
232
+ "resRC.y",
233
+ "resRC.z",
234
+ "resRC.w",
235
+ "resRC.u",
236
+ "resRC.v"
237
+ ], r = Array(t);
238
+ for (let t = 0; t < e.length; t++) r[e[t]] = n[t];
239
+ return r.join();
240
+ }
241
+ //#endregion
242
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/transpose_packed_gpu.js
243
+ var w = class {
244
+ constructor(e, t) {
245
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0;
246
+ let n = Array(e.length);
247
+ for (let r = 0; r < n.length; r++) n[r] = e[t[r]];
248
+ if (this.outputShape = n, this.rank = n.length, this.rank > 6) throw Error(`Packed transpose for rank ${this.rank} is not yet supported.`);
249
+ let r = f(this.rank), i = h("rc", this.rank), a = Array(this.rank);
250
+ for (let e = 0; e < t.length; e++) a[t[e]] = i[e];
251
+ let o = `vec2(${a.slice(-2).join()})`, s = `++${i[this.rank - 1]} < ${n[this.rank - 1]}`, c = `getChannel(getA(${a.join()}), ${o})`;
252
+ this.userCode = `
253
+ void main() {
254
+ ${r} rc = getOutputCoords();
255
+ vec4 result = vec4(0.);
256
+ result[0] = ${c};
257
+ if(${s}) {
258
+ result[1] = ${c};
259
+ }
260
+ --${i[this.rank - 1]};
261
+ if(++${i[this.rank - 2]} < ${n[this.rank - 2]}) {
262
+ result[2] = ${c};
263
+ if(${s}) {
264
+ result[3] = ${c};
265
+ }
266
+ }
267
+ setOutput(result);
268
+ }
269
+ `;
270
+ }
271
+ };
272
+ //#endregion
273
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Transpose_impl.js
274
+ function T(e, t, n) {
275
+ let r = i().getBool("WEBGL_PACK_ARRAY_OPERATIONS") ? new w(e.shape, t) : new S(e.shape, t);
276
+ return n.runWebGLProgram(r, [e], e.dtype);
277
+ }
278
+ //#endregion
279
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Sum_impl.js
280
+ function E(n, r, i, d) {
281
+ let f = r, p = n.shape.length, m = s(f, n.shape), h = m, g = t(h, p), v = g != null, y = n;
282
+ v && (y = T(n, g, d), h = e(h.length, p)), u("sum", h, p);
283
+ let [b, S] = l(y.shape, h), C = b;
284
+ i && (C = o(b, m));
285
+ let w = c(S), E = c(n.shape) / w, D = _({
286
+ inputs: { x: y },
287
+ attrs: { shape: [E, w] },
288
+ backend: d
289
+ }), O = x(D, a(n.dtype), "sum", d), k = _({
290
+ inputs: { x: O },
291
+ attrs: { shape: C },
292
+ backend: d
293
+ });
294
+ return d.disposeIntermediateTensorInfo(D), d.disposeIntermediateTensorInfo(O), v && d.disposeIntermediateTensorInfo(y), k;
295
+ }
296
+ //#endregion
297
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Sum.js
298
+ function D(e) {
299
+ let { inputs: t, backend: n, attrs: r } = e, { x: i } = t, { axis: a, keepDims: o } = r;
300
+ return E(i, a, o, n);
301
+ }
302
+ var O = {
303
+ kernelName: "Sum",
304
+ backendName: "webgl",
305
+ kernelFunc: D
306
+ };
307
+ //#endregion
308
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Max_impl.js
309
+ function k(e, t, n, r) {
310
+ let i = c(t), a = c(e.shape) / i, o = _({
311
+ inputs: { x: e },
312
+ attrs: { shape: [a, i] },
313
+ backend: r
314
+ }), s = x(o, e.dtype, "max", r), l = _({
315
+ inputs: { x: s },
316
+ attrs: { shape: n },
317
+ backend: r
318
+ });
319
+ return r.disposeIntermediateTensorInfo(o), r.disposeIntermediateTensorInfo(s), l;
320
+ }
321
+ //#endregion
322
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Max.js
323
+ function A(n) {
324
+ let { inputs: r, backend: i, attrs: a } = n, { x: d } = r, { reductionIndices: f, keepDims: h } = a, g = d.shape.length, _ = s(f, d.shape), v = _, y = t(v, g), b = y != null, x = i.shouldExecuteOnCPU([d]), S = d;
325
+ if (b) {
326
+ if (x) {
327
+ let e = i.texData.get(S.dataId).values, t = Array(g);
328
+ for (let e = 0; e < t.length; e++) t[e] = d.shape[y[e]];
329
+ let n = p(e, d.shape, d.dtype, y, t);
330
+ S = i.makeTensorInfo(t, d.dtype);
331
+ let r = i.texData.get(S.dataId);
332
+ r.values = n;
333
+ } else S = T(d, y, i);
334
+ v = e(v.length, g);
335
+ }
336
+ u("max", v, g);
337
+ let [C, w] = l(S.shape, v), E = C;
338
+ h && (E = o(C, _));
339
+ let D;
340
+ if (x) {
341
+ let e = i.texData.get(S.dataId).values, t = m(e, c(w), E, d.dtype);
342
+ D = i.makeTensorInfo(E, d.dtype);
343
+ let n = i.texData.get(D.dataId);
344
+ n.values = t;
345
+ } else D = k(S, w, E, i);
346
+ return b && i.disposeIntermediateTensorInfo(S), D;
347
+ }
348
+ var j = {
349
+ kernelName: "Max",
350
+ backendName: "webgl",
351
+ kernelFunc: A
352
+ }, M = g({
353
+ opSnippet: "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;",
354
+ packedOpSnippet: "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n",
355
+ checkOutOfBounds: !0
356
+ }), N = {
357
+ kernelName: n,
358
+ backendName: "webgl",
359
+ kernelFunc: M
360
+ };
361
+ //#endregion
362
+ export { D as a, x as c, j as i, N as n, O as o, A as r, T as s, M as t };
@@ -0,0 +1,94 @@
1
+ import { Bo as e, Gs as t, Ps as n, oc as r } from "./dist-BewPQWjc.js";
2
+ import { E as i, a, c as o, d as s, j as c, l, u, z as d } from "./gpgpu_math-DvLcCH6u.js";
3
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/reshape_packed_gpu.js
4
+ var f = class {
5
+ constructor(e, t) {
6
+ this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.customUniforms = [{
7
+ name: "inputShape",
8
+ type: "ivec3"
9
+ }], this.outputShape = e, this.enableShapeUniforms = a(this.outputShape.length);
10
+ let n = "";
11
+ for (let e = 0; e < 4; e++) {
12
+ let t = "thisRC = rc;";
13
+ e % 2 == 1 && (t += "thisRC.z += 1;"), e > 1 && (t += "thisRC.y += 1;"), n += `
14
+ ${t}
15
+ ${e > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : ""}
16
+ int flatIndex = getFlatIndex(thisRC);
17
+
18
+ ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);
19
+ vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));
20
+
21
+ result[${e}] =
22
+ getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);
23
+ ${e > 0 ? "}" : ""}
24
+ `;
25
+ }
26
+ this.userCode = `
27
+ ${p(t, this.enableShapeUniforms)}
28
+ ${this.enableShapeUniforms ? l() : o(e)}
29
+
30
+ void main() {
31
+ ivec3 rc = getOutputCoords();
32
+
33
+ vec4 result = vec4(0.);
34
+
35
+ ivec3 thisRC;
36
+ int rows = ${this.enableShapeUniforms ? "outShape[1]" : e[1]};
37
+ int cols = ${this.enableShapeUniforms ? "outShape[2]" : e[2]};
38
+
39
+ ${n}
40
+
41
+ setOutput(result);
42
+ }
43
+ `;
44
+ }
45
+ };
46
+ function p(e, t) {
47
+ return `
48
+ ivec3 inputCoordsFromReshapedOutCoords(int index) {
49
+ ${t ? s([
50
+ "r",
51
+ "c",
52
+ "d"
53
+ ], "inputShape") : u([
54
+ "r",
55
+ "c",
56
+ "d"
57
+ ], e)}
58
+ return ivec3(r, c, d);
59
+ }
60
+ `;
61
+ }
62
+ //#endregion
63
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernel_utils/reshape.js
64
+ function m(e, t, n) {
65
+ let r = [i(e.shape), ...c(e.shape)], a = {
66
+ dtype: e.dtype,
67
+ shape: r,
68
+ dataId: e.dataId
69
+ }, o = new f([i(t), ...c(t)], r), s = [r], l = n.runWebGLProgram(o, [a], e.dtype, s, !0);
70
+ return {
71
+ dataId: l.dataId,
72
+ shape: t,
73
+ dtype: l.dtype
74
+ };
75
+ }
76
+ //#endregion
77
+ //#region node_modules/@tensorflow/tfjs-backend-webgl/dist/kernels/Reshape.js
78
+ function h(e) {
79
+ let { inputs: i, backend: a, attrs: o } = e, { x: s } = i, { shape: c } = o, l = a, u = r(s.shape), f = t(c, u), p = r(f);
80
+ n(u === p, () => `The new shape (${f}) has ${p} elements and the old shape (${s.shape}) has ${u} elements. The new shape and old shape must have the same number of elements.`);
81
+ let h = l.texData.get(s.dataId);
82
+ return h.isPacked && !d(s.shape, f) && !(h.texture !== null && d(h.shape, f)) ? m(s, f, l) : (l.incRef(s.dataId), {
83
+ dataId: s.dataId,
84
+ shape: f,
85
+ dtype: s.dtype
86
+ });
87
+ }
88
+ var g = {
89
+ kernelName: e,
90
+ backendName: "webgl",
91
+ kernelFunc: h
92
+ };
93
+ //#endregion
94
+ export { g as n, f as r, h as t };
@@ -0,0 +1,17 @@
1
+ import { Bo as e, Gs as t, Ps as n, oc as r } from "./dist-BewPQWjc.js";
2
+ //#region node_modules/@tensorflow/tfjs-backend-webgpu/dist/kernels/Reshape.js
3
+ function i(e) {
4
+ let { inputs: i, attrs: a } = e, { x: o } = i, { shape: s } = a, c = r(o.shape), l = t(s, c), u = r(l);
5
+ return n(c === u, () => `The new shape (${l}) has ${u} elements and the old shape (${o.shape}) has ${c} elements. The new shape and old shape must have the same number of elements.`), e.backend.incRef(o.dataId), {
6
+ dataId: o.dataId,
7
+ shape: l,
8
+ dtype: o.dtype
9
+ };
10
+ }
11
+ var a = {
12
+ kernelName: e,
13
+ backendName: "webgpu",
14
+ kernelFunc: i
15
+ };
16
+ //#endregion
17
+ export { a as n, i as t };
@@ -8,7 +8,7 @@ import { default as MemoryProfiler } from './utilities/profile';
8
8
  import { default as Model, ModelForwardAttributes } from './models/model';
9
9
  import { Task } from './training/tasks/Task';
10
10
  import { TrainingLogEntry, TrainingOptions } from './training/types';
11
- import { ModelPhase, TransformersMetadata } from './loader/types';
11
+ import { ModelMode, TransformersMetadata } from './loader/types';
12
12
  type TeachableLLMStatus = 'warmup' | 'awaitingTokens' | 'ready' | 'training' | 'loading' | 'busy' | 'error';
13
13
  export default class TeachableLLM {
14
14
  private ee;
@@ -22,8 +22,8 @@ export default class TeachableLLM {
22
22
  constructor(tokeniser?: ITokeniser, model?: Model<ModelForwardAttributes, GPTConfig>);
23
23
  get currentTrainer(): Trainer | null;
24
24
  get vocab(): string[];
25
- get phase(): ModelPhase;
26
- set phase(phase: ModelPhase);
25
+ get mode(): ModelMode;
26
+ set mode(mode: ModelMode);
27
27
  /** Model is fully loaded */
28
28
  get loaded(): boolean;
29
29
  get config(): GPTConfig;
@@ -57,12 +57,12 @@ export default class TeachableLLM {
57
57
  generateText(options?: IGenerateOptions): Promise<Conversation[]>;
58
58
  dispose(): void;
59
59
  on(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
60
- on(event: 'phase', listener: (phase: ModelPhase) => void): void;
60
+ on(event: 'mode', listener: (mode: ModelMode) => void): void;
61
61
  on(event: 'error', listener: (error: Error) => void): void;
62
62
  on(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
63
63
  on(event: 'loaded' | 'changeLoRA', listener: () => void): void;
64
64
  off(event: 'status', listener: (status: TeachableLLMStatus) => void): void;
65
- off(event: 'phase', listener: (phase: ModelPhase) => void): void;
65
+ off(event: 'mode', listener: (mode: ModelMode) => void): void;
66
66
  off(event: 'error', listener: (error: Error) => void): void;
67
67
  off(event: 'trainStep', listener: (step: TrainingLogEntry) => void): void;
68
68
  off(event: 'loaded' | 'changeLoRA', listener: () => void): void;