@genai-fi/nanogpt 0.9.1 → 0.10.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (343) hide show
  1. package/README.md +352 -14
  2. package/dist/Generator.js +69 -78
  3. package/dist/{RealDiv-D4EzDsC0.js → RealDiv-DgA3z9oO.js} +32 -206
  4. package/dist/Reshape-CF6odzV4.js +16 -0
  5. package/dist/Reshape-_kILl6tK.js +81 -0
  6. package/dist/TeachableLLM.js +28 -22
  7. package/dist/Trainer.d.ts +2 -0
  8. package/dist/Trainer.js +3 -2
  9. package/dist/{axis_util-TbGYJ208.js → axis_util-BvHEw88j.js} +7 -23
  10. package/dist/backend.d.ts +2 -1
  11. package/dist/backend.js +10 -4
  12. package/dist/backend_util-D-rUb2ty.js +474 -0
  13. package/dist/backend_webgpu-B0u2ndUn.js +547 -0
  14. package/dist/binary_op_util-pKXltfxI.js +192 -0
  15. package/dist/broadcast_to-CwF7XIeu.js +30 -0
  16. package/dist/checks/appendCache.js +2 -2
  17. package/dist/checks/attentionMask.js +3 -3
  18. package/dist/checks/check.d.ts +1 -1
  19. package/dist/checks/check.js +8 -8
  20. package/dist/checks/gelu.js +2 -2
  21. package/dist/checks/index.d.ts +2 -0
  22. package/dist/checks/index.js +7 -5
  23. package/dist/checks/matMulGelu.js +6 -6
  24. package/dist/checks/normRMS.js +7 -7
  25. package/dist/checks/normRMSGrad.js +3 -3
  26. package/dist/checks/packUnpack.d.ts +1 -0
  27. package/dist/checks/packUnpack.js +18 -0
  28. package/dist/checks/qkv.js +12 -27
  29. package/dist/checks/rope.js +2 -2
  30. package/dist/checks/weights.js +18 -16
  31. package/dist/complex-CSlYz-2T.js +13 -0
  32. package/dist/complex_util-Yc1A_gV1.js +55 -0
  33. package/dist/concat-BHlIJeyT.js +19 -0
  34. package/dist/concat_util-DcJk7YHS.js +22 -0
  35. package/dist/data/docx.js +1 -1
  36. package/dist/data/parquet.js +2 -2
  37. package/dist/data/pdf.js +1 -1
  38. package/dist/data/textLoader.js +1 -1
  39. package/dist/{dataset-DlZtKmBq.js → dataset-0xP8GjwI.js} +136 -236
  40. package/dist/dropout-C1pM3f11.js +99 -0
  41. package/dist/expand_dims-BPG4fwBP.js +13 -0
  42. package/dist/exports_initializers-xuidcwI4.js +7 -0
  43. package/dist/gather-DykLGqmW.js +10 -0
  44. package/dist/{gelu-Bp_-935b.js → gelu-CNLFZWea.js} +11 -10
  45. package/dist/{gpgpu_math-CDaYiyE_.js → gpgpu_math-DDVJCn6-.js} +90 -265
  46. package/dist/{index-C4L8Cm77.js → index-CieiGp4Y.js} +14 -14
  47. package/dist/index-CjOj7j-u.js +7308 -0
  48. package/dist/{index-Tf7vU29b.js → index-Cp39cXWe.js} +3 -10
  49. package/dist/{index-Dwqa6Zy2.js → index-DvYrXKkX.js} +2 -2
  50. package/dist/index-ZyQhjEPo.js +2157 -0
  51. package/dist/{jszip.min-CjP2V1VV.js → jszip.min-Bz5-11Bk.js} +56 -57
  52. package/dist/kernel_funcs_utils-Dg_-E44D.js +308 -0
  53. package/dist/layers/BaseLayer.d.ts +1 -0
  54. package/dist/layers/BaseLayer.js +7 -6
  55. package/dist/layers/CausalSelfAttention.d.ts +0 -1
  56. package/dist/layers/CausalSelfAttention.js +56 -55
  57. package/dist/layers/MLP.js +15 -16
  58. package/dist/layers/PositionEmbedding.js +5 -14
  59. package/dist/layers/RMSNorm.js +3 -3
  60. package/dist/layers/RoPECache.d.ts +2 -0
  61. package/dist/layers/RoPECache.js +22 -17
  62. package/dist/layers/TiedEmbedding.js +22 -17
  63. package/dist/layers/TransformerBlock.js +21 -20
  64. package/dist/loader/load.js +1 -1
  65. package/dist/loader/loadTransformers.js +1 -1
  66. package/dist/loader/oldZipLoad.js +39 -33
  67. package/dist/loader/save.js +1 -1
  68. package/dist/log_sum_exp-DWI-76TI.js +41 -0
  69. package/dist/main.d.ts +8 -0
  70. package/dist/main.js +63 -52
  71. package/dist/matMul16--R5hOwDG.js +77 -0
  72. package/dist/mat_mul-DeAh4uTH.js +12 -0
  73. package/dist/mod-Gt1rMB4n.js +12 -0
  74. package/dist/models/NanoGPTV1.js +40 -31
  75. package/dist/models/model.d.ts +2 -0
  76. package/dist/models/model.js +37 -29
  77. package/dist/{mulmat_packed_gpu-BT60jmzP.js → mulmat_packed_gpu-BMFhLwta.js} +1 -17
  78. package/dist/{non_max_suppression_impl-CsEgBuMA.js → non_max_suppression_impl-B2W7YjZB.js} +0 -32
  79. package/dist/ones-CAMiP4I2.js +15 -0
  80. package/dist/ops/adamAdjust.js +1 -1
  81. package/dist/ops/adamMoments.d.ts +1 -1
  82. package/dist/ops/adamMoments.js +4 -4
  83. package/dist/ops/add16.d.ts +2 -0
  84. package/dist/ops/add16.js +9 -0
  85. package/dist/ops/appendCache.js +16 -9
  86. package/dist/ops/attentionMask.js +4 -4
  87. package/dist/ops/concat16.d.ts +2 -0
  88. package/dist/ops/concat16.js +9 -0
  89. package/dist/ops/cpu/adamAdjust.js +14 -13
  90. package/dist/ops/cpu/adamMoments.js +10 -9
  91. package/dist/ops/cpu/appendCache.js +9 -8
  92. package/dist/ops/cpu/attentionMask.js +15 -14
  93. package/dist/ops/cpu/fusedSoftmax.js +13 -12
  94. package/dist/ops/cpu/gatherSub.js +9 -24
  95. package/dist/ops/cpu/gelu.js +13 -12
  96. package/dist/ops/cpu/matMul16.d.ts +1 -0
  97. package/dist/ops/cpu/matMul16.js +16 -0
  98. package/dist/ops/cpu/matMulGelu.js +18 -16
  99. package/dist/ops/cpu/matMulMul.js +8 -7
  100. package/dist/ops/cpu/mulDropout.js +4 -3
  101. package/dist/ops/cpu/normRMS.js +11 -10
  102. package/dist/ops/cpu/qkv.js +17 -13
  103. package/dist/ops/cpu/rope.js +23 -22
  104. package/dist/ops/cpu/scatterSub.js +16 -30
  105. package/dist/ops/dot16.d.ts +2 -0
  106. package/dist/ops/dot16.js +42 -0
  107. package/dist/ops/gatherSub.js +1 -1
  108. package/dist/ops/gelu.js +2 -2
  109. package/dist/ops/grads/add16.d.ts +1 -0
  110. package/dist/ops/grads/add16.js +27 -0
  111. package/dist/ops/grads/attentionMask.js +12 -19
  112. package/dist/ops/grads/gelu.js +4 -3
  113. package/dist/ops/grads/matMul16.d.ts +2 -0
  114. package/dist/ops/grads/matMul16.js +9 -0
  115. package/dist/ops/grads/matMulGelu.js +8 -7
  116. package/dist/ops/grads/normRMS.js +8 -7
  117. package/dist/ops/grads/{fusedSoftmax.d.ts → pack16.d.ts} +1 -1
  118. package/dist/ops/grads/pack16.js +7 -0
  119. package/dist/ops/grads/qkv.d.ts +3 -1
  120. package/dist/ops/grads/qkv.js +28 -22
  121. package/dist/ops/grads/rope.d.ts +2 -1
  122. package/dist/ops/grads/rope.js +6 -13
  123. package/dist/ops/grads/softmax16.d.ts +2 -0
  124. package/dist/ops/grads/softmax16.js +26 -0
  125. package/dist/ops/grads/unpack16.d.ts +2 -0
  126. package/dist/ops/grads/unpack16.js +6 -0
  127. package/dist/ops/grads/utils.d.ts +3 -0
  128. package/dist/ops/grads/utils.js +10 -0
  129. package/dist/ops/matMul16.d.ts +15 -0
  130. package/dist/ops/matMul16.js +13 -0
  131. package/dist/ops/matMulGelu.js +1 -1
  132. package/dist/ops/matMulMul.js +1 -1
  133. package/dist/ops/mul16.d.ts +2 -0
  134. package/dist/ops/mul16.js +8 -0
  135. package/dist/ops/mulDrop.js +1 -1
  136. package/dist/ops/normRMS.js +1 -1
  137. package/dist/ops/pack16.d.ts +2 -0
  138. package/dist/ops/pack16.js +6 -0
  139. package/dist/ops/qkv.d.ts +1 -1
  140. package/dist/ops/qkv.js +8 -4
  141. package/dist/ops/reshape16.d.ts +2 -0
  142. package/dist/ops/reshape16.js +43 -0
  143. package/dist/ops/rope.d.ts +1 -1
  144. package/dist/ops/rope.js +8 -10
  145. package/dist/ops/scatterSub.js +1 -1
  146. package/dist/ops/slice16.d.ts +2 -0
  147. package/dist/ops/slice16.js +9 -0
  148. package/dist/ops/softmax16.d.ts +2 -0
  149. package/dist/ops/softmax16.js +12 -0
  150. package/dist/ops/sub16.d.ts +2 -0
  151. package/dist/ops/sub16.js +8 -0
  152. package/dist/ops/sum16.d.ts +2 -0
  153. package/dist/ops/sum16.js +13 -0
  154. package/dist/ops/transpose16.d.ts +3 -0
  155. package/dist/ops/transpose16.js +41 -0
  156. package/dist/ops/unpack16.d.ts +2 -0
  157. package/dist/ops/unpack16.js +6 -0
  158. package/dist/ops/webgl/adamAdjust.js +3 -2
  159. package/dist/ops/webgl/adamMoments.js +2 -1
  160. package/dist/ops/webgl/appendCache.js +2 -1
  161. package/dist/ops/webgl/attentionMask.js +5 -4
  162. package/dist/ops/webgl/fusedSoftmax.js +6 -4
  163. package/dist/ops/webgl/gatherSub.js +7 -6
  164. package/dist/ops/webgl/gelu.js +3 -2
  165. package/dist/ops/webgl/log.js +12 -27
  166. package/dist/ops/webgl/matMul16.d.ts +1 -0
  167. package/dist/ops/webgl/matMul16.js +37 -0
  168. package/dist/ops/webgl/matMulGelu.js +17 -15
  169. package/dist/ops/webgl/matMulMul.js +13 -12
  170. package/dist/ops/webgl/mulDropout.js +9 -8
  171. package/dist/ops/webgl/normRMS.js +8 -7
  172. package/dist/ops/webgl/qkv.js +6 -5
  173. package/dist/ops/webgl/rope.js +11 -10
  174. package/dist/ops/webgl/scatterSub.js +6 -5
  175. package/dist/ops/webgpu/adamAdjust.js +12 -10
  176. package/dist/ops/webgpu/adamMoments.js +27 -22
  177. package/dist/ops/webgpu/add16.d.ts +1 -0
  178. package/dist/ops/webgpu/add16.js +14 -0
  179. package/dist/ops/webgpu/appendCache.js +64 -17
  180. package/dist/ops/webgpu/attentionMask.js +19 -62
  181. package/dist/ops/webgpu/attentionMask32_program.d.ts +19 -0
  182. package/dist/ops/webgpu/attentionMask32_program.js +54 -0
  183. package/dist/ops/webgpu/concat16.d.ts +19 -0
  184. package/dist/ops/webgpu/concat16.js +128 -0
  185. package/dist/ops/webgpu/gatherSub.js +9 -7
  186. package/dist/ops/webgpu/gelu.js +78 -31
  187. package/dist/ops/webgpu/index.js +12 -0
  188. package/dist/ops/webgpu/matMul16.d.ts +1 -0
  189. package/dist/ops/webgpu/matMul16.js +58 -0
  190. package/dist/ops/webgpu/matMul16_program.d.ts +42 -0
  191. package/dist/ops/webgpu/matMul16_program.js +336 -0
  192. package/dist/ops/webgpu/mul16.d.ts +1 -0
  193. package/dist/ops/webgpu/mul16.js +14 -0
  194. package/dist/ops/webgpu/normRMS.js +21 -40
  195. package/dist/ops/webgpu/normRMS16_program.d.ts +9 -0
  196. package/dist/ops/webgpu/normRMS16_program.js +24 -0
  197. package/dist/ops/webgpu/normRMS32_program.d.ts +9 -0
  198. package/dist/ops/webgpu/normRMS32_program.js +24 -0
  199. package/dist/ops/webgpu/normRMSGrad.js +113 -64
  200. package/dist/ops/webgpu/pack16.d.ts +1 -0
  201. package/dist/ops/webgpu/pack16.js +19 -0
  202. package/dist/ops/webgpu/pack16_program.d.ts +19 -0
  203. package/dist/ops/webgpu/pack16_program.js +92 -0
  204. package/dist/ops/webgpu/qkv.js +20 -55
  205. package/dist/ops/webgpu/rope.js +77 -22
  206. package/dist/ops/webgpu/scatterSub.js +9 -7
  207. package/dist/ops/webgpu/slice16.d.ts +7 -0
  208. package/dist/ops/webgpu/slice16.js +71 -0
  209. package/dist/{variable-Bm2OFwGI.js → ops/webgpu/softmax16.d.ts} +2 -8
  210. package/dist/ops/webgpu/softmax16.js +23 -0
  211. package/dist/ops/webgpu/softmax16_program.d.ts +13 -0
  212. package/dist/ops/webgpu/softmax16_program.js +73 -0
  213. package/dist/ops/webgpu/softmax16_subgroup_program.d.ts +17 -0
  214. package/dist/ops/webgpu/softmax16_subgroup_program.js +75 -0
  215. package/dist/ops/webgpu/softmax16grad.d.ts +1 -0
  216. package/dist/ops/webgpu/softmax16grad.js +38 -0
  217. package/dist/ops/webgpu/sub16.d.ts +1 -0
  218. package/dist/ops/webgpu/sub16.js +14 -0
  219. package/dist/ops/webgpu/sum16.d.ts +1 -0
  220. package/dist/ops/webgpu/sum16.js +40 -0
  221. package/dist/ops/webgpu/transpose16.d.ts +1 -0
  222. package/dist/ops/webgpu/transpose16.js +35 -0
  223. package/dist/ops/webgpu/transpose16_program.d.ts +16 -0
  224. package/dist/ops/webgpu/transpose16_program.js +50 -0
  225. package/dist/ops/webgpu/transpose16_shared_program.d.ts +15 -0
  226. package/dist/ops/webgpu/transpose16_shared_program.js +71 -0
  227. package/dist/ops/webgpu/unpack16.d.ts +1 -0
  228. package/dist/ops/webgpu/unpack16.js +49 -0
  229. package/dist/ops/webgpu/utils/binary_op.d.ts +19 -0
  230. package/dist/ops/webgpu/utils/binary_op.js +79 -0
  231. package/dist/ops/webgpu/utils/deviceInfo.d.ts +7 -0
  232. package/dist/ops/webgpu/utils/deviceInfo.js +11 -0
  233. package/dist/ops/webgpu/utils/reductions.d.ts +32 -4
  234. package/dist/ops/webgpu/utils/reductions.js +236 -45
  235. package/dist/ops-CNI3TwqM.js +645 -0
  236. package/dist/pack16-CFUqumar.js +41 -0
  237. package/dist/{papaparse.min-C8l2Kvo1.js → papaparse.min-C0cScC2i.js} +2 -8
  238. package/dist/{parquet-C0Tlmv9c.js → parquet-BE8MU_ge.js} +201 -278
  239. package/dist/patches/PackedTensor.d.ts +12 -0
  240. package/dist/patches/PackedTensor.js +11 -0
  241. package/dist/patches/engine.d.ts +261 -0
  242. package/dist/patches/engine.js +10 -0
  243. package/dist/patches/tape.d.ts +12 -0
  244. package/dist/patches/tape.js +5 -0
  245. package/dist/patches/webgpu_backend.d.ts +18 -0
  246. package/dist/patches/webgpu_backend.js +57 -0
  247. package/dist/{tensor-CZr4dh61.js → patches/webgpu_base.d.ts} +5 -8
  248. package/dist/patches/webgpu_base.js +34 -0
  249. package/dist/patches/webgpu_program.d.ts +36 -0
  250. package/dist/patches/webgpu_program.js +401 -0
  251. package/dist/{pdf-kJD-f258.js → pdf-NIhmP3sq.js} +424 -428
  252. package/dist/random_width-DY6Kk2Dl.js +10051 -0
  253. package/dist/range-BMS52eQi.js +11 -0
  254. package/dist/reciprocal-CTmshQ9J.js +10 -0
  255. package/dist/{register_all_kernels-DIGpEwcf.js → register_all_kernels-Bwu1PTuU.js} +719 -9766
  256. package/dist/relu-yZ2-7WxU.js +10 -0
  257. package/dist/reshape-DevtBWtf.js +10 -0
  258. package/dist/rope-B5UUMsPi.js +32 -0
  259. package/dist/{scatter_nd_util-BQdz--Gn.js → scatter_nd_util-5EL-8VAQ.js} +1 -1
  260. package/dist/selu_util-D1w6yyTO.js +303 -0
  261. package/dist/{shared-DuP7ue-R.js → shared-BRksrJb3.js} +1 -17
  262. package/dist/shared-BuAXb4CI.js +2145 -0
  263. package/dist/sin-BGfy2HZo.js +16 -0
  264. package/dist/slice-D_gkkqZK.js +13 -0
  265. package/dist/slice_util-DtEldBfK.js +261 -0
  266. package/dist/softmax-ZHVebtR1.js +13 -0
  267. package/dist/split-DrfihRpZ.js +10 -0
  268. package/dist/squeeze-DZEpeblb.js +11 -0
  269. package/dist/stack-yOIAalTq.js +13 -0
  270. package/dist/sum-_fzj5ZTB.js +12 -0
  271. package/dist/tensor-DdQUJZlz.js +909 -0
  272. package/dist/tensor-f35l8Odg.js +8 -0
  273. package/dist/tensor1d-CeZuc-Rv.js +12 -0
  274. package/dist/tensor2d-G4Ys2GxX.js +15 -0
  275. package/dist/tensor4d-B8roDgtc.js +15 -0
  276. package/dist/tensor_util-DV-FP5Q3.js +523 -0
  277. package/dist/tfjs_backend-kNyO5L2d.js +653 -0
  278. package/dist/tile-BzyEiF-F.js +13 -0
  279. package/dist/tokeniser/CharTokeniser.js +1 -1
  280. package/dist/tokeniser/bpe.js +1 -1
  281. package/dist/training/Adam.d.ts +2 -1
  282. package/dist/training/Adam.js +12 -28
  283. package/dist/training/AdamExt.d.ts +1 -0
  284. package/dist/training/AdamExt.js +2 -2
  285. package/dist/training/DatasetBuilder.js +3 -20
  286. package/dist/training/FullTrainer.js +55 -48
  287. package/dist/training/Trainer.d.ts +11 -6
  288. package/dist/training/Trainer.js +51 -39
  289. package/dist/training/sparseCrossEntropy.js +3 -3
  290. package/dist/transpose-DKELTqhe.js +38 -0
  291. package/dist/utilities/arrayClose.js +7 -7
  292. package/dist/utilities/dummy.js +35 -27
  293. package/dist/utilities/multinomialCPU.js +2 -2
  294. package/dist/utilities/packed.d.ts +7 -0
  295. package/dist/utilities/packed.js +716 -0
  296. package/dist/utilities/performance.js +1 -1
  297. package/dist/utilities/profile.js +1 -1
  298. package/dist/utilities/safetensors.js +2 -2
  299. package/dist/utilities/sentences.d.ts +5 -0
  300. package/dist/utilities/sentences.js +41 -0
  301. package/dist/utilities/weights.js +2 -2
  302. package/dist/variable-Bhn5bHYv.js +7 -0
  303. package/dist/{webgpu_program-DkQJOJSd.js → webgpu_program-Cigz-7RF.js} +15 -44
  304. package/dist/webgpu_util-BBCnKm2X.js +65 -0
  305. package/dist/zeros-2gldETuK.js +14 -0
  306. package/package.json +4 -3
  307. package/dist/Reshape-Bowtk9BP.js +0 -127
  308. package/dist/Reshape-DUqYftGC.js +0 -30
  309. package/dist/backend_util-CJIiDoV1.js +0 -749
  310. package/dist/broadcast_to-DzlNweb8.js +0 -44
  311. package/dist/concat-B912vBbo.js +0 -33
  312. package/dist/dropout-C-csYCLj.js +0 -193
  313. package/dist/exports_initializers-B8iZMgQ0.js +0 -16
  314. package/dist/gather-Dnpgw-YQ.js +0 -25
  315. package/dist/index-BzFyqcy-.js +0 -4457
  316. package/dist/index-C1rx_Ajs.js +0 -12076
  317. package/dist/kernel_funcs_utils-DKLK0Mg3.js +0 -466
  318. package/dist/log_sum_exp-DO6z8tSE.js +0 -103
  319. package/dist/mat_mul-DzjTFx-u.js +0 -27
  320. package/dist/mod-Dobti4j4.js +0 -27
  321. package/dist/ones-tIJeHlq-.js +0 -29
  322. package/dist/ops/fusedSoftmax.d.ts +0 -2
  323. package/dist/ops/fusedSoftmax.js +0 -10
  324. package/dist/ops/grads/fusedSoftmax.js +0 -22
  325. package/dist/ops-LuCMAnmM.js +0 -1525
  326. package/dist/random_width-CXVRloNK.js +0 -13670
  327. package/dist/range-CWcz7xFA.js +0 -26
  328. package/dist/reciprocal-C4rNcM-S.js +0 -25
  329. package/dist/relu-BjCh_SYb.js +0 -25
  330. package/dist/reshape-CnIwVG1c.js +0 -25
  331. package/dist/selu_util-OtRzVwW5.js +0 -719
  332. package/dist/shared-DmRsFyaJ.js +0 -3134
  333. package/dist/sin-gpDNRxE0.js +0 -47
  334. package/dist/slice-d0Vo9XTN.js +0 -28
  335. package/dist/softmax-D7Jj3p_P.js +0 -28
  336. package/dist/split-DK2k5eHf.js +0 -25
  337. package/dist/stack-DFatutCx.js +0 -27
  338. package/dist/sum-CJ0ULhmt.js +0 -27
  339. package/dist/tensor1d-vML0r3q6.js +0 -27
  340. package/dist/tensor2d-D76QGjF3.js +0 -30
  341. package/dist/tensor4d-Df1WlVDY.js +0 -30
  342. package/dist/webgpu_util-pLEV9tks.js +0 -80
  343. package/dist/zeros-Bj5rMYA7.js +0 -52
@@ -0,0 +1,716 @@
1
+ import { w as T, o as R, p as x, K as j, I as M, q as V, s as $, t as O, v as L, x as _, A as U } from "../tensor_util-DV-FP5Q3.js";
2
+ import { b as q, o as W, E as H, q as X, a as p, r as Y, t as Z, u as J, v as K, V as Q, w as N, T as B, x as P, f as tt, m as et, s as st, y as nt } from "../tensor-DdQUJZlz.js";
3
+ import { PackableTensor as A, PackableVariable as rt } from "../patches/PackedTensor.js";
4
+ function lt() {
5
+ return y.backendName === "webgpu";
6
+ }
7
+ function E(o) {
8
+ return o.packed !== void 0;
9
+ }
10
+ function w(o) {
11
+ return E(o) && o.packed;
12
+ }
13
+ function z(o) {
14
+ if (E(o)) {
15
+ if (o.dtype !== "int32")
16
+ throw new Error("packTensor: only int32 tensors can be packed.");
17
+ return o.packed = !0, o;
18
+ } else
19
+ throw console.error("Tensor:", o), new Error("Tensor is not packable");
20
+ }
21
+ function ft(o) {
22
+ if (E(o)) {
23
+ if (o.dtype !== "float32")
24
+ throw new Error("unpackTensor: only float32 tensors can be unpacked.");
25
+ o.packed = !1;
26
+ }
27
+ return o;
28
+ }
29
+ function it(o, t, e, s) {
30
+ for (let n = t.length - 1; n >= 0; n--) {
31
+ const r = t[n], c = [];
32
+ if (r.outputs.forEach((a) => {
33
+ const i = o[a.id];
34
+ i != null ? c.push(i) : c.push(null);
35
+ }), r.gradient == null)
36
+ throw new Error(`Cannot compute gradient: gradient function not found for ${r.kernelName}.`);
37
+ const h = r.gradient(c);
38
+ for (const a in r.inputs) {
39
+ if (!(a in h))
40
+ throw new Error(
41
+ `Cannot backprop through input ${a}. Available gradients found: ${Object.keys(h)}.`
42
+ );
43
+ const i = e(() => h[a]()), d = w(i);
44
+ if (i.dtype !== "float32" && (!d || i.dtype !== "int32"))
45
+ throw new Error(
46
+ `Error in gradient for op ${r.kernelName}. The gradient of input ${a} must have 'float32' dtype, but has '${i.dtype}'`
47
+ );
48
+ const l = r.inputs[a];
49
+ if (!q(i.shape, l.shape))
50
+ throw new Error(
51
+ `Error in gradient for op ${r.kernelName}. The gradient of input '${a}' has shape '${i.shape}', which does not match the shape of the input '${l.shape}'`
52
+ );
53
+ if (o[l.id] == null)
54
+ o[l.id] = i;
55
+ else {
56
+ const u = o[l.id];
57
+ o[l.id] = s(u, i), u.dispose();
58
+ }
59
+ }
60
+ }
61
+ }
62
+ function S(o) {
63
+ return o.kernelName != null;
64
+ }
65
+ class C {
66
+ // Public since optimizers will use it.
67
+ registeredVariables = {};
68
+ nextTapeNodeId = 0;
69
+ numBytes = 0;
70
+ numTensors = 0;
71
+ numStringTensors = 0;
72
+ numDataBuffers = 0;
73
+ activeTape;
74
+ // Number of nested tf.grad() statements when computing higher-order
75
+ // gradients. E.g. `1` for first-order gradients and `2` for second-order
76
+ // gradients. Used to track if the tape should be removed after a backprop.
77
+ gradientDepth = 0;
78
+ // Number of nested kernel calls. When kernel depth is greater than 1, we turn
79
+ // off the tape.
80
+ kernelDepth = 0;
81
+ // Keep Tensors that parallel the tapes.
82
+ activeScope;
83
+ scopeStack = [];
84
+ /**
85
+ * Keeps track of the number of data moves during a kernel execution. We
86
+ * maintain a stack since kernels can call other kernels, recursively.
87
+ */
88
+ numDataMovesStack = [];
89
+ nextScopeId = 0;
90
+ tensorInfo = /* @__PURE__ */ new WeakMap();
91
+ profiling = !1;
92
+ activeProfile = {
93
+ newBytes: 0,
94
+ newTensors: 0,
95
+ peakBytes: 0,
96
+ kernels: [],
97
+ result: null,
98
+ get kernelNames() {
99
+ return Array.from(new Set(this.kernels.map((t) => t.name)));
100
+ }
101
+ };
102
+ dispose() {
103
+ for (const t in this.registeredVariables)
104
+ this.registeredVariables[t].dispose();
105
+ }
106
+ }
107
+ class v {
108
+ constructor(t) {
109
+ this.ENV = t, this.state = new C(), console.log("GenAI Patched Engine Initialized");
110
+ }
111
+ version = "GENAI_PATCHED_ENGINE";
112
+ state;
113
+ backendName;
114
+ registry = {};
115
+ registryFactory = {};
116
+ profiler;
117
+ backendInstance = null;
118
+ pendingBackendInit;
119
+ pendingBackendInitId = 0;
120
+ async ready() {
121
+ if (this.pendingBackendInit != null)
122
+ return this.pendingBackendInit.then(() => {
123
+ });
124
+ if (this.backendInstance != null)
125
+ return;
126
+ const t = this.getSortedBackends();
127
+ for (let e = 0; e < t.length; e++) {
128
+ const s = t[e];
129
+ if (await this.initializeBackend(s).success) {
130
+ await this.setBackend(s);
131
+ return;
132
+ }
133
+ }
134
+ throw new Error("Could not initialize any backends, all backend initializations failed.");
135
+ }
136
+ get backend() {
137
+ if (this.pendingBackendInit != null)
138
+ throw new Error(
139
+ `Backend '${this.backendName}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
140
+ );
141
+ if (this.backendInstance == null) {
142
+ const { name: t, asyncInit: e } = this.initializeBackendsAndReturnBest();
143
+ if (e)
144
+ throw new Error(
145
+ `The highest priority backend '${t}' has not yet been initialized. Make sure to await tf.ready() or await tf.setBackend() before calling other methods`
146
+ );
147
+ this.setBackend(t);
148
+ }
149
+ return this.backendInstance;
150
+ }
151
+ backendNames() {
152
+ return Object.keys(this.registryFactory);
153
+ }
154
+ findBackend(t) {
155
+ if (!(t in this.registry))
156
+ if (t in this.registryFactory) {
157
+ const { asyncInit: e } = this.initializeBackend(t);
158
+ if (e)
159
+ return null;
160
+ } else
161
+ return null;
162
+ return this.registry[t];
163
+ }
164
+ findBackendFactory(t) {
165
+ return t in this.registryFactory ? this.registryFactory[t].factory : null;
166
+ }
167
+ registerBackend(t, e, s = 1) {
168
+ return t in this.registryFactory ? (T(`${t} backend was already registered. Reusing existing backend factory.`), !1) : (this.registryFactory[t] = { factory: e, priority: s }, console.log("Registered backend", t), !0);
169
+ }
170
+ async setBackend(t) {
171
+ if (this.registryFactory[t] == null)
172
+ throw new Error(`Backend name '${t}' not found in registry`);
173
+ if (this.backendName = t, this.registry[t] == null) {
174
+ this.backendInstance = null;
175
+ const { success: e, asyncInit: s } = this.initializeBackend(t);
176
+ if (!(s ? await e : e))
177
+ return !1;
178
+ }
179
+ return this.backendInstance = this.registry[t], this.setupRegisteredKernels(), this.profiler = new R(this.backendInstance), !0;
180
+ }
181
+ setupRegisteredKernels() {
182
+ x(this.backendName).forEach((e) => {
183
+ e.setupFunc != null && e.setupFunc(this.backendInstance);
184
+ });
185
+ }
186
+ disposeRegisteredKernels(t) {
187
+ x(t).forEach((s) => {
188
+ s.disposeFunc != null && s.disposeFunc(this.registry[t]);
189
+ });
190
+ }
191
+ /**
192
+ * Initializes a backend by looking up the backend name in the factory
193
+ * registry and calling the factory method. Returns a boolean representing
194
+ * whether the initialization of the backend succeeded. Throws an error if
195
+ * there is no backend in the factory registry.
196
+ */
197
+ initializeBackend(t) {
198
+ const e = this.registryFactory[t];
199
+ if (e == null)
200
+ throw new Error(`Cannot initialize backend ${t}, no registration found.`);
201
+ try {
202
+ const s = e.factory();
203
+ if (s && !(s instanceof j) && typeof s.then == "function") {
204
+ const n = ++this.pendingBackendInitId, r = s.then((c) => n < this.pendingBackendInitId ? !1 : (this.registry[t] = c, this.pendingBackendInit = null, !0)).catch((c) => (n < this.pendingBackendInitId || (this.pendingBackendInit = null, T(`Initialization of backend ${t} failed`), T(c.stack || c.message)), !1));
205
+ return this.pendingBackendInit = r, { success: r, asyncInit: !0 };
206
+ } else
207
+ return this.registry[t] = s, { success: !0, asyncInit: !1 };
208
+ } catch (s) {
209
+ return T(`Initialization of backend ${t} failed`), T(s.stack || s.message), { success: !1, asyncInit: !1 };
210
+ }
211
+ }
212
+ removeBackend(t) {
213
+ if (!(t in this.registryFactory))
214
+ throw new Error(`${t} backend not found in registry`);
215
+ this.backendName === t && this.pendingBackendInit != null && this.pendingBackendInitId++, t in this.registry && (this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t]), delete this.registryFactory[t], this.backendName === t && (this.pendingBackendInit = null, this.backendName = null, this.backendInstance = null);
216
+ }
217
+ getSortedBackends() {
218
+ if (Object.keys(this.registryFactory).length === 0)
219
+ throw new Error("No backend found in registry.");
220
+ return Object.keys(this.registryFactory).sort((t, e) => this.registryFactory[e].priority - this.registryFactory[t].priority);
221
+ }
222
+ initializeBackendsAndReturnBest() {
223
+ const t = this.getSortedBackends();
224
+ for (let e = 0; e < t.length; e++) {
225
+ const s = t[e], { success: n, asyncInit: r } = this.initializeBackend(s);
226
+ if (r || n)
227
+ return { name: s, asyncInit: r };
228
+ }
229
+ throw new Error("Could not initialize any backends, all backend initializations failed.");
230
+ }
231
+ moveData(t, e) {
232
+ const s = this.state.tensorInfo.get(e);
233
+ s || console.warn("Tried to move data that does not exist", this.state, e);
234
+ const n = s.backend, r = this.readSync(e), c = n.refCount(e);
235
+ n.disposeData(e, !0), s.backend = t, t.move(e, r, s.shape, s.dtype, c), this.shouldCheckForMemLeaks() && this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++;
236
+ }
237
+ tidy(t, e) {
238
+ let s = null;
239
+ if (e == null) {
240
+ if (typeof t != "function")
241
+ throw new Error("Please provide a function to tidy()");
242
+ e = t;
243
+ } else {
244
+ if (typeof t != "string" && !(t instanceof String))
245
+ throw new Error("When calling with two arguments, the first argument to tidy() must be a string");
246
+ if (typeof e != "function")
247
+ throw new Error("When calling with two arguments, the 2nd argument to tidy() must be a function");
248
+ s = t;
249
+ }
250
+ let n;
251
+ return this.scopedRun(
252
+ () => this.startScope(s),
253
+ () => this.endScope(n),
254
+ () => (n = e(), n instanceof Promise && console.error("Cannot return a Promise inside of tidy."), n)
255
+ );
256
+ }
257
+ scopedRun(t, e, s) {
258
+ t();
259
+ try {
260
+ const n = s();
261
+ return e(), n;
262
+ } catch (n) {
263
+ throw e(), n;
264
+ }
265
+ }
266
+ static nextTensorId = 0;
267
+ nextTensorId() {
268
+ return v.nextTensorId++;
269
+ }
270
+ static nextVariableId = 0;
271
+ nextVariableId() {
272
+ return v.nextVariableId++;
273
+ }
274
+ /**
275
+ * This method is called instead of the public-facing tensor.clone() when
276
+ * saving a tensor for backwards pass. It makes sure to add the clone
277
+ * operation to the tape regardless of being called inside a kernel
278
+ * execution.
279
+ */
280
+ clone(t) {
281
+ const s = w(t) ? z(y.runKernel(M, { x: t })) : y.runKernel(M, { x: t }), n = { x: t }, r = (h) => ({
282
+ x: () => {
283
+ const a = "float32", i = { x: h }, d = { dtype: a }, l = w(t), u = y.runKernel(
284
+ _,
285
+ i,
286
+ // tslint:disable-next-line: no-unnecessary-type-assertion
287
+ d
288
+ );
289
+ return l && z(u), u;
290
+ }
291
+ }), c = [];
292
+ return this.addTapeNode(this.state.activeScope.name, n, [s], r, c, {}), s;
293
+ }
294
+ /**
295
+ * Execute a kernel with the given name and return the output tensor.
296
+ *
297
+ * @param kernelName The name of the kernel to execute.
298
+ * @param inputs A map of input names to tensors.
299
+ * @param attrs A map of attribute names to their values. An attribute is a
300
+ * primitive (non-tensor) input to the kernel.
301
+ * @param inputsToSave A list of tensors, inputs to save for the backprop
302
+ * computation.
303
+ * @param outputsToSave A list of booleans, specifying which output to save
304
+ * for the backprop computation. These are booleans since the output
305
+ * tensors are not visible to the user.
306
+ */
307
+ runKernel(t, e, s) {
308
+ if (this.backendName == null && this.backend, !(V(t, this.backendName) != null))
309
+ throw new Error(`Kernel '${t}' not registered for backend '${this.backendName}'`);
310
+ return this.runKernelFunc({ kernelName: t, inputs: e, attrs: s });
311
+ }
312
+ shouldCheckForMemLeaks() {
313
+ return this.ENV.getBool("IS_TEST");
314
+ }
315
+ checkKernelForMemLeak(t, e, s) {
316
+ const n = this.backend.numDataIds();
317
+ let r = 0;
318
+ s.forEach((a) => {
319
+ r += a.dtype === "complex64" ? 3 : 1;
320
+ });
321
+ const c = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1], h = n - e - r - c;
322
+ if (h > 0)
323
+ throw new Error(
324
+ `Backend '${this.backendName}' has an internal memory leak (${h} data ids) after running '${t}'`
325
+ );
326
+ }
327
+ /**
328
+ * Internal helper method to execute a kernel Func
329
+ *
330
+ * Use `runKernel` to execute kernels from outside of engine.
331
+ */
332
+ runKernelFunc(t) {
333
+ let e, s = [];
334
+ const n = this.isTapeOn(), r = this.state.numBytes, c = this.state.numTensors;
335
+ this.shouldCheckForMemLeaks() && this.state.numDataMovesStack.push(0);
336
+ let h;
337
+ this.backendName == null && this.backend;
338
+ let a;
339
+ const i = S(t) ? t.kernelName : this.state.activeScope != null ? this.state.activeScope.name : "";
340
+ if (S(t)) {
341
+ const { kernelName: f, inputs: I, attrs: m } = t;
342
+ this.backendName == null && this.backend;
343
+ const g = V(f, this.backendName);
344
+ p(
345
+ g != null,
346
+ () => `Cannot find registered kernel '${f}' for backend '${this.backendName}'`
347
+ ), h = () => {
348
+ const G = this.backend.numDataIds();
349
+ a = g.kernelFunc({ inputs: I, attrs: m, backend: this.backend });
350
+ const F = Array.isArray(a) ? a : [a];
351
+ this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(f, G, F);
352
+ const D = F.map((b) => b.rank != null ? b : this.makeTensorFromTensorInfo(b));
353
+ if (n) {
354
+ const b = this.getTensorsForGradient(f, I, D);
355
+ s = this.saveTensorsForBackwardMode(b ?? []);
356
+ }
357
+ return D;
358
+ };
359
+ } else {
360
+ const { forwardFunc: f } = t, I = (m) => {
361
+ n && (s = m.map((g) => this.keep(this.clone(g))));
362
+ };
363
+ h = () => {
364
+ const m = this.backend.numDataIds();
365
+ a = this.tidy(() => f(this.backend, I));
366
+ const g = Array.isArray(a) ? a : [a];
367
+ return this.shouldCheckForMemLeaks() && this.checkKernelForMemLeak(i, m, g), g;
368
+ };
369
+ }
370
+ const { inputs: d, attrs: l } = t, u = S(t) ? null : t.backwardsFunc;
371
+ let k;
372
+ return this.scopedRun(
373
+ // Stop recording to a tape when running a kernel.
374
+ () => this.state.kernelDepth++,
375
+ () => this.state.kernelDepth--,
376
+ () => {
377
+ !this.ENV.getBool("DEBUG") && !this.state.profiling ? e = h() : (k = this.profiler.profileKernel(i, d, () => h()), this.ENV.getBool("DEBUG") && this.profiler.logKernelProfile(k), e = k.outputs);
378
+ }
379
+ ), n && this.addTapeNode(
380
+ i,
381
+ d,
382
+ e,
383
+ u,
384
+ s,
385
+ l ?? {}
386
+ ), this.state.profiling && this.state.activeProfile.kernels.push({
387
+ name: i,
388
+ bytesAdded: this.state.numBytes - r,
389
+ totalBytesSnapshot: this.state.numBytes,
390
+ tensorsAdded: this.state.numTensors - c,
391
+ totalTensorsSnapshot: this.state.numTensors,
392
+ inputShapes: Object.keys(d).map(
393
+ (f) => d[f] != null ? d[f].shape : null
394
+ ),
395
+ outputShapes: e.map((f) => f.shape),
396
+ kernelTimeMs: k.timeMs,
397
+ extraInfo: k.extraInfo
398
+ }), Array.isArray(a) ? e : e[0];
399
+ }
400
+ /**
401
+ * Saves tensors used in forward mode for use in backward mode.
402
+ *
403
+ * @param tensors the list of tensors to save.
404
+ */
405
+ saveTensorsForBackwardMode(t) {
406
+ return t.map((s) => this.keep(this.clone(s)));
407
+ }
408
+ /**
409
+ * Returns a list of tensors to save for a given gradient calculation.
410
+ *
411
+ * @param kernelName name of kernel to look up gradient for.
412
+ * @param inputs a map of input tensors.
413
+ * @param outputs an array of output tensors from forward mode of kernel.
414
+ */
415
+ getTensorsForGradient(t, e, s) {
416
+ const n = $(t);
417
+ if (n != null) {
418
+ const r = n.inputsToSave || [], c = n.outputsToSave || [];
419
+ let h;
420
+ n.saveAllInputs ? (p(Array.isArray(e), () => "saveAllInputs is true, expected inputs to be an array."), h = Object.keys(e).map((i) => e[i])) : h = r.map((i) => e[i]);
421
+ const a = s.filter((i, d) => c[d]);
422
+ return h.concat(a);
423
+ }
424
+ return [];
425
+ }
426
+ /**
427
+ * Internal method used by public APIs for tensor creation. Makes a new
428
+ * tensor with the provided shape, dtype and values. It always
429
+ * creates a new data id and writes the values to the underlying backend.
430
+ */
431
+ makeTensor(t, e, s, n) {
432
+ if (t == null)
433
+ throw new Error("Values passed to engine.makeTensor() are null");
434
+ s = s || "float32", n = n || this.backend;
435
+ let r = t;
436
+ s === "string" && Y(t[0]) && (r = t.map((a) => Z(a)));
437
+ const c = n.write(r, e, s), h = new A(e, s, c, this.nextTensorId());
438
+ if (this.trackTensor(h, n), s === "string") {
439
+ const a = this.state.tensorInfo.get(c), i = J(r);
440
+ this.state.numBytes += i - a.bytes, a.bytes = i;
441
+ }
442
+ return h;
443
+ }
444
+ /**
445
+ * Internal method used by backends. Makes a new tensor
446
+ * that is a wrapper around an existing data id. It doesn't create
447
+ * a new data id, only increments the ref count used in memory tracking.
448
+ * @deprecated
449
+ */
450
+ makeTensorFromDataId(t, e, s, n) {
451
+ s = s || "float32";
452
+ const r = { dataId: t, shape: e, dtype: s };
453
+ return this.makeTensorFromTensorInfo(r, n);
454
+ }
455
+ /**
456
+ * Internal method used by backends. Makes a new tensor that is a wrapper
457
+ * around an existing data id in TensorInfo. It doesn't create a new data id,
458
+ * only increments the ref count used in memory tracking.
459
+ */
460
+ makeTensorFromTensorInfo(t, e) {
461
+ const { dataId: s, shape: n, dtype: r } = t, c = new A(n, r, s, this.nextTensorId());
462
+ if (c.packed = t.packed || !1, c.packed && r !== "int32")
463
+ throw new Error("Only int32 tensors can be packed.");
464
+ return this.trackTensor(c, e ?? this.backend), c;
465
+ }
466
+ makeVariable(t, e = !0, s, n) {
467
+ s = s || this.nextVariableId().toString(), n != null && n !== t.dtype && (t = t.cast(n));
468
+ const r = new rt(t, e, s, this.nextTensorId());
469
+ if (this.state.registeredVariables[r.name] != null)
470
+ throw new Error(`Variable with name ${r.name} was already registered`);
471
+ return this.state.registeredVariables[r.name] = r, this.incRef(r, this.backend), r;
472
+ }
473
+ trackTensor(t, e) {
474
+ this.state.numTensors++, t.dtype === "string" && this.state.numStringTensors++;
475
+ let s = 0;
476
+ t.dtype !== "complex64" && t.dtype !== "string" && (s = t.size * K(t.dtype)), this.state.numBytes += s, this.state.tensorInfo.has(t.dataId) || (this.state.numDataBuffers++, this.state.tensorInfo.set(t.dataId, {
477
+ backend: e || this.backend,
478
+ dtype: t.dtype,
479
+ shape: t.shape,
480
+ bytes: s
481
+ })), t instanceof Q || this.track(t);
482
+ }
483
+ // Track the tensor by dataId and increase the refCount for the dataId in the
484
+ // backend.
485
+ // TODO(pyu10055): This is currently used by makeVariable method, to increase
486
+ // refCount on the backend for the dataId. It can potentially be replaced with
487
+ // Identity op indead of calling backend directly.
488
+ incRef(t, e) {
489
+ this.trackTensor(t, e), this.backend.incRef(t.dataId);
490
+ }
491
+ removeDataId(t, e) {
492
+ this.state.tensorInfo.has(t) && this.state.tensorInfo.get(t).backend === e && (this.state.tensorInfo.delete(t), this.state.numDataBuffers--);
493
+ }
494
+ disposeTensor(t) {
495
+ if (!this.state.tensorInfo.has(t.dataId))
496
+ return;
497
+ const e = this.state.tensorInfo.get(t.dataId);
498
+ if (this.state.numTensors--, t.dtype === "string" && (this.state.numStringTensors--, this.state.numBytes -= e.bytes), t.dtype !== "complex64" && t.dtype !== "string") {
499
+ const s = t.size * K(t.dtype);
500
+ this.state.numBytes -= s;
501
+ }
502
+ e.backend.disposeData(t.dataId) && this.removeDataId(t.dataId, e.backend);
503
+ }
504
+ disposeVariables() {
505
+ for (const t in this.state.registeredVariables) {
506
+ const e = this.state.registeredVariables[t];
507
+ this.disposeVariable(e);
508
+ }
509
+ }
510
+ disposeVariable(t) {
511
+ this.disposeTensor(t), this.state.registeredVariables[t.name] != null && delete this.state.registeredVariables[t.name];
512
+ }
513
+ memory() {
514
+ const t = this.backend.memory();
515
+ return t.numTensors = this.state.numTensors, t.numDataBuffers = this.state.numDataBuffers, t.numBytes = this.state.numBytes, this.state.numStringTensors > 0 && (t.unreliable = !0, t.reasons == null && (t.reasons = []), t.reasons.push("Memory usage by string tensors is approximate (2 bytes per character)")), t;
516
+ }
517
+ async profile(t) {
518
+ this.state.profiling = !0;
519
+ const e = this.state.numBytes, s = this.state.numTensors;
520
+ this.state.activeProfile.kernels = [], this.state.activeProfile.result = await t(), this.state.profiling = !1, this.state.activeProfile.peakBytes = Math.max(
521
+ ...this.state.activeProfile.kernels.map((n) => n.totalBytesSnapshot)
522
+ ), this.state.activeProfile.newBytes = this.state.numBytes - e, this.state.activeProfile.newTensors = this.state.numTensors - s;
523
+ for (const n of this.state.activeProfile.kernels)
524
+ n.kernelTimeMs = await n.kernelTimeMs, n.extraInfo = await n.extraInfo;
525
+ return this.state.activeProfile;
526
+ }
527
+ isTapeOn() {
528
+ return this.state.gradientDepth > 0 && this.state.kernelDepth === 0;
529
+ }
530
+ addTapeNode(t, e, s, n, r, c) {
531
+ const h = { id: this.state.nextTapeNodeId++, kernelName: t, inputs: e, outputs: s, saved: r }, a = $(t);
532
+ a != null && (n = a.gradFunc), n != null && (h.gradient = (i) => (i = i.map((d, l) => {
533
+ if (d == null) {
534
+ const u = s[l], k = tt(u.size, u.dtype);
535
+ return this.makeTensor(k, u.shape, u.dtype);
536
+ }
537
+ return d;
538
+ }), n(i.length > 1 ? i : i[0], r, c))), this.state.activeTape.push(h);
539
+ }
540
+ keep(t) {
541
+ return t.kept = !0, t;
542
+ }
543
+ startTape() {
544
+ this.state.gradientDepth === 0 && (this.state.activeTape = []), this.state.gradientDepth++;
545
+ }
546
+ endTape() {
547
+ this.state.gradientDepth--;
548
+ }
549
+ /**
550
+ * Start a scope. Use this with endScope() to achieve the same functionality
551
+ * as scope() without the need for a function closure.
552
+ */
553
+ startScope(t) {
554
+ const e = {
555
+ track: [],
556
+ name: "unnamed scope",
557
+ id: this.state.nextScopeId++
558
+ };
559
+ t && (e.name = t), this.state.scopeStack.push(e), this.state.activeScope = e;
560
+ }
561
+ /**
562
+ * End a scope. Use this with startScope() to achieve the same functionality
563
+ * as scope() without the need for a function closure.
564
+ */
565
+ endScope(t) {
566
+ const e = O(t), s = new Set(e.map((r) => r.id));
567
+ for (let r = 0; r < this.state.activeScope.track.length; r++) {
568
+ const c = this.state.activeScope.track[r];
569
+ !c.kept && !s.has(c.id) && c.dispose();
570
+ }
571
+ const n = this.state.scopeStack.pop();
572
+ this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1], e.forEach((r) => {
573
+ !r.kept && r.scopeId === n?.id && this.track(r);
574
+ });
575
+ }
576
+ /**
577
+ * Returns gradients of `f` with respect to each of the `xs`. The gradients
578
+ * returned are of the same length as `xs`, but some might be null if `f`
579
+ * was not a function of that `x`. It also takes optional dy to multiply the
580
+ * gradient, which defaults to `1`.
581
+ */
582
+ gradients(t, e, s, n = !1) {
583
+ if (p(e.length > 0, () => "gradients() received an empty list of xs."), s != null && s.dtype !== "float32")
584
+ throw new Error(`dy must have 'float32' dtype, but has '${s.dtype}'`);
585
+ const r = this.scopedRun(
586
+ () => this.startTape(),
587
+ () => this.endTape(),
588
+ () => this.tidy("forward", t)
589
+ );
590
+ p(r instanceof B, () => "The result y returned by f() must be a tensor.");
591
+ const c = L(this.state.activeTape, e, r);
592
+ if (!n && c.length === 0 && e.length > 0)
593
+ throw new Error(
594
+ "Cannot compute gradient of y=f(x) with respect to x. Make sure that the f you passed encloses all operations that lead from x to y."
595
+ );
596
+ return this.tidy("backward", () => {
597
+ const h = {};
598
+ h[r.id] = s ?? at(r.shape), it(
599
+ h,
600
+ c,
601
+ // Pass the tidy function to avoid circular dep with `tape.ts`.
602
+ (i) => this.tidy(i),
603
+ // Pass an add function to avoide a circular dep with `tape.ts`.
604
+ ct
605
+ );
606
+ const a = e.map((i) => h[i.id]);
607
+ return this.state.gradientDepth === 0 && (this.state.activeTape.forEach((i) => {
608
+ if (i.saved !== void 0)
609
+ for (const d of i.saved)
610
+ d.dispose();
611
+ }), this.state.activeTape = null), { value: r, grads: a };
612
+ });
613
+ }
614
+ customGrad(t) {
615
+ return p(N(t), () => "The f passed in customGrad(f) must be a function."), (...e) => {
616
+ p(
617
+ e.every((h) => h instanceof B),
618
+ () => "The args passed in customGrad(f)(x1, x2,...) must all be tensors"
619
+ );
620
+ let s;
621
+ const n = {};
622
+ e.forEach((h, a) => {
623
+ n[a] = h;
624
+ });
625
+ const r = (h, a) => (s = t(...e, a), p(
626
+ s.value instanceof B,
627
+ () => "The function f passed in customGrad(f) must return an object where `obj.value` is a tensor"
628
+ ), p(
629
+ N(s.gradFunc),
630
+ () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function."
631
+ ), s.value), c = (h, a) => {
632
+ const i = s.gradFunc(h, a), d = Array.isArray(i) ? i : [i];
633
+ p(
634
+ d.length === e.length,
635
+ () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns the same number of tensors as inputs passed to f(...)."
636
+ ), p(
637
+ d.every((u) => u instanceof B),
638
+ () => "The function f passed in customGrad(f) must return an object where `obj.gradFunc` is a function that returns a list of only tensors."
639
+ );
640
+ const l = {};
641
+ return d.forEach((u, k) => {
642
+ l[k] = () => u;
643
+ }), l;
644
+ };
645
+ return this.runKernelFunc({
646
+ forwardFunc: r,
647
+ backwardsFunc: c,
648
+ inputs: n
649
+ });
650
+ };
651
+ }
652
+ readSync(t) {
653
+ return this.state.tensorInfo.get(t).backend.readSync(t);
654
+ }
655
+ read(t) {
656
+ return this.state.tensorInfo.get(t).backend.read(t);
657
+ }
658
+ readToGPU(t, e) {
659
+ return this.state.tensorInfo.get(t).backend.readToGPU(t, e);
660
+ }
661
+ async time(t) {
662
+ const e = P(), s = await this.backend.time(t);
663
+ return s.wallMs = P() - e, s;
664
+ }
665
+ /**
666
+ * Tracks a Tensor in the current scope to be automatically cleaned up
667
+ * when the current scope ends, and returns the value.
668
+ *
669
+ * @param result The Tensor to track in the current scope.
670
+ */
671
+ track(t) {
672
+ return this.state.activeScope != null && (t.scopeId = this.state.activeScope.id, this.state.activeScope.track.push(t)), t;
673
+ }
674
+ get registeredVariables() {
675
+ return this.state.registeredVariables;
676
+ }
677
+ /**
678
+ * Resets the engine state. Removes all backends but does not remove
679
+ * registered backend factories.
680
+ */
681
+ reset() {
682
+ this.pendingBackendInitId++, this.state.dispose(), this.ENV.reset(), this.state = new C();
683
+ for (const t in this.registry)
684
+ this.disposeRegisteredKernels(t), this.registry[t].dispose(), delete this.registry[t];
685
+ this.backendName = null, this.backendInstance = null, this.pendingBackendInit = null;
686
+ }
687
+ }
688
+ function at(o) {
689
+ const t = et(st(o), "float32");
690
+ return y.makeTensor(t, o, "float32");
691
+ }
692
+ function ot() {
693
+ const o = W();
694
+ if (o._tfengine == null) {
695
+ const t = new H(o);
696
+ o._tfengine = new v(t);
697
+ }
698
+ return X(o._tfengine.ENV), nt(() => o._tfengine), o._tfengine;
699
+ }
700
+ const y = ot();
701
+ function ct(o, t) {
702
+ const e = w(o) || w(t), s = { a: o, b: t };
703
+ return y.runKernel(e ? "Add16" : U, s);
704
+ }
705
+ export {
706
+ v as E,
707
+ y as a,
708
+ it as b,
709
+ ct as c,
710
+ ot as g,
711
+ E as isPackableTensor,
712
+ w as isPackedTensor,
713
+ z as packTensor,
714
+ lt as packingSupported,
715
+ ft as unpackTensor
716
+ };