@genai-fi/nanogpt 0.7.3 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +25 -2
- package/dist/Generator.js +152 -49
- package/dist/{RealDiv-Dy0p8Bvo.js → RealDiv-D_q39E3A.js} +13 -13
- package/dist/{Reshape-DvudQDvJ.js → Reshape-41YpQqEo.js} +1 -1
- package/dist/{Reshape-DH5srBP0.js → Reshape-Bh_jzKzV.js} +5 -5
- package/dist/TeachableLLM.d.ts +6 -6
- package/dist/TeachableLLM.js +33 -31
- package/dist/Trainer.d.ts +13 -2
- package/dist/Trainer.js +21 -12
- package/dist/{axis_util-BzbKo31C.js → axis_util-Did9235A.js} +3 -3
- package/dist/backend.js +2 -2
- package/dist/{backend_util-TE7aTPhZ.js → backend_util-yC3YH1jo.js} +58 -58
- package/dist/{broadcast_to-CdbwV-Dj.js → broadcast_to-CUvOdOT5.js} +2 -2
- package/dist/checks/appendCache.d.ts +1 -0
- package/dist/checks/appendCache.js +22 -0
- package/dist/checks/attentionMask.d.ts +1 -0
- package/dist/checks/attentionMask.js +37 -0
- package/dist/checks/check.d.ts +9 -0
- package/dist/checks/check.js +20 -0
- package/dist/checks/gelu.d.ts +1 -0
- package/dist/checks/gelu.js +18 -0
- package/dist/checks/index.d.ts +19 -0
- package/dist/checks/index.js +21 -0
- package/dist/checks/normRMS.d.ts +1 -0
- package/dist/checks/normRMS.js +16 -0
- package/dist/checks/normRMSGrad.d.ts +1 -0
- package/dist/checks/normRMSGrad.js +12 -0
- package/dist/checks/qkv.d.ts +1 -0
- package/dist/checks/qkv.js +25 -0
- package/dist/checks/rope.d.ts +1 -0
- package/dist/checks/rope.js +21 -0
- package/dist/{concat-CsxrgovM.js → concat-pHiVqR3L.js} +1 -1
- package/dist/{dataset-CtdBYwjo.js → dataset-DPPl-iLT.js} +9 -9
- package/dist/{dropout-DYs5QFGQ.js → dropout-CcKSfOYE.js} +18 -18
- package/dist/exports_initializers-DKk7-bsx.js +16 -0
- package/dist/{gather-CMMy2KEG.js → gather-CPg6ZlQA.js} +1 -1
- package/dist/{gelu-C-dPj6Ku.js → gelu-BkcmEEyD.js} +1 -1
- package/dist/{gpgpu_math-DGNLNL4I.js → gpgpu_math-D_ODOLix.js} +26 -26
- package/dist/{index-BoWRt-10.js → index-DdmHGZjq.js} +659 -650
- package/dist/{index-CLthM0TO.js → index-evZ57wr4.js} +185 -185
- package/dist/{kernel_funcs_utils-BYKWV8Aa.js → kernel_funcs_utils-CDfFpUab.js} +21 -21
- package/dist/layers/BaseLayer.d.ts +8 -13
- package/dist/layers/BaseLayer.js +25 -13
- package/dist/layers/CausalSelfAttention.d.ts +3 -2
- package/dist/layers/CausalSelfAttention.js +28 -28
- package/dist/layers/MLP.d.ts +3 -2
- package/dist/layers/MLP.js +16 -20
- package/dist/layers/PositionEmbedding.d.ts +9 -0
- package/dist/layers/PositionEmbedding.js +45 -0
- package/dist/layers/RMSNorm.d.ts +3 -2
- package/dist/layers/RMSNorm.js +6 -6
- package/dist/layers/RoPECache.d.ts +1 -1
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.d.ts +3 -2
- package/dist/layers/TiedEmbedding.js +29 -7
- package/dist/layers/TransformerBlock.d.ts +3 -2
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/load.d.ts +2 -2
- package/dist/loader/loadHF.d.ts +2 -2
- package/dist/loader/loadTransformers.d.ts +4 -2
- package/dist/loader/loadTransformers.js +10 -9
- package/dist/loader/newZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.d.ts +2 -2
- package/dist/loader/oldZipLoad.js +44 -51
- package/dist/loader/save.d.ts +8 -0
- package/dist/loader/save.js +62 -0
- package/dist/{log_sum_exp-DbjkV734.js → log_sum_exp-C8yFJfZz.js} +45 -24
- package/dist/main.d.ts +6 -4
- package/dist/main.js +24 -18
- package/dist/{mat_mul-8m8pfdcx.js → mat_mul-Dpy2mMRu.js} +1 -1
- package/dist/mod-CbibJi3D.js +27 -0
- package/dist/models/NanoGPTV1.d.ts +15 -0
- package/dist/models/NanoGPTV1.js +71 -0
- package/dist/{config.d.ts → models/config.d.ts} +1 -0
- package/dist/{config.js → models/config.js} +1 -0
- package/dist/models/factory.d.ts +3 -0
- package/dist/models/factory.js +14 -0
- package/dist/models/model.d.ts +26 -0
- package/dist/models/model.js +70 -0
- package/dist/{mulmat_packed_gpu-VSekgsNv.js → mulmat_packed_gpu-q_Gmwyld.js} +1 -1
- package/dist/{ones-Dj0SDhHf.js → ones-BAqVh-eA.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +9 -9
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +7 -7
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +10 -10
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +3 -3
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +3 -3
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +5 -5
- package/dist/ops/webgpu/qkv.js +3 -3
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/ops-542ai2vG.js +1525 -0
- package/dist/{random_width-sZORGo5k.js → random_width-DKGeiFuR.js} +1471 -1538
- package/dist/{range-CRuAh-gd.js → range-BcUvLuf5.js} +1 -1
- package/dist/{reciprocal-BvGAyKyu.js → reciprocal-DhDWSKiD.js} +1 -1
- package/dist/{register_all_kernels-BwDSRN-f.js → register_all_kernels-Do9VvZmo.js} +2488 -2534
- package/dist/{max-Ddnnb5xe.js → relu-B1AXs7p5.js} +6 -6
- package/dist/{reshape-CdBq1WJ6.js → reshape-WeJkT3ja.js} +1 -1
- package/dist/{scatter_nd_util-DUstGbU1.js → scatter_nd_util-B7yDhiQr.js} +1 -1
- package/dist/{selu_util-BJEXVvjX.js → selu_util-BgUO9gHY.js} +125 -146
- package/dist/{shared-wS99K7_n.js → shared-CZiWmQCI.js} +1 -1
- package/dist/{shared-B8ztnyEk.js → shared-V6D_md-c.js} +72 -72
- package/dist/{sin-BeA3tsEd.js → sin-CPxad7Am.js} +1 -1
- package/dist/{slice-BiOsknYS.js → slice-B7jXtPnp.js} +1 -1
- package/dist/{softmax-Bv_6lyMX.js → softmax-BfsyI4As.js} +1 -1
- package/dist/{split-B-dikLRw.js → split-BPxr8_8m.js} +1 -1
- package/dist/{stack-B17UN2nn.js → stack-BNwLzE43.js} +1 -1
- package/dist/{sum-66ew2byf.js → sum-ByFINZgi.js} +3 -3
- package/dist/{tensor-JwS7ZYY6.js → tensor-DbqgIV9B.js} +1 -1
- package/dist/tensor1d-CtJq5BOv.js +27 -0
- package/dist/{tensor2d-wxPAnDQy.js → tensor2d-CObBWBkW.js} +1 -1
- package/dist/tensor3d-BOukqWwr.js +30 -0
- package/dist/tensor4d-DLtk7Nxh.js +30 -0
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/Evaluator.d.ts +2 -2
- package/dist/training/FullTrainer.d.ts +3 -3
- package/dist/training/FullTrainer.js +61 -69
- package/dist/training/Trainer.d.ts +15 -3
- package/dist/training/Trainer.js +39 -47
- package/dist/training/sparseCrossEntropy.js +12 -13
- package/dist/utilities/arrayClose.d.ts +1 -1
- package/dist/utilities/arrayClose.js +16 -7
- package/dist/utilities/dummy.d.ts +4 -4
- package/dist/utilities/dummy.js +13 -13
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/parameters.d.ts +1 -1
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-BuddVFLa.js → variable-DPFOJyRG.js} +1 -1
- package/dist/{webgpu_program-PFzf1hAQ.js → webgpu_program-Dhk9R5aG.js} +1 -1
- package/dist/{webgpu_util-D____QpY.js → webgpu_util-BqGnZg8t.js} +27 -27
- package/dist/{zeros--BdLQ3oG.js → zeros-Dnwix0p4.js} +1 -1
- package/package.json +2 -3
- package/dist/NanoGPTModel.d.ts +0 -52
- package/dist/NanoGPTModel.js +0 -203
- package/dist/TiedEmbedding-BxOerUmB.js +0 -43
- package/dist/ops-BFGCx8Ri.js +0 -1202
- package/dist/utilities/generate.d.ts +0 -3
- package/dist/utilities/generate.js +0 -22
- package/dist/utilities/save.d.ts +0 -9
- package/dist/utilities/save.js +0 -61
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { m as hr, c as pr, P as He, t as
|
|
3
|
-
import { i as wt, G as it, a as gr, c as S, f as I, M as j, b as bt, d as yt, e as St } from "./webgpu_util-
|
|
4
|
-
import { m as rt, E as xr, u as Cr, w as wr, x as br, y as yr, z as Sr, f as at, A as vt, B as It, C as kt, D as vr, F as Ir, G as xe, H as kr, I as Rr, J as Pr, K as Dr, L as Nr, M as $r, N as zr, O as Ar, P as Fr, Q as Lr, S as
|
|
5
|
-
import { S as
|
|
6
|
-
import { r as R, a as Mr } from "./Reshape-
|
|
7
|
-
import { s as Or } from "./shared-
|
|
8
|
-
import { c as qe, g as Se, a as ve, b as Ye, e as Gr, h as Pt } from "./axis_util-
|
|
9
|
-
import { z as Hr } from "./zeros
|
|
1
|
+
import { ag as U, d4 as Jt, d5 as es, e as We, n as L, d6 as _e, j as $, aO as Ge, ap as de, d7 as st, d8 as ts, d9 as ss, aP as os, bn as Ct, a3 as Re, da as is, db as rs, dc as as, bv as ns, l as Z, cg as us, dd as ot, az as ds, aa as ls, u as ge, bm as cs, co as hs, cp as ps, bt as fs, cq as ms, bc as gs, af as ze, p as se, aB as xs, bJ as Cs, bK as ws, bL as bs, cr as ys, cs as Ss, ct as vs, cv as Is, cu as ks, cw as Rs, ah as Ps, b5 as Ds, bM as Ns, bN as $s, cx as zs, cy as As, U as Fs, $ as Ls, bP as Bs, aS as Es, de as Ts, b7 as Ws, b8 as _s, bq as Vs, br as Us, bw as Ms, bR as Os, cA as Gs, a$ as Hs, I as Xs, bS as Ks, cc as qs, bT as Ys, bU as Qs, cB as js, bV as Zs, ac as Js, bW as eo, bd as to, bX as so, bY as oo, bZ as io, df as ro, b_ as ao, ce as no, cf as uo, dg as lo, cC as co, cD as ho, cE as po, dh as fo, bB as mo, a1 as go, aU as xo, as as Co, cF as wo, bx as bo, b$ as yo, ai as So, aY as vo, by as Io, di as ko, be as Ro, ao as Po, bz as Do, dj as No, bQ as $o, cd as zo, dk as Ao, a8 as Fo, G as Lo, aZ as Bo, a_ as Eo, dl as To, cG as Wo, cH as _o, cI as Vo, at as Uo, b0 as Mo, b1 as Oo, dm as Go, aj as Ho, b2 as Xo, b4 as Ko, c1 as qo, dn as Yo, cL as Qo, cK as jo, bA as Zo, c2 as Jo, c3 as ei, cM as ti, cN as si, dp as oi, aV as ii, b6 as ri, cO as ai, Y as ni, S as ui, a6 as di, b3 as li, bg as ci, bh as hi, c4 as pi, cW as fi, c5 as mi, P as gi, a4 as xi, c6 as Ci, cP as wi, au as bi, bC as yi, R as Si, aC as vi, Z as Ii, _ as ki, av as Ri, bj as Pi, cQ as Di, bk as Ni, cR as $i, c8 as zi, bf as Ai, b9 as Fi, bD as Li, a7 as Bi, dq as Ei, aT as Ti, c9 as Wi, ar as _i, cS as Vi, ad as Ui, ca as Mi, c0 as Oi, c7 as Gi, dr as Hi, ds as Xi, X as Ki, dt as qi, ab as Yi, W as Qi, bF as ji, cT as Zi, ba as Ji, aw as er, du as tr, dv as sr, bH as or, cU as ir, bO as rr, dw as ar, dx as nr, bl as ur, bb as dr, cb as lr, f as cr } from "./index-DdmHGZjq.js";
|
|
2
|
+
import { m as hr, c as pr, P as He, t as B, g as y, a as J, b as q, d as Pe, e as fr, f as mr } from "./webgpu_program-Dhk9R5aG.js";
|
|
3
|
+
import { i as wt, G as it, a as gr, c as S, f as I, M as j, b as bt, d as yt, e as St } from "./webgpu_util-BqGnZg8t.js";
|
|
4
|
+
import { m as rt, E as xr, u as Cr, w as wr, x as br, y as yr, z as Sr, f as at, A as vt, B as It, C as kt, D as vr, F as Ir, G as xe, H as kr, I as Rr, J as Pr, K as Dr, L as Nr, M as $r, N as zr, O as Ar, P as Fr, Q as Lr, S as Br } from "./backend_util-yC3YH1jo.js";
|
|
5
|
+
import { S as Er, a as Tr, h as be, i as Ae, p as Wr, q as _r, j as ye, d as ee, e as Xe, g as Ke, k as Rt, A as Vr, B as Ur } from "./selu_util-BgUO9gHY.js";
|
|
6
|
+
import { r as R, a as Mr } from "./Reshape-41YpQqEo.js";
|
|
7
|
+
import { s as Or } from "./shared-V6D_md-c.js";
|
|
8
|
+
import { c as qe, g as Se, a as ve, b as Ye, e as Gr, h as Pt } from "./axis_util-Did9235A.js";
|
|
9
|
+
import { z as Hr } from "./zeros-Dnwix0p4.js";
|
|
10
10
|
import { n as Xr, a as Kr } from "./non_max_suppression_impl-CsEgBuMA.js";
|
|
11
|
-
import { c as Qe } from "./scatter_nd_util-
|
|
11
|
+
import { c as Qe } from "./scatter_nd_util-B7yDhiQr.js";
|
|
12
12
|
/**
|
|
13
13
|
* @license
|
|
14
14
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
@@ -230,7 +230,7 @@ class Fe extends Jt {
|
|
|
230
230
|
constructor(t, e) {
|
|
231
231
|
if (super(), this.commandQueueOwnedIds = /* @__PURE__ */ new WeakSet(), this.dispatchCountInPass = 0, this.disposed = !1, this.downloadWaitMs = 0, this.tensorDataPendingDisposal = [], this.queryResolveBuffer = null, this.querySet = null, this.querySetCount = 2, this.stagingPendingDisposal = [], this.uniformPendingDisposal = [], this.uploadWaitMs = 0, this.hasReadSyncWarned = !1, this.hasTimestampQueryWarned = !1, !wt())
|
|
232
232
|
throw new Error("WebGPU is not supported on this device");
|
|
233
|
-
this.pipelineCache = {}, this.device = t, this.queue = t.queue, this.commandEncoder = null, this.computePassEncoder = null, this.adapterInfo = new qr(e), this.supportTimestampQuery = this.device.features.has("timestamp-query"), this.thresholdToIncreaseWorkgroups = this.adapterInfo.intelGPUGeneration >= 12 ? 16 : 8, this.bufferManager = new Yr(this.device), this.textureManager = new Qr(this.device), this.tensorMap = new es(this,
|
|
233
|
+
this.pipelineCache = {}, this.device = t, this.queue = t.queue, this.commandEncoder = null, this.computePassEncoder = null, this.adapterInfo = new qr(e), this.supportTimestampQuery = this.device.features.has("timestamp-query"), this.thresholdToIncreaseWorkgroups = this.adapterInfo.intelGPUGeneration >= 12 ? 16 : 8, this.bufferManager = new Yr(this.device), this.textureManager = new Qr(this.device), this.tensorMap = new es(this, We()), U().getBool("WEBGPU_USE_PROFILE_TOOL") && (this.dummyCanvas = document.createElement("canvas"), this.dummyCanvas.width = 1, this.dummyCanvas.height = 1, this.dummyContext = this.dummyCanvas.getContext("webgpu"), this.dummyContext.configure({
|
|
234
234
|
device: t,
|
|
235
235
|
format: "bgra8unorm"
|
|
236
236
|
}), document.body.appendChild(this.dummyCanvas));
|
|
@@ -351,28 +351,28 @@ class Fe extends Jt {
|
|
|
351
351
|
alphaMode: r[g]
|
|
352
352
|
}), x.getCurrentTexture();
|
|
353
353
|
}).map((m, g) => {
|
|
354
|
-
const x = h * 4, C = (A, z,
|
|
354
|
+
const x = h * 4, C = (A, z, E) => {
|
|
355
355
|
this.ensureCommandEncoderReady(), this.commandEncoder.copyBufferToTexture({
|
|
356
356
|
buffer: a,
|
|
357
357
|
bytesPerRow: x,
|
|
358
|
-
offset:
|
|
358
|
+
offset: E
|
|
359
359
|
}, {
|
|
360
360
|
texture: m
|
|
361
361
|
}, {
|
|
362
362
|
width: A,
|
|
363
363
|
height: z
|
|
364
364
|
}), this.submitQueue();
|
|
365
|
-
const
|
|
365
|
+
const T = p.getContext("2d", {
|
|
366
366
|
willReadFrequently: !0
|
|
367
367
|
});
|
|
368
|
-
|
|
369
|
-
const G =
|
|
370
|
-
for (let V = 0; V <
|
|
368
|
+
T.clearRect(0, 0, A, z), T.drawImage(c[g], 0, 0);
|
|
369
|
+
const G = T.getImageData(0, 0, A, z).data, M = r[g], W = new Uint8ClampedArray(d, E, A * z * 4);
|
|
370
|
+
for (let V = 0; V < W.length; V += 4)
|
|
371
371
|
if (M === "premultiplied")
|
|
372
|
-
|
|
372
|
+
W[V + 3] = G[V + 3];
|
|
373
373
|
else {
|
|
374
374
|
const O = G[V];
|
|
375
|
-
|
|
375
|
+
W[V] = G[V + 2], W[V + 1] = G[V + 1], W[V + 2] = O;
|
|
376
376
|
}
|
|
377
377
|
}, w = Math.floor(u / (h * l));
|
|
378
378
|
let v = h, k = l, P = 0;
|
|
@@ -429,7 +429,7 @@ class Fe extends Jt {
|
|
|
429
429
|
throw new Error(`GPUBuffer size(${t.buffer.size}) is smaller than tensor size(${n})!`);
|
|
430
430
|
if ((t.buffer.usage & (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC)) !== (GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC))
|
|
431
431
|
throw new Error("GPUBuffer.usage should include GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC!");
|
|
432
|
-
return t.zeroCopy !== !0 && (s = this.copyBuffer(s)), a.resource = s,
|
|
432
|
+
return t.zeroCopy !== !0 && (s = this.copyBuffer(s)), a.resource = s, We().makeTensorFromDataId(r, e, o, this);
|
|
433
433
|
}
|
|
434
434
|
/**
|
|
435
435
|
* Read tensor to a new GPUBuffer.
|
|
@@ -443,7 +443,7 @@ class Fe extends Jt {
|
|
|
443
443
|
throw o != null ? new Error("Data is not on GPU but on CPU.") : new Error("There is no data on GPU or CPU.");
|
|
444
444
|
const n = a, u = n.size, d = n.usage, h = this.bufferManager.acquireBuffer(u, d);
|
|
445
445
|
this.ensureCommandEncoderReady(), this.endComputePassEncoder(), this.commandEncoder.copyBufferToBuffer(a, 0, h, 0, u), this.submitQueue();
|
|
446
|
-
const l = this.makeTensorInfo(r, s), c =
|
|
446
|
+
const l = this.makeTensorInfo(r, s), c = We().makeTensorFromTensorInfo(l), p = this.tensorMap.get(l.dataId);
|
|
447
447
|
return p.resource = h, { tensorRef: c, buffer: h };
|
|
448
448
|
}
|
|
449
449
|
bufferSync(t) {
|
|
@@ -920,10 +920,10 @@ const Da = "return abs(a);", Na = `
|
|
|
920
920
|
return -uniforms.INFINITY;
|
|
921
921
|
}
|
|
922
922
|
return atanh(a);
|
|
923
|
-
`,
|
|
923
|
+
`, Ba = "return ceil(a);", Ea = "return cos(a);", Ta = `
|
|
924
924
|
let e2x = exp(-a);
|
|
925
925
|
return (e2x + 1.0 / e2x) / 2.0;
|
|
926
|
-
`,
|
|
926
|
+
`, Wa = "return exp(a) - 1.0;", _a = "if (a >= 0.0) { return a; } return (exp(a) - 1.0);", Va = `
|
|
927
927
|
var resFloat = exp(a) - vec4<f32>(1.0);
|
|
928
928
|
if (a.r >= 0.0) {
|
|
929
929
|
resFloat.r = a.r;
|
|
@@ -964,9 +964,9 @@ const Da = "return abs(a);", Na = `
|
|
|
964
964
|
return select(a, vec4<f32>(0.0), a < vec4<f32>(0.0));
|
|
965
965
|
`, an = "return round(a);", nn = "return inverseSqrt(a);", un = `
|
|
966
966
|
if (a >= 0.0) {
|
|
967
|
-
return ${
|
|
967
|
+
return ${Er} * a;
|
|
968
968
|
} else {
|
|
969
|
-
return ${
|
|
969
|
+
return ${Tr} * (exp(a) - 1.0);
|
|
970
970
|
}
|
|
971
971
|
`, dn = "return 1.0 / (1.0 + exp(-1.0 * a));", ln = "return sign(a);", cn = "return sin(a);", hn = `
|
|
972
972
|
let e2x = exp(a);
|
|
@@ -1013,11 +1013,11 @@ function te(i, t) {
|
|
|
1013
1013
|
case b.ATANH:
|
|
1014
1014
|
return La;
|
|
1015
1015
|
case b.COS:
|
|
1016
|
-
return
|
|
1016
|
+
return Ea;
|
|
1017
1017
|
case b.COSH:
|
|
1018
|
-
return
|
|
1018
|
+
return Ta;
|
|
1019
1019
|
case b.CEIL:
|
|
1020
|
-
return
|
|
1020
|
+
return Ba;
|
|
1021
1021
|
case b.ELU:
|
|
1022
1022
|
return t ? Va : _a;
|
|
1023
1023
|
case b.ERF:
|
|
@@ -1025,7 +1025,7 @@ function te(i, t) {
|
|
|
1025
1025
|
case b.EXP:
|
|
1026
1026
|
return Ma;
|
|
1027
1027
|
case b.EXPM1:
|
|
1028
|
-
return
|
|
1028
|
+
return Wa;
|
|
1029
1029
|
case b.FLOOR:
|
|
1030
1030
|
return Oa;
|
|
1031
1031
|
case b.IS_FINITE:
|
|
@@ -1120,7 +1120,7 @@ function Q(i, t = !1, e = !1, o = 3) {
|
|
|
1120
1120
|
s = te(b.LEAKYRELU, e);
|
|
1121
1121
|
else
|
|
1122
1122
|
throw new Error(`Activation ${i} has not been implemented for the WebGPU backend.`);
|
|
1123
|
-
const a =
|
|
1123
|
+
const a = B(e ? 4 : 1);
|
|
1124
1124
|
let n = "";
|
|
1125
1125
|
return t ? n = `
|
|
1126
1126
|
fn activation(a : ${a}, coords : vec${o}<i32>) -> ${a} {
|
|
@@ -1160,8 +1160,8 @@ function Dt(i, t, e = !1, o = !1, s = !1, r = 1) {
|
|
|
1160
1160
|
|
|
1161
1161
|
`, n = t ? "value = getB(batch, col, row);" : "value = getB(batch, row, col);";
|
|
1162
1162
|
return `
|
|
1163
|
-
fn mm_readA(batch: i32, row: i32, col: i32) -> ${
|
|
1164
|
-
var value = ${
|
|
1163
|
+
fn mm_readA(batch: i32, row: i32, col: i32) -> ${B(r)} {
|
|
1164
|
+
var value = ${B(r)}(0.0);
|
|
1165
1165
|
${e && s ? a : `
|
|
1166
1166
|
${i ? "if(row < uniforms.dimAOuter && col < uniforms.dimInner)" : "if(row < uniforms.aShape[1] && col < uniforms.aShape[2])"}
|
|
1167
1167
|
{
|
|
@@ -1171,8 +1171,8 @@ function Dt(i, t, e = !1, o = !1, s = !1, r = 1) {
|
|
|
1171
1171
|
return value;
|
|
1172
1172
|
}
|
|
1173
1173
|
|
|
1174
|
-
fn mm_readB(batch: i32, row: i32, col: i32) -> ${
|
|
1175
|
-
var value = ${
|
|
1174
|
+
fn mm_readB(batch: i32, row: i32, col: i32) -> ${B(r)} {
|
|
1175
|
+
var value = ${B(r)}(0.0);
|
|
1176
1176
|
${n}
|
|
1177
1177
|
return value;
|
|
1178
1178
|
}
|
|
@@ -1181,7 +1181,7 @@ function Dt(i, t, e = !1, o = !1, s = !1, r = 1) {
|
|
|
1181
1181
|
function Ze(i, t, e, o, s = !1, r = !1, a = !1, n = 1) {
|
|
1182
1182
|
return `
|
|
1183
1183
|
${Dt(e, o, s, r, a, n)}
|
|
1184
|
-
fn mm_write(batch: i32, row: i32, col: i32, valueIn: ${
|
|
1184
|
+
fn mm_write(batch: i32, row: i32, col: i32, valueIn: ${B(n)}) {
|
|
1185
1185
|
${s && r ? "" : "if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)"}
|
|
1186
1186
|
{
|
|
1187
1187
|
var value = valueIn;
|
|
@@ -1287,7 +1287,7 @@ const lt = (i) => i ? `
|
|
|
1287
1287
|
globalRowStart + inputRow,
|
|
1288
1288
|
kStart + inputCol);
|
|
1289
1289
|
`, Sn = (i) => i ? "let ACached = mm_Asub[k][tileRow + innerRow];" : "let ACached = mm_Asub[tileRow + innerRow][k];";
|
|
1290
|
-
function
|
|
1290
|
+
function Be(i, t, e = !1, o = 32, s = !1, r = 32, a = !1, n = !1) {
|
|
1291
1291
|
const u = i[1] * t[1], d = i[0] * t[0], h = e ? u : o, l = e ? o : u;
|
|
1292
1292
|
L(l % t[1] === 0 && h % t[0] === 0 && o % t[1] === 0, () => `tileAHight ${l} must be divisible by workgroupSize[1]${t[1]}, tileAWidth ${h} must be divisible by workgroupSize[0]${t[0]}, tileInner ${o} must be divisible by workgroupSize[1]${t[1]}`);
|
|
1293
1293
|
const c = l / t[1], p = h / t[0], f = o / t[1], m = i[1], g = i[0], x = a ? `
|
|
@@ -1501,7 +1501,7 @@ class kn {
|
|
|
1501
1501
|
return `
|
|
1502
1502
|
${Q(this.activation, this.hasPreluActivationWeights, this.isVec4)}
|
|
1503
1503
|
${Ze(this.addBias, this.activation, !1, this.transposeB, this.fitAOuter, this.fitBOuter, this.fitInner, this.isVec4 ? 4 : 1)}
|
|
1504
|
-
${this.isVec4 ? Le(this.elementsPerThread, this.workgroupSize, this.transposeA, this.tileInner, !1, null, !0) : this.isVectorA ? In(this.workgroupSize, this.transposeA) :
|
|
1504
|
+
${this.isVec4 ? Le(this.elementsPerThread, this.workgroupSize, this.transposeA, this.tileInner, !1, null, !0) : this.isVectorA ? In(this.workgroupSize, this.transposeA) : Be(this.elementsPerThread, this.workgroupSize, this.transposeA, this.tileInner, !1, null, this.sequentialAccessByThreads, !0)}
|
|
1505
1505
|
`;
|
|
1506
1506
|
}
|
|
1507
1507
|
}
|
|
@@ -1694,7 +1694,7 @@ class $n {
|
|
|
1694
1694
|
const t = this.outputComponent;
|
|
1695
1695
|
return `
|
|
1696
1696
|
${Dt(!1, this.transposeB, !1, !1, !1, t)}
|
|
1697
|
-
fn mm_write(batch: i32, row : i32, col : i32, value : ${
|
|
1697
|
+
fn mm_write(batch: i32, row : i32, col : i32, value : ${B(t)}) {
|
|
1698
1698
|
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
|
|
1699
1699
|
let coords = vec3<i32>(batch, row, col);
|
|
1700
1700
|
let flatIndex = getOutputIndexFromCoords(coords);
|
|
@@ -1705,7 +1705,7 @@ class $n {
|
|
|
1705
1705
|
}
|
|
1706
1706
|
}
|
|
1707
1707
|
}
|
|
1708
|
-
${t === 4 ? Le(this.elementsPerThread, this.workgroupSize, this.transposeA, 32, !0, this.splitedDimInner) :
|
|
1708
|
+
${t === 4 ? Le(this.elementsPerThread, this.workgroupSize, this.transposeA, 32, !0, this.splitedDimInner) : Be(this.elementsPerThread, this.workgroupSize, this.transposeA, 32, !0, this.splitedDimInner)}
|
|
1709
1709
|
`;
|
|
1710
1710
|
}
|
|
1711
1711
|
}
|
|
@@ -1805,34 +1805,34 @@ const Fn = {
|
|
|
1805
1805
|
* limitations under the License.
|
|
1806
1806
|
* =============================================================================
|
|
1807
1807
|
*/
|
|
1808
|
-
function
|
|
1808
|
+
function Ee({ a: i, b: t, transposeA: e, transposeB: o, backend: s, bias: r = null, preluActivationWeights: a = null, leakyreluAlpha: n = 0, activation: u = null }) {
|
|
1809
1809
|
const d = i.shape.length, h = t.shape.length, l = e ? i.shape[d - 2] : i.shape[d - 1], c = o ? t.shape[h - 1] : t.shape[h - 2], p = e ? i.shape[d - 1] : i.shape[d - 2], f = o ? t.shape[h - 2] : t.shape[h - 1], m = i.shape.slice(0, -2), g = t.shape.slice(0, -2), x = $(m), C = $(g), v = Z(i.shape.slice(0, -2), t.shape.slice(0, -2)).concat([p, f]);
|
|
1810
1810
|
L(l === c, () => `Error in matMul: inner shapes (${l}) and (${c}) of Tensors with shapes ${i.shape} and ${t.shape} and transposeA=${e} and transposeB=${o} must match.`);
|
|
1811
|
-
const k = e ? [x, l, p] : [x, p, l], P = o ? [C, f, c] : [C, c, f], N = R({ inputs: { x: i }, backend: s, attrs: { shape: k } }), A = R({ inputs: { x: t }, backend: s, attrs: { shape: P } }), z = [N, A],
|
|
1811
|
+
const k = e ? [x, l, p] : [x, p, l], P = o ? [C, f, c] : [C, c, f], N = R({ inputs: { x: i }, backend: s, attrs: { shape: k } }), A = R({ inputs: { x: t }, backend: s, attrs: { shape: P } }), z = [N, A], E = Math.max(x, C), T = [N, A], G = [
|
|
1812
1812
|
{ type: "int32", data: [p] },
|
|
1813
1813
|
{ type: "int32", data: [f] },
|
|
1814
1814
|
{ type: "int32", data: [l] }
|
|
1815
1815
|
];
|
|
1816
|
-
let M,
|
|
1817
|
-
const V = [
|
|
1816
|
+
let M, W;
|
|
1817
|
+
const V = [E, p, f];
|
|
1818
1818
|
let O = U().get("WEBGPU_MATMUL_PROGRAM_TYPE");
|
|
1819
1819
|
if (O < 0) {
|
|
1820
|
-
const ae = U().getNumber("WEBGPU_THRESHOLD_TO_INCREASE_WORKGROUPS_FOR_MATMUL"), he = ae > 0 ? ae : s.thresholdToIncreaseWorkgroups, pe =
|
|
1821
|
-
pe <= he || p <= 8 && pe <= he * 2 ?
|
|
1820
|
+
const ae = U().getNumber("WEBGPU_THRESHOLD_TO_INCREASE_WORKGROUPS_FOR_MATMUL"), he = ae > 0 ? ae : s.thresholdToIncreaseWorkgroups, pe = E * Math.ceil(p / 32) * Math.ceil(f / 32);
|
|
1821
|
+
pe <= he || p <= 8 && pe <= he * 2 ? E * p * f <= 128 ? O = j.MatMulReduceProgram : E === 1 && c >= 2e3 ? O = j.MatMulSplitKProgram : O = j.MatMulSmallOutputSizeProgram : O = j.MatMulPackedProgram;
|
|
1822
1822
|
}
|
|
1823
1823
|
switch (O) {
|
|
1824
1824
|
case j.MatMulReduceProgram:
|
|
1825
1825
|
M = new Pn(V, e, o, r, u, a);
|
|
1826
1826
|
break;
|
|
1827
1827
|
case j.MatMulSplitKProgram: {
|
|
1828
|
-
if (
|
|
1829
|
-
|
|
1830
|
-
const he = new zn(
|
|
1828
|
+
if (W = H({ backend: s, attrs: { shape: V, value: 0, dtype: i.dtype } }), M = new $n(V, c, e, o), r || u) {
|
|
1829
|
+
W = s.runWebGPUProgram(M, T, i.dtype, G, W);
|
|
1830
|
+
const he = new zn(W.shape, r, u, a);
|
|
1831
1831
|
let pe = null;
|
|
1832
|
-
const ke = [
|
|
1832
|
+
const ke = [W];
|
|
1833
1833
|
r && ke.push(r), a && ke.push(a), u === "leakyrelu" && (pe = [{ type: "float32", data: [n] }], he.uniforms += " alpha : f32,");
|
|
1834
|
-
const tt = s.runWebGPUProgram(he, ke,
|
|
1835
|
-
z.push(
|
|
1834
|
+
const tt = s.runWebGPUProgram(he, ke, W.dtype, pe);
|
|
1835
|
+
z.push(W);
|
|
1836
1836
|
const jt = R({ inputs: { x: tt }, backend: s, attrs: { shape: v } });
|
|
1837
1837
|
z.push(tt);
|
|
1838
1838
|
for (const Zt of z)
|
|
@@ -1851,9 +1851,9 @@ function Be({ a: i, b: t, transposeA: e, transposeB: o, backend: s, bias: r = nu
|
|
|
1851
1851
|
default:
|
|
1852
1852
|
throw new Error(`Unsupported MatMulProgramType ${O}.`);
|
|
1853
1853
|
}
|
|
1854
|
-
r &&
|
|
1855
|
-
const Qt = R({ inputs: { x:
|
|
1856
|
-
z.push(
|
|
1854
|
+
r && T.push(r), a && T.push(a), u === "leakyrelu" && (G.push({ type: "float32", data: [n] }), M.uniforms += " alpha : f32,"), W = s.runWebGPUProgram(M, T, i.dtype, G, W);
|
|
1855
|
+
const Qt = R({ inputs: { x: W }, backend: s, attrs: { shape: v } });
|
|
1856
|
+
z.push(W);
|
|
1857
1857
|
for (const ae of z)
|
|
1858
1858
|
s.disposeData(ae.dataId);
|
|
1859
1859
|
return Qt;
|
|
@@ -1876,7 +1876,7 @@ function Be({ a: i, b: t, transposeA: e, transposeB: o, backend: s, bias: r = nu
|
|
|
1876
1876
|
*/
|
|
1877
1877
|
function Ln(i) {
|
|
1878
1878
|
const { inputs: t, backend: e, attrs: o } = i, { a: s, b: r, bias: a, preluActivationWeights: n } = t, { transposeA: u, transposeB: d, activation: h, leakyreluAlpha: l } = o;
|
|
1879
|
-
return
|
|
1879
|
+
return Ee({
|
|
1880
1880
|
a: s,
|
|
1881
1881
|
b: r,
|
|
1882
1882
|
transposeA: u,
|
|
@@ -1888,7 +1888,7 @@ function Ln(i) {
|
|
|
1888
1888
|
activation: h
|
|
1889
1889
|
});
|
|
1890
1890
|
}
|
|
1891
|
-
const
|
|
1891
|
+
const Bn = {
|
|
1892
1892
|
kernelName: us,
|
|
1893
1893
|
backendName: "webgpu",
|
|
1894
1894
|
kernelFunc: Ln
|
|
@@ -2022,7 +2022,7 @@ function X(i) {
|
|
|
2022
2022
|
const { inputs: t } = i, { x: e } = t;
|
|
2023
2023
|
return i.backend.incRef(e.dataId), { dataId: e.dataId, shape: e.shape, dtype: e.dtype };
|
|
2024
2024
|
}
|
|
2025
|
-
const
|
|
2025
|
+
const En = {
|
|
2026
2026
|
kernelName: ds,
|
|
2027
2027
|
backendName: "webgpu",
|
|
2028
2028
|
kernelFunc: X
|
|
@@ -2047,7 +2047,7 @@ function ie(i) {
|
|
|
2047
2047
|
const { inputs: t, backend: e } = i, { real: o, imag: s } = t, r = e.makeTensorInfo(o.shape, "complex64"), a = e.tensorMap.get(r.dataId), n = X({ inputs: { x: o }, backend: e }), u = X({ inputs: { x: s }, backend: e });
|
|
2048
2048
|
return a.complexTensorInfos = { real: n, imag: u }, r;
|
|
2049
2049
|
}
|
|
2050
|
-
const
|
|
2050
|
+
const Tn = {
|
|
2051
2051
|
kernelName: ls,
|
|
2052
2052
|
backendName: "webgpu",
|
|
2053
2053
|
kernelFunc: ie
|
|
@@ -2196,7 +2196,7 @@ function _({ opType: i, cpuKernelImpl: t, supportsComplex: e = !1, dtype: o }) {
|
|
|
2196
2196
|
* limitations under the License.
|
|
2197
2197
|
* =============================================================================
|
|
2198
2198
|
*/
|
|
2199
|
-
const { addImpl:
|
|
2199
|
+
const { addImpl: Wn, castImpl: _n, ceilImpl: Vn, concatImpl: Un, equalImpl: Mn, expImpl: On, expm1Impl: Gn, floorImpl: Hn, floorDivImpl: Xn, gatherNdImpl: Kn, gatherV2Impl: qn, greaterEqualImpl: Yn, greaterImpl: Qn, lessEqualImpl: jn, lessImpl: Zn, logImpl: Jn, maxImpl: eu, maximumImpl: tu, minimumImpl: su, multiplyImpl: ou, negImpl: iu, notEqualImpl: ru, prodImpl: au, rangeImpl: nu, rsqrtImpl: uu, scatterImpl: du, simpleAbsImpl: lu, sliceImpl: cu, stridedSliceImpl: hu, stringNGramsImpl: pu, subImpl: fu, tileImpl: mu, topKImpl: gu, transposeImpl: xu } = Or;
|
|
2200
2200
|
/**
|
|
2201
2201
|
* @license
|
|
2202
2202
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
@@ -2276,7 +2276,7 @@ const Su = F({ opType: b.ACOSH }), vu = {
|
|
|
2276
2276
|
* limitations under the License.
|
|
2277
2277
|
* =============================================================================
|
|
2278
2278
|
*/
|
|
2279
|
-
const Iu = _({ opType: D.ADD, cpuKernelImpl:
|
|
2279
|
+
const Iu = _({ opType: D.ADD, cpuKernelImpl: Wn, supportsComplex: !0 }), ku = {
|
|
2280
2280
|
kernelName: fs,
|
|
2281
2281
|
backendName: "webgpu",
|
|
2282
2282
|
kernelFunc: Iu
|
|
@@ -2638,7 +2638,7 @@ function Lu(i) {
|
|
|
2638
2638
|
const { inputs: t, backend: e, attrs: o } = i, { x: s } = t, { keepDims: r, axis: a } = o;
|
|
2639
2639
|
return re(s, a, r, "all", e);
|
|
2640
2640
|
}
|
|
2641
|
-
const
|
|
2641
|
+
const Bu = {
|
|
2642
2642
|
kernelName: Cs,
|
|
2643
2643
|
backendName: "webgpu",
|
|
2644
2644
|
kernelFunc: Lu
|
|
@@ -2659,14 +2659,14 @@ const Eu = {
|
|
|
2659
2659
|
* limitations under the License.
|
|
2660
2660
|
* =============================================================================
|
|
2661
2661
|
*/
|
|
2662
|
-
function
|
|
2662
|
+
function Eu(i) {
|
|
2663
2663
|
const { inputs: t, backend: e, attrs: o } = i, { x: s } = t, { keepDims: r, axis: a } = o;
|
|
2664
2664
|
return re(s, a, r, "any", e);
|
|
2665
2665
|
}
|
|
2666
|
-
const
|
|
2666
|
+
const Tu = {
|
|
2667
2667
|
kernelName: ws,
|
|
2668
2668
|
backendName: "webgpu",
|
|
2669
|
-
kernelFunc:
|
|
2669
|
+
kernelFunc: Eu
|
|
2670
2670
|
};
|
|
2671
2671
|
/**
|
|
2672
2672
|
* @license
|
|
@@ -2787,7 +2787,7 @@ class $t {
|
|
|
2787
2787
|
* limitations under the License.
|
|
2788
2788
|
* =============================================================================
|
|
2789
2789
|
*/
|
|
2790
|
-
function
|
|
2790
|
+
function Wu(i) {
|
|
2791
2791
|
const { inputs: t, backend: e, attrs: o } = i, { x: s } = t, { axis: r } = o;
|
|
2792
2792
|
let a = se(r, s.shape);
|
|
2793
2793
|
const n = Se(a, s.shape.length);
|
|
@@ -2800,7 +2800,7 @@ function Tu(i) {
|
|
|
2800
2800
|
const _u = {
|
|
2801
2801
|
kernelName: bs,
|
|
2802
2802
|
backendName: "webgpu",
|
|
2803
|
-
kernelFunc:
|
|
2803
|
+
kernelFunc: Wu
|
|
2804
2804
|
};
|
|
2805
2805
|
/**
|
|
2806
2806
|
* @license
|
|
@@ -3512,7 +3512,7 @@ const ld = {
|
|
|
3512
3512
|
*/
|
|
3513
3513
|
function cd(i) {
|
|
3514
3514
|
const { inputs: t, backend: e, attrs: o } = i, { a: s, b: r } = t, { transposeA: a, transposeB: n } = o;
|
|
3515
|
-
return
|
|
3515
|
+
return Ee({ a: s, b: r, transposeA: a, transposeB: n, backend: e });
|
|
3516
3516
|
}
|
|
3517
3517
|
const hd = {
|
|
3518
3518
|
kernelName: Fs,
|
|
@@ -3580,7 +3580,7 @@ function fd(i) {
|
|
|
3580
3580
|
* =============================================================================
|
|
3581
3581
|
*/
|
|
3582
3582
|
function ce(i) {
|
|
3583
|
-
const { inputs: t, backend: e, attrs: o } = i, { x: s } = t, { begin: r, size: a } = o, [n, u] =
|
|
3583
|
+
const { inputs: t, backend: e, attrs: o } = i, { x: s } = t, { begin: r, size: a } = o, [n, u] = Wr(s, r, a);
|
|
3584
3584
|
if (_r(s, n, u), e.shouldExecuteOnCPU([s]) || s.dtype === "string") {
|
|
3585
3585
|
const l = e.tensorMap.get(s.dataId), c = cu(l.values, n, u, s.shape, s.dtype);
|
|
3586
3586
|
return e.makeTensorInfo(u, s.dtype, c);
|
|
@@ -3625,7 +3625,7 @@ const gd = (i) => {
|
|
|
3625
3625
|
});
|
|
3626
3626
|
return p.push(f), p.push(m), p.push(g), p.forEach((C) => e.disposeData(C.dataId)), x;
|
|
3627
3627
|
}, xd = {
|
|
3628
|
-
kernelName:
|
|
3628
|
+
kernelName: Bs,
|
|
3629
3629
|
backendName: "webgpu",
|
|
3630
3630
|
kernelFunc: gd
|
|
3631
3631
|
};
|
|
@@ -3701,7 +3701,7 @@ function bd(i) {
|
|
|
3701
3701
|
return e.runWebGPUProgram(p, m, l, f, c);
|
|
3702
3702
|
}
|
|
3703
3703
|
const yd = {
|
|
3704
|
-
kernelName:
|
|
3704
|
+
kernelName: Es,
|
|
3705
3705
|
backendName: "webgpu",
|
|
3706
3706
|
kernelFunc: bd
|
|
3707
3707
|
};
|
|
@@ -3780,7 +3780,7 @@ function vd(i) {
|
|
|
3780
3780
|
return e.runWebGPUProgram(u, [o, s], "int32", d);
|
|
3781
3781
|
}
|
|
3782
3782
|
const Id = {
|
|
3783
|
-
kernelName:
|
|
3783
|
+
kernelName: Ts,
|
|
3784
3784
|
backendName: "webgpu",
|
|
3785
3785
|
kernelFunc: vd
|
|
3786
3786
|
};
|
|
@@ -3800,14 +3800,14 @@ const Id = {
|
|
|
3800
3800
|
* limitations under the License.
|
|
3801
3801
|
* =============================================================================
|
|
3802
3802
|
*/
|
|
3803
|
-
const
|
|
3803
|
+
const Bt = _({
|
|
3804
3804
|
opType: D.NOT_EQUAL,
|
|
3805
3805
|
dtype: "bool",
|
|
3806
3806
|
cpuKernelImpl: ru
|
|
3807
3807
|
}), kd = {
|
|
3808
|
-
kernelName:
|
|
3808
|
+
kernelName: Ws,
|
|
3809
3809
|
backendName: "webgpu",
|
|
3810
|
-
kernelFunc:
|
|
3810
|
+
kernelFunc: Bt
|
|
3811
3811
|
};
|
|
3812
3812
|
/**
|
|
3813
3813
|
* @license
|
|
@@ -3893,7 +3893,7 @@ function Me(i) {
|
|
|
3893
3893
|
if (r === "int32")
|
|
3894
3894
|
return Pd(s, e);
|
|
3895
3895
|
if (r === "bool") {
|
|
3896
|
-
const a = e.makeTensorInfo([], "bool", Ct("bool", 1)), u =
|
|
3896
|
+
const a = e.makeTensorInfo([], "bool", Ct("bool", 1)), u = Bt({ inputs: { a: s, b: a }, backend: e });
|
|
3897
3897
|
return e.disposeData(a.dataId), u;
|
|
3898
3898
|
}
|
|
3899
3899
|
throw new Error(`Error in Cast: failed to cast ${s.dtype} to ${r}`);
|
|
@@ -4039,7 +4039,7 @@ const Ld = {
|
|
|
4039
4039
|
* limitations under the License.
|
|
4040
4040
|
* =============================================================================
|
|
4041
4041
|
*/
|
|
4042
|
-
class
|
|
4042
|
+
class Bd {
|
|
4043
4043
|
constructor(t) {
|
|
4044
4044
|
this.outputShape = [], this.variableNames = ["real", "imag"], this.workgroupSize = [64, 1, 1], this.size = !0, this.outputShape = t, this.dispatchLayout = I(this.outputShape), this.dispatch = S(this.dispatchLayout, this.outputShape, this.workgroupSize), this.shaderKey = "complexAbs";
|
|
4045
4045
|
}
|
|
@@ -4082,17 +4082,17 @@ function ht(i, t) {
|
|
|
4082
4082
|
shape: i.shape
|
|
4083
4083
|
};
|
|
4084
4084
|
}
|
|
4085
|
-
function
|
|
4086
|
-
const { inputs: t, backend: e } = i, { x: o } = t, s = e.tensorMap.get(o.dataId), r = new
|
|
4085
|
+
function Ed(i) {
|
|
4086
|
+
const { inputs: t, backend: e } = i, { x: o } = t, s = e.tensorMap.get(o.dataId), r = new Bd(o.shape), a = [
|
|
4087
4087
|
ht(o, s.complexTensorInfos.real),
|
|
4088
4088
|
ht(o, s.complexTensorInfos.imag)
|
|
4089
4089
|
];
|
|
4090
4090
|
return e.runWebGPUProgram(r, a, a[0].dtype);
|
|
4091
4091
|
}
|
|
4092
|
-
const
|
|
4092
|
+
const Td = {
|
|
4093
4093
|
kernelName: Gs,
|
|
4094
4094
|
backendName: "webgpu",
|
|
4095
|
-
kernelFunc:
|
|
4095
|
+
kernelFunc: Ed
|
|
4096
4096
|
};
|
|
4097
4097
|
/**
|
|
4098
4098
|
* @license
|
|
@@ -4110,7 +4110,7 @@ const Wd = {
|
|
|
4110
4110
|
* limitations under the License.
|
|
4111
4111
|
* =============================================================================
|
|
4112
4112
|
*/
|
|
4113
|
-
class
|
|
4113
|
+
class Wd {
|
|
4114
4114
|
constructor(t) {
|
|
4115
4115
|
this.uniforms = "", this.workPerThread = 1, this.workgroupSize = [64, 1, 1], this.size = !0, this.outputShape = xe(
|
|
4116
4116
|
t,
|
|
@@ -4164,14 +4164,14 @@ class Td {
|
|
|
4164
4164
|
* limitations under the License.
|
|
4165
4165
|
* =============================================================================
|
|
4166
4166
|
*/
|
|
4167
|
-
function
|
|
4167
|
+
function Te(i) {
|
|
4168
4168
|
const { inputs: t, backend: e } = i, { input: o } = t, s = e.tensorMap.get(o.dataId);
|
|
4169
4169
|
return X({ inputs: { x: s.complexTensorInfos.imag }, backend: e });
|
|
4170
4170
|
}
|
|
4171
4171
|
const _d = {
|
|
4172
4172
|
kernelName: Hs,
|
|
4173
4173
|
backendName: "webgpu",
|
|
4174
|
-
kernelFunc:
|
|
4174
|
+
kernelFunc: Te
|
|
4175
4175
|
};
|
|
4176
4176
|
/**
|
|
4177
4177
|
* @license
|
|
@@ -4192,7 +4192,7 @@ const _d = {
|
|
|
4192
4192
|
function fe(i, t, e) {
|
|
4193
4193
|
const o = i[0].dtype;
|
|
4194
4194
|
if (o === "complex64") {
|
|
4195
|
-
const f = i.map((w) => Ie({ inputs: { input: w }, backend: e })), m = i.map((w) =>
|
|
4195
|
+
const f = i.map((w) => Ie({ inputs: { input: w }, backend: e })), m = i.map((w) => Te({ inputs: { input: w }, backend: e })), g = fe(f, t, e), x = fe(m, t, e), C = ie({ inputs: { real: g, imag: x }, backend: e });
|
|
4196
4196
|
return f.forEach((w) => e.disposeData(w.dataId)), m.forEach((w) => e.disposeData(w.dataId)), e.disposeData(g.dataId), e.disposeData(x.dataId), C;
|
|
4197
4197
|
}
|
|
4198
4198
|
let s = e.shouldExecuteOnCPU(i);
|
|
@@ -4219,7 +4219,7 @@ function fe(i, t, e) {
|
|
|
4219
4219
|
e.disposeData(g.dataId);
|
|
4220
4220
|
return m;
|
|
4221
4221
|
}
|
|
4222
|
-
const { tensors2D: a, outShape: n } = Vd(i, t, e), u = a.map((f) => f.shape), d = new
|
|
4222
|
+
const { tensors2D: a, outShape: n } = Vd(i, t, e), u = a.map((f) => f.shape), d = new Wd(u), h = [], l = new Array(u.length - 1);
|
|
4223
4223
|
if (l.length > 0) {
|
|
4224
4224
|
l[0] = u[0][1], h.push({ type: "int32", data: [l[0]] });
|
|
4225
4225
|
for (let f = 1; f < l.length; f++)
|
|
@@ -4259,7 +4259,7 @@ function Vd(i, t, e) {
|
|
|
4259
4259
|
* limitations under the License.
|
|
4260
4260
|
* =============================================================================
|
|
4261
4261
|
*/
|
|
4262
|
-
function
|
|
4262
|
+
function Et(i) {
|
|
4263
4263
|
const { inputs: t, backend: e, attrs: o } = i, { axis: s } = o, r = se(s, t[0].shape)[0], a = t.map((d) => d.shape);
|
|
4264
4264
|
kr(a, r);
|
|
4265
4265
|
const n = xe(t.map((d) => d.shape), r);
|
|
@@ -4271,7 +4271,7 @@ function Bt(i) {
|
|
|
4271
4271
|
const Ud = {
|
|
4272
4272
|
kernelName: Xs,
|
|
4273
4273
|
backendName: "webgpu",
|
|
4274
|
-
kernelFunc:
|
|
4274
|
+
kernelFunc: Et
|
|
4275
4275
|
};
|
|
4276
4276
|
/**
|
|
4277
4277
|
* @license
|
|
@@ -4337,7 +4337,7 @@ function Md(i, t, e, o, s = !1, r = null, a = !1, n = 4, u = 4, d = 4) {
|
|
|
4337
4337
|
let xRow = outRow * uniforms.strides[0] + uniforms.dilations[0] * WRow - uniforms.pads[0];
|
|
4338
4338
|
let xCol = outCol * uniforms.strides[1] + uniforms.dilations[1] * WCol - uniforms.pads[1];
|
|
4339
4339
|
let xCh = ${x} % inChannels;
|
|
4340
|
-
var resData = ${
|
|
4340
|
+
var resData = ${B(n)}(0.0);
|
|
4341
4341
|
// The bounds checking is always needed since we use it to pad zero for
|
|
4342
4342
|
// the 'same' padding type.
|
|
4343
4343
|
if (xRow >= 0 && xRow < ${f} && xCol >= 0 && xCol < ${m}) {
|
|
@@ -4350,12 +4350,12 @@ function Md(i, t, e, o, s = !1, r = null, a = !1, n = 4, u = 4, d = 4) {
|
|
|
4350
4350
|
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
|
|
4351
4351
|
${C}
|
|
4352
4352
|
}
|
|
4353
|
-
return ${
|
|
4353
|
+
return ${B(n)}(0.0);` : o && e ? `
|
|
4354
4354
|
${C}` : `
|
|
4355
4355
|
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
|
|
4356
4356
|
${C}
|
|
4357
4357
|
}
|
|
4358
|
-
return ${
|
|
4358
|
+
return ${B(n)}(0.0);`, v = `${l(u)}`, k = B(d), P = i ? B(n) : B(u), N = i ? B(u) : B(n);
|
|
4359
4359
|
return `
|
|
4360
4360
|
${Q(r, a, d === 4, 4)}
|
|
4361
4361
|
fn mm_readA(batch: i32, row : i32, col : i32) -> ${P} {
|
|
@@ -4382,7 +4382,7 @@ class Od {
|
|
|
4382
4382
|
this.variableNames = ["x", "W"], this.uniforms = "filterDims : vec2<i32>, pads : vec2<i32>, strides : vec2<i32>, dilations : vec2<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,", this.outputShape = t.outShape, this.isChannelsLast = t.dataFormat === "channelsLast", this.isVec4 = ((t.inChannels % 4 === 0 || t.inChannels % 3 === 0) && this.isChannelsLast || t.outWidth % 4 === 0 && !this.isChannelsLast) && t.outChannels % 4 === 0, this.dispatchLayout = this.isChannelsLast ? { x: [3], y: [1, 2], z: [0] } : { x: [2, 3], y: [1], z: [0] }, this.workgroupSize = yt(this.dispatchLayout, this.outputShape, this.isVec4), this.elementsPerThread = St(this.dispatchLayout, this.outputShape, this.isVec4), this.dispatch = S(this.dispatchLayout, this.outputShape, this.workgroupSize, this.elementsPerThread), this.isVec4 ? (this.outputComponent = 4, this.isChannelsLast && t.inChannels % 4 !== 0 ? (this.innerElementSize = 3, this.variableComponents = [1, 4]) : (this.innerElementSize = 4, this.variableComponents = [4, 4]), r && (this.variableNames.push("bias"), this.variableComponents.push(4)), n && (this.variableNames.push("preluActivationWeights"), this.variableComponents.push(4))) : (this.innerElementSize = this.elementsPerThread[0], r && this.variableNames.push("bias"), n && this.variableNames.push("preluActivationWeights")), this.sequentialAccessByThreads = u, this.addBias = r, this.activation = a, this.hasPreluActivationWeights = n, this.tileAOuter = this.workgroupSize[1] * this.elementsPerThread[1], this.tileBOuter = this.workgroupSize[0] * this.elementsPerThread[0], this.tileInner = Math.max(this.workgroupSize[0] * this.innerElementSize, this.workgroupSize[1]), this.fitAOuter = e % this.tileAOuter === 0, this.fitBOuter = o % this.tileBOuter === 0, this.fitInner = s % this.tileInner === 0, this.shaderKey = `conv2DMM_${this.elementsPerThread}_${this.activation}}_${this.fitAOuter}_${this.fitBOuter}_${this.fitInner}_${this.isVec4}_${this.innerElementSize}_${this.isChannelsLast}_${this.sequentialAccessByThreads}`;
|
|
4383
4383
|
}
|
|
4384
4384
|
getUserCode() {
|
|
4385
|
-
const t = this.isVec4 ? Le(this.elementsPerThread, this.workgroupSize, !this.isChannelsLast, this.tileInner) :
|
|
4385
|
+
const t = this.isVec4 ? Le(this.elementsPerThread, this.workgroupSize, !this.isChannelsLast, this.tileInner) : Be(this.elementsPerThread, this.workgroupSize, !this.isChannelsLast, this.tileInner, !1, null, this.sequentialAccessByThreads), e = this.isVec4 ? [this.innerElementSize, 4, 4] : [1, 1, 1];
|
|
4386
4386
|
return `
|
|
4387
4387
|
${Md(this.isChannelsLast, this.fitAOuter, this.fitBOuter, this.fitInner, this.addBias, this.activation, this.hasPreluActivationWeights, e[0], e[1], e[2])}
|
|
4388
4388
|
${t}
|
|
@@ -4584,7 +4584,7 @@ function Xd({ x: i, filter: t, convInfo: e, backend: o, bias: s = null, preluAct
|
|
|
4584
4584
|
const x = Ne(s.shape, u);
|
|
4585
4585
|
x != null && (s = R({ inputs: { x: s }, backend: o, attrs: { shape: x } }), c.push(s));
|
|
4586
4586
|
}
|
|
4587
|
-
const m =
|
|
4587
|
+
const m = Ee({
|
|
4588
4588
|
a: u ? p : f,
|
|
4589
4589
|
b: u ? f : p,
|
|
4590
4590
|
transposeA: d,
|
|
@@ -4608,24 +4608,24 @@ function Kd({ x: i, filter: t, convInfo: e, backend: o, bias: s = null, preluAct
|
|
|
4608
4608
|
{ type: "int32", data: [f] },
|
|
4609
4609
|
{ type: "int32", data: [h * u] },
|
|
4610
4610
|
{ type: "int32", data: [h] }
|
|
4611
|
-
], z = o.runWebGPUProgram(N, [i], i.dtype, A),
|
|
4612
|
-
|
|
4613
|
-
const
|
|
4614
|
-
if (
|
|
4611
|
+
], z = o.runWebGPUProgram(N, [i], i.dtype, A), E = [];
|
|
4612
|
+
E.push(z);
|
|
4613
|
+
const T = R({ inputs: { x: t }, backend: o, attrs: { shape: [1, v, -1] } });
|
|
4614
|
+
if (E.push(T), r != null) {
|
|
4615
4615
|
const O = Ne(r.shape, w);
|
|
4616
4616
|
O != null && (r = R({
|
|
4617
4617
|
inputs: { x: r },
|
|
4618
4618
|
backend: o,
|
|
4619
4619
|
attrs: { shape: O }
|
|
4620
|
-
}),
|
|
4620
|
+
}), E.push(r));
|
|
4621
4621
|
}
|
|
4622
4622
|
if (s != null) {
|
|
4623
4623
|
const O = Ne(s.shape, w);
|
|
4624
|
-
O != null && (s = R({ inputs: { x: s }, backend: o, attrs: { shape: O } }),
|
|
4624
|
+
O != null && (s = R({ inputs: { x: s }, backend: o, attrs: { shape: O } }), E.push(s));
|
|
4625
4625
|
}
|
|
4626
|
-
const
|
|
4627
|
-
a: w ? z :
|
|
4628
|
-
b: w ?
|
|
4626
|
+
const W = Ee({
|
|
4627
|
+
a: w ? z : T,
|
|
4628
|
+
b: w ? T : z,
|
|
4629
4629
|
transposeA: !w,
|
|
4630
4630
|
transposeB: !1,
|
|
4631
4631
|
backend: o,
|
|
@@ -4633,13 +4633,13 @@ function Kd({ x: i, filter: t, convInfo: e, backend: o, bias: s = null, preluAct
|
|
|
4633
4633
|
activation: n,
|
|
4634
4634
|
preluActivationWeights: r,
|
|
4635
4635
|
leakyreluAlpha: a
|
|
4636
|
-
}), V = R({ inputs: { x:
|
|
4637
|
-
|
|
4638
|
-
for (const O of
|
|
4636
|
+
}), V = R({ inputs: { x: W }, backend: o, attrs: { shape: e.outShape } });
|
|
4637
|
+
E.push(W);
|
|
4638
|
+
for (const O of E)
|
|
4639
4639
|
o.disposeData(O.dataId);
|
|
4640
4640
|
return V;
|
|
4641
4641
|
}
|
|
4642
|
-
function
|
|
4642
|
+
function Tt({ x: i, filter: t, convInfo: e, backend: o, bias: s = null, preluActivationWeights: r = null, leakyreluAlpha: a = 0, activation: n = null }) {
|
|
4643
4643
|
const u = s != null, d = r != null, h = e.dataFormat === "channelsLast", l = h && e.filterHeight === e.inHeight && e.filterWidth === e.inWidth && e.padInfo.type === "VALID", c = U().getBool("WEBGPU_USE_NAIVE_CONV2D_DEBUG");
|
|
4644
4644
|
if (!c && (l || e.filterHeight === 1 && e.filterWidth === 1 && e.dilationHeight === 1 && e.dilationWidth === 1 && e.strideHeight === 1 && e.strideWidth === 1 && (e.padInfo.type === "SAME" || e.padInfo.type === "VALID")))
|
|
4645
4645
|
return Xd({
|
|
@@ -4708,7 +4708,7 @@ function Wt({ x: i, filter: t, convInfo: e, backend: o, bias: s = null, preluAct
|
|
|
4708
4708
|
*/
|
|
4709
4709
|
function qd(i) {
|
|
4710
4710
|
const { inputs: t, attrs: e, backend: o } = i, { x: s, filter: r } = t, { strides: a, pad: n, dataFormat: u, dilations: d, dimRoundingMode: h } = e, l = ye(u), c = ee(s.shape, r.shape, a, d, n, h, !1, l);
|
|
4711
|
-
return
|
|
4711
|
+
return Tt({ x: s, filter: r, convInfo: c, backend: o });
|
|
4712
4712
|
}
|
|
4713
4713
|
const Yd = {
|
|
4714
4714
|
kernelName: Ks,
|
|
@@ -5121,10 +5121,10 @@ function sl(i = 4) {
|
|
|
5121
5121
|
let xR = f32(outRow - uniforms.pads[0] + WRow) / f32(uniforms.strides[0]);
|
|
5122
5122
|
let xC = f32(outCol - uniforms.pads[1] + WCol) / f32(uniforms.strides[1]);
|
|
5123
5123
|
if (xR < 0.0 || xR >= f32(uniforms.outBackprop[1]) || fract(xR) > 0.0) {
|
|
5124
|
-
return ${
|
|
5124
|
+
return ${B(i)}(0.0);
|
|
5125
5125
|
}
|
|
5126
5126
|
if (xC < 0.0 || xC >= f32(uniforms.outBackprop[2]) || fract(xC) > 0.0) {
|
|
5127
|
-
return ${
|
|
5127
|
+
return ${B(i)}(0.0);
|
|
5128
5128
|
}
|
|
5129
5129
|
let coord = vec4<i32>(
|
|
5130
5130
|
batch,
|
|
@@ -5133,13 +5133,13 @@ function sl(i = 4) {
|
|
|
5133
5133
|
col % uniforms.outBackprop[3]);
|
|
5134
5134
|
return x[getIndexFromCoords4D(coord, uniforms.xShape)/${i}];`}
|
|
5135
5135
|
}
|
|
5136
|
-
return ${
|
|
5136
|
+
return ${B(i)}(0.0);`;
|
|
5137
5137
|
return `
|
|
5138
|
-
fn mm_readA(batch: i32, row : i32, col : i32) -> ${
|
|
5138
|
+
fn mm_readA(batch: i32, row : i32, col : i32) -> ${B(i)} {
|
|
5139
5139
|
${o}
|
|
5140
5140
|
}
|
|
5141
5141
|
|
|
5142
|
-
fn mm_readB(batch: i32, row : i32, col : i32) -> ${
|
|
5142
|
+
fn mm_readB(batch: i32, row : i32, col : i32) -> ${B(i)} {
|
|
5143
5143
|
let coordX = uniforms.filterDims.x - 1 -
|
|
5144
5144
|
row / (uniforms.filterDims[1] * uniforms.outBackprop[3]);
|
|
5145
5145
|
let coordY = uniforms.filterDims.y - 1 -
|
|
@@ -5150,10 +5150,10 @@ function sl(i = 4) {
|
|
|
5150
5150
|
let coord = vec4<i32>(coordX, coordY, col, rowInner);
|
|
5151
5151
|
${t(i)}
|
|
5152
5152
|
}
|
|
5153
|
-
return ${
|
|
5153
|
+
return ${B(i)}(0.0);
|
|
5154
5154
|
}
|
|
5155
5155
|
|
|
5156
|
-
fn mm_write(batch: i32, row : i32, col : i32, valueInput : ${
|
|
5156
|
+
fn mm_write(batch: i32, row : i32, col : i32, valueInput : ${B(i)}) {
|
|
5157
5157
|
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
|
|
5158
5158
|
var value = valueInput;
|
|
5159
5159
|
let outCoord = vec4<i32>(
|
|
@@ -5170,7 +5170,7 @@ class ol {
|
|
|
5170
5170
|
this.variableNames = ["x", "W"], this.uniforms = "filterDims : vec2<i32>, pads : vec2<i32>, strides : vec2<i32>, outBackprop : vec4<i32>, dimAOuter : i32, dimBOuter : i32, dimInner : i32,", this.outputShape = t.inShape, L(t.dataFormat === "channelsLast", () => "TODO: NCHW is unimplemented"), this.isVec4 = t.inChannels % 4 === 0 && t.outChannels % 4 === 0, this.dispatchLayout = { x: [3], y: [1, 2], z: [0] }, this.workgroupSize = yt(this.dispatchLayout, this.outputShape, this.isVec4), this.elementsPerThread = St(this.dispatchLayout, this.outputShape, this.isVec4), this.dispatch = S(this.dispatchLayout, this.outputShape, this.workgroupSize, this.elementsPerThread), this.isVec4 && (this.outputComponent = 4, this.variableComponents = [4, 1]), this.shaderKey = `conv2DDerInputMM_${this.isVec4}_${this.elementsPerThread}`;
|
|
5171
5171
|
}
|
|
5172
5172
|
getUserCode() {
|
|
5173
|
-
const t = this.isVec4 ? Le(this.elementsPerThread, this.workgroupSize) :
|
|
5173
|
+
const t = this.isVec4 ? Le(this.elementsPerThread, this.workgroupSize) : Be(this.elementsPerThread, this.workgroupSize);
|
|
5174
5174
|
return `
|
|
5175
5175
|
${sl(this.isVec4 ? 4 : 1)}
|
|
5176
5176
|
${t}
|
|
@@ -5714,7 +5714,7 @@ function mt(i, t, e) {
|
|
|
5714
5714
|
* limitations under the License.
|
|
5715
5715
|
* =============================================================================
|
|
5716
5716
|
*/
|
|
5717
|
-
function
|
|
5717
|
+
function Wt(i, t, e, o, s, r) {
|
|
5718
5718
|
const a = t.shape.length, n = Se([o], a);
|
|
5719
5719
|
let u = t;
|
|
5720
5720
|
n != null && (u = Y({ inputs: { x: t }, backend: e, attrs: { perm: n } }));
|
|
@@ -5755,7 +5755,7 @@ function Tt(i, t, e, o, s, r) {
|
|
|
5755
5755
|
*/
|
|
5756
5756
|
function bl(i) {
|
|
5757
5757
|
const { inputs: t, backend: e, attrs: o } = i, { x: s } = t, { axis: r, exclusive: a, reverse: n } = o;
|
|
5758
|
-
return
|
|
5758
|
+
return Wt(we.Prod, s, e, r, a, n);
|
|
5759
5759
|
}
|
|
5760
5760
|
const yl = {
|
|
5761
5761
|
kernelName: so,
|
|
@@ -5780,7 +5780,7 @@ const yl = {
|
|
|
5780
5780
|
*/
|
|
5781
5781
|
function Sl(i) {
|
|
5782
5782
|
const { inputs: t, backend: e, attrs: o } = i, { x: s } = t, { axis: r, exclusive: a, reverse: n } = o;
|
|
5783
|
-
return
|
|
5783
|
+
return Wt(we.Sum, s, e, r, a, n);
|
|
5784
5784
|
}
|
|
5785
5785
|
const vl = {
|
|
5786
5786
|
kernelName: oo,
|
|
@@ -6349,7 +6349,7 @@ function Ll(i) {
|
|
|
6349
6349
|
];
|
|
6350
6350
|
return e.runWebGPUProgram(c, [s, r], "float32", p);
|
|
6351
6351
|
}
|
|
6352
|
-
const
|
|
6352
|
+
const Bl = {
|
|
6353
6353
|
kernelName: no,
|
|
6354
6354
|
backendName: "webgpu",
|
|
6355
6355
|
kernelFunc: Ll
|
|
@@ -6370,7 +6370,7 @@ const El = {
|
|
|
6370
6370
|
* limitations under the License.
|
|
6371
6371
|
* =============================================================================
|
|
6372
6372
|
*/
|
|
6373
|
-
function
|
|
6373
|
+
function El(i) {
|
|
6374
6374
|
const { inputs: t, backend: e, attrs: o } = i, { dy: s, filter: r } = t, { strides: a, dilations: n, pad: u, dimRoundingMode: d, inputShape: h } = o, l = ee(
|
|
6375
6375
|
h,
|
|
6376
6376
|
r.shape,
|
|
@@ -6396,10 +6396,10 @@ function Bl(i) {
|
|
|
6396
6396
|
];
|
|
6397
6397
|
return e.runWebGPUProgram(c, [s, r], s.dtype, p);
|
|
6398
6398
|
}
|
|
6399
|
-
const
|
|
6399
|
+
const Tl = {
|
|
6400
6400
|
kernelName: uo,
|
|
6401
6401
|
backendName: "webgpu",
|
|
6402
|
-
kernelFunc:
|
|
6402
|
+
kernelFunc: El
|
|
6403
6403
|
};
|
|
6404
6404
|
/**
|
|
6405
6405
|
* @license
|
|
@@ -6417,7 +6417,7 @@ const Wl = {
|
|
|
6417
6417
|
* limitations under the License.
|
|
6418
6418
|
* =============================================================================
|
|
6419
6419
|
*/
|
|
6420
|
-
class
|
|
6420
|
+
class Wl {
|
|
6421
6421
|
constructor(t) {
|
|
6422
6422
|
this.variableNames = ["x"], this.workgroupSize = [64, 1, 1], this.size = !0, this.outputShape = [t, t], this.dispatchLayout = I(this.outputShape), this.dispatch = S(this.dispatchLayout, this.outputShape, this.workgroupSize), this.shaderKey = "diag";
|
|
6423
6423
|
}
|
|
@@ -6450,7 +6450,7 @@ class Tl {
|
|
|
6450
6450
|
* =============================================================================
|
|
6451
6451
|
*/
|
|
6452
6452
|
function _l(i) {
|
|
6453
|
-
const { inputs: t, backend: e } = i, { x: o } = t, s = [...o.shape, ...o.shape], r = $(o.shape), a = R({ inputs: { x: o }, backend: e, attrs: { shape: [r] } }), n = new
|
|
6453
|
+
const { inputs: t, backend: e } = i, { x: o } = t, s = [...o.shape, ...o.shape], r = $(o.shape), a = R({ inputs: { x: o }, backend: e, attrs: { shape: [r] } }), n = new Wl(r), u = e.runWebGPUProgram(n, [a], a.dtype), d = R({ inputs: { x: u }, backend: e, attrs: { shape: s } });
|
|
6454
6454
|
return e.disposeData(a.dataId), e.disposeData(u.dataId), d;
|
|
6455
6455
|
}
|
|
6456
6456
|
const Vl = {
|
|
@@ -7387,7 +7387,7 @@ function Rc(i) {
|
|
|
7387
7387
|
const M = U().getBool("CANVAS2D_WILL_READ_FREQUENTLY_FOR_GPU");
|
|
7388
7388
|
(ne == null || M !== Ve) && (Ve = M, ne = document.createElement("canvas").getContext("2d", { willReadFrequently: Ve })), ne.canvas.width = h, ne.canvas.height = l, ne.drawImage(s, 0, 0, h, l), s = ne.canvas;
|
|
7389
7389
|
}
|
|
7390
|
-
const
|
|
7390
|
+
const E = GPUTextureUsage.COPY_DST | GPUTextureUsage.RENDER_ATTACHMENT | GPUTextureUsage.TEXTURE_BINDING, G = e.textureManager.acquireTexture(c[1], c[0], "rgba8unorm", E);
|
|
7391
7391
|
e.queue.copyExternalImageToTexture({ source: s }, { texture: G }, [c[1], c[0]]), C = G;
|
|
7392
7392
|
}
|
|
7393
7393
|
const w = $(c), v = Re(c), k = new Ic(c, r, p), P = [
|
|
@@ -7498,7 +7498,7 @@ const Dc = {
|
|
|
7498
7498
|
*/
|
|
7499
7499
|
function Nc(i) {
|
|
7500
7500
|
const { inputs: t, backend: e, attrs: o } = i, { x: s, filter: r, bias: a, preluActivationWeights: n } = t, { strides: u, pad: d, dataFormat: h, dilations: l, dimRoundingMode: c, activation: p, leakyreluAlpha: f } = o, m = ye(h), g = ee(s.shape, r.shape, u, l, d, c, !1, m);
|
|
7501
|
-
return
|
|
7501
|
+
return Tt({
|
|
7502
7502
|
x: s,
|
|
7503
7503
|
filter: r,
|
|
7504
7504
|
convInfo: g,
|
|
@@ -7628,7 +7628,7 @@ function Lc(i) {
|
|
|
7628
7628
|
const f = new Fc(a, [d, h]), m = [{ type: "int32", data: [a] }, { type: "int32", data: l }], g = e.runWebGPUProgram(f, [p, c], p.dtype, m), x = R({ inputs: { x: g }, backend: e, attrs: { shape: u } });
|
|
7629
7629
|
return e.disposeData(c.dataId), e.disposeData(p.dataId), e.disposeData(g.dataId), x;
|
|
7630
7630
|
}
|
|
7631
|
-
const
|
|
7631
|
+
const Bc = {
|
|
7632
7632
|
kernelName: Fo,
|
|
7633
7633
|
backendName: "webgpu",
|
|
7634
7634
|
kernelFunc: Lc
|
|
@@ -7649,12 +7649,12 @@ const Ec = {
|
|
|
7649
7649
|
* limitations under the License.
|
|
7650
7650
|
* =============================================================================
|
|
7651
7651
|
*/
|
|
7652
|
-
class
|
|
7652
|
+
class Ec {
|
|
7653
7653
|
constructor(t, e) {
|
|
7654
7654
|
this.variableNames = ["A", "indices"], this.workgroupSize = [64, 1, 1], this.size = !0, this.outputShape = t.slice(), this.aShape = t, this.outputShape = e, this.dispatchLayout = I(this.outputShape), this.dispatch = S(this.dispatchLayout, this.outputShape, this.workgroupSize), this.shaderKey = "gather";
|
|
7655
7655
|
}
|
|
7656
7656
|
getUserCode() {
|
|
7657
|
-
const t =
|
|
7657
|
+
const t = Tc(this.aShape);
|
|
7658
7658
|
return `
|
|
7659
7659
|
${y("index")} {
|
|
7660
7660
|
if (index < uniforms.size) {
|
|
@@ -7667,7 +7667,7 @@ class Bc {
|
|
|
7667
7667
|
`;
|
|
7668
7668
|
}
|
|
7669
7669
|
}
|
|
7670
|
-
function
|
|
7670
|
+
function Tc(i) {
|
|
7671
7671
|
const t = ["resRC.x", "resRC.y", "resRC.z", "resRC.w"], e = [];
|
|
7672
7672
|
for (let o = 0; o < i.length; o++)
|
|
7673
7673
|
o === 2 ? e.push("indexZ") : e.push(`${t[o]}`);
|
|
@@ -7717,12 +7717,12 @@ function Gt(i) {
|
|
|
7717
7717
|
const w = e.tensorMap.get(p.dataId).values, v = de(p.shape, p.dtype, w), P = e.tensorMap.get(c.dataId).values, N = de(c.shape, c.dtype, P), A = qn(N, v, f);
|
|
7718
7718
|
return l.forEach((z) => e.disposeData(z.dataId)), e.makeTensorInfo(d.outputShape, A.dtype, A.values);
|
|
7719
7719
|
}
|
|
7720
|
-
const m = new
|
|
7720
|
+
const m = new Ec(c.shape, f), g = e.runWebGPUProgram(m, [c, p], c.dtype);
|
|
7721
7721
|
l.push(g);
|
|
7722
7722
|
const x = R({ inputs: { x: g }, backend: e, attrs: { shape: d.outputShape } });
|
|
7723
7723
|
return l.forEach((C) => e.disposeData(C.dataId)), x;
|
|
7724
7724
|
}
|
|
7725
|
-
const
|
|
7725
|
+
const Wc = {
|
|
7726
7726
|
kernelName: Lo,
|
|
7727
7727
|
backendName: "webgpu",
|
|
7728
7728
|
kernelFunc: Gt
|
|
@@ -7748,7 +7748,7 @@ const _c = _({
|
|
|
7748
7748
|
cpuKernelImpl: Qn,
|
|
7749
7749
|
dtype: "bool"
|
|
7750
7750
|
}), Vc = {
|
|
7751
|
-
kernelName:
|
|
7751
|
+
kernelName: Bo,
|
|
7752
7752
|
backendName: "webgpu",
|
|
7753
7753
|
kernelFunc: _c
|
|
7754
7754
|
};
|
|
@@ -7773,7 +7773,7 @@ const Uc = _({
|
|
|
7773
7773
|
dtype: "bool",
|
|
7774
7774
|
cpuKernelImpl: Yn
|
|
7775
7775
|
}), Mc = {
|
|
7776
|
-
kernelName:
|
|
7776
|
+
kernelName: Eo,
|
|
7777
7777
|
backendName: "webgpu",
|
|
7778
7778
|
kernelFunc: Uc
|
|
7779
7779
|
};
|
|
@@ -7798,7 +7798,7 @@ function Oc(i) {
|
|
|
7798
7798
|
return Ot(o, !0, e);
|
|
7799
7799
|
}
|
|
7800
7800
|
const Gc = {
|
|
7801
|
-
kernelName:
|
|
7801
|
+
kernelName: To,
|
|
7802
7802
|
backendName: "webgpu",
|
|
7803
7803
|
kernelFunc: Oc
|
|
7804
7804
|
};
|
|
@@ -7819,7 +7819,7 @@ const Gc = {
|
|
|
7819
7819
|
* =============================================================================
|
|
7820
7820
|
*/
|
|
7821
7821
|
const Hc = F({ opType: b.IS_FINITE, dtype: "bool" }), Xc = {
|
|
7822
|
-
kernelName:
|
|
7822
|
+
kernelName: Wo,
|
|
7823
7823
|
backendName: "webgpu",
|
|
7824
7824
|
kernelFunc: Hc
|
|
7825
7825
|
};
|
|
@@ -8716,7 +8716,7 @@ const Lh = {
|
|
|
8716
8716
|
* limitations under the License.
|
|
8717
8717
|
* =============================================================================
|
|
8718
8718
|
*/
|
|
8719
|
-
function
|
|
8719
|
+
function Bh(i) {
|
|
8720
8720
|
const { inputs: t, backend: e, attrs: o } = i, { filterSize: s, strides: r, pad: a, includeBatchInIndex: n } = o, { x: u } = t;
|
|
8721
8721
|
L(u.shape.length === 4, () => `Error in maxPool: input must be rank 4 but got rank ${u.shape.length}.`);
|
|
8722
8722
|
const d = [1, 1];
|
|
@@ -8737,10 +8737,10 @@ function Eh(i) {
|
|
|
8737
8737
|
const f = e.runWebGPUProgram(c, [u], "int32", l);
|
|
8738
8738
|
return [p, f];
|
|
8739
8739
|
}
|
|
8740
|
-
const
|
|
8740
|
+
const Eh = {
|
|
8741
8741
|
kernelName: oi,
|
|
8742
8742
|
backendName: "webgpu",
|
|
8743
|
-
kernelFunc:
|
|
8743
|
+
kernelFunc: Bh
|
|
8744
8744
|
};
|
|
8745
8745
|
/**
|
|
8746
8746
|
* @license
|
|
@@ -8758,14 +8758,14 @@ const Bh = {
|
|
|
8758
8758
|
* limitations under the License.
|
|
8759
8759
|
* =============================================================================
|
|
8760
8760
|
*/
|
|
8761
|
-
function
|
|
8761
|
+
function Th(i) {
|
|
8762
8762
|
const { inputs: t, backend: e, attrs: o } = i, { x: s } = t, { axis: r, keepDims: a } = o;
|
|
8763
8763
|
return re(s, r, a, "min", e);
|
|
8764
8764
|
}
|
|
8765
|
-
const
|
|
8765
|
+
const Wh = {
|
|
8766
8766
|
kernelName: ii,
|
|
8767
8767
|
backendName: "webgpu",
|
|
8768
|
-
kernelFunc:
|
|
8768
|
+
kernelFunc: Th
|
|
8769
8769
|
};
|
|
8770
8770
|
/**
|
|
8771
8771
|
* @license
|
|
@@ -9245,7 +9245,7 @@ const ip = {
|
|
|
9245
9245
|
function $e(i) {
|
|
9246
9246
|
const { inputs: t, backend: e } = i, { x: o } = t;
|
|
9247
9247
|
if (o.dtype === "complex64") {
|
|
9248
|
-
const s = Ie({ inputs: { input: o }, backend: e }), r = $e({ inputs: { x: s }, backend: e }), a =
|
|
9248
|
+
const s = Ie({ inputs: { input: o }, backend: e }), r = $e({ inputs: { x: s }, backend: e }), a = Te({ inputs: { input: o }, backend: e }), n = $e({ inputs: { x: a }, backend: e }), u = ie({ inputs: { real: r, imag: n }, backend: e });
|
|
9249
9249
|
return e.disposeData(s.dataId), e.disposeData(r.dataId), e.disposeData(a.dataId), e.disposeData(n.dataId), u;
|
|
9250
9250
|
} else
|
|
9251
9251
|
return H({
|
|
@@ -9283,7 +9283,7 @@ function Kt(i) {
|
|
|
9283
9283
|
if (o.dtype === "string")
|
|
9284
9284
|
throw new Error("onesLike is not supported under string dtype");
|
|
9285
9285
|
if (o.dtype === "complex64") {
|
|
9286
|
-
const s = Ie({ inputs: { input: o }, backend: e }), r = Kt({ inputs: { x: s }, backend: e }), a =
|
|
9286
|
+
const s = Ie({ inputs: { input: o }, backend: e }), r = Kt({ inputs: { x: s }, backend: e }), a = Te({ inputs: { input: o }, backend: e }), n = $e({ inputs: { x: a }, backend: e }), u = ie({ inputs: { real: r, imag: n }, backend: e });
|
|
9287
9287
|
return e.disposeData(s.dataId), e.disposeData(r.dataId), e.disposeData(a.dataId), e.disposeData(n.dataId), u;
|
|
9288
9288
|
} else
|
|
9289
9289
|
return H({ attrs: { shape: o.shape, dtype: o.dtype, value: 1 }, backend: e });
|
|
@@ -9320,7 +9320,7 @@ function np(i) {
|
|
|
9320
9320
|
const n = [], u = t.map((h) => {
|
|
9321
9321
|
const l = Oe({ inputs: { input: h }, backend: e, attrs: { dim: s } });
|
|
9322
9322
|
return n.push(l), l;
|
|
9323
|
-
}), d =
|
|
9323
|
+
}), d = Et({ inputs: u, backend: e, attrs: { axis: s } });
|
|
9324
9324
|
return n.forEach((h) => e.disposeData(h.dataId)), d;
|
|
9325
9325
|
}
|
|
9326
9326
|
const up = {
|
|
@@ -9903,17 +9903,17 @@ class Lp {
|
|
|
9903
9903
|
* limitations under the License.
|
|
9904
9904
|
* =============================================================================
|
|
9905
9905
|
*/
|
|
9906
|
-
function
|
|
9906
|
+
function Bp(i) {
|
|
9907
9907
|
const { inputs: t, backend: e, attrs: o } = i, { images: s } = t, { alignCorners: r, halfPixelCenters: a, size: n } = o, [u, d] = n, h = r && u > 1 ? 1 : 0, l = r && d > 1 ? 1 : 0, p = [
|
|
9908
9908
|
{ type: "float32", data: [h, l] },
|
|
9909
9909
|
{ type: "float32", data: [r ? 0.5 : 0] }
|
|
9910
9910
|
], f = new Lp(s.shape, u, d, a);
|
|
9911
9911
|
return e.runWebGPUProgram(f, [s], s.dtype, p);
|
|
9912
9912
|
}
|
|
9913
|
-
const
|
|
9913
|
+
const Ep = {
|
|
9914
9914
|
kernelName: Ni,
|
|
9915
9915
|
backendName: "webgpu",
|
|
9916
|
-
kernelFunc:
|
|
9916
|
+
kernelFunc: Bp
|
|
9917
9917
|
};
|
|
9918
9918
|
/**
|
|
9919
9919
|
* @license
|
|
@@ -9931,7 +9931,7 @@ const Bp = {
|
|
|
9931
9931
|
* limitations under the License.
|
|
9932
9932
|
* =============================================================================
|
|
9933
9933
|
*/
|
|
9934
|
-
class
|
|
9934
|
+
class Tp {
|
|
9935
9935
|
constructor(t, e) {
|
|
9936
9936
|
this.variableNames = ["dy"], this.uniforms = `effectiveXSize : vec2<i32>, effectiveYSize : vec2<i32>, invHeightScale : f32, invWidthScale : f32,
|
|
9937
9937
|
winHeight : i32, winWidth : i32,`, this.workgroupSize = [64, 1, 1], this.size = !0, this.outputShape = t, this.dispatchLayout = I(this.outputShape), this.dispatch = S(this.dispatchLayout, this.outputShape, this.workgroupSize), this.alignCorners = e, this.shaderKey = `resizeNearestNeigborBackprop_${e}`;
|
|
@@ -10015,14 +10015,14 @@ class Wp {
|
|
|
10015
10015
|
* limitations under the License.
|
|
10016
10016
|
* =============================================================================
|
|
10017
10017
|
*/
|
|
10018
|
-
function
|
|
10018
|
+
function Wp(i) {
|
|
10019
10019
|
const { inputs: t, backend: e, attrs: o } = i, { images: s, dy: r } = t, { alignCorners: a } = o, [, n, u] = s.shape, [, d, h] = r.shape, l = [
|
|
10020
10020
|
a && d > 1 ? n - 1 : n,
|
|
10021
10021
|
a && h > 1 ? u - 1 : u
|
|
10022
10022
|
], c = [
|
|
10023
10023
|
a && d > 1 ? d - 1 : d,
|
|
10024
10024
|
a && h > 1 ? h - 1 : h
|
|
10025
|
-
], p = l[0] / c[0], f = l[1] / c[1], m = 1 / p, g = 1 / f, x = Math.ceil(m) * 2 + 2, C = Math.ceil(g) * 2 + 2, w = new
|
|
10025
|
+
], p = l[0] / c[0], f = l[1] / c[1], m = 1 / p, g = 1 / f, x = Math.ceil(m) * 2 + 2, C = Math.ceil(g) * 2 + 2, w = new Tp(s.shape, a), v = [
|
|
10026
10026
|
{ type: "int32", data: l },
|
|
10027
10027
|
{ type: "int32", data: c },
|
|
10028
10028
|
{ type: "float32", data: [m] },
|
|
@@ -10035,7 +10035,7 @@ function Tp(i) {
|
|
|
10035
10035
|
const _p = {
|
|
10036
10036
|
kernelName: $i,
|
|
10037
10037
|
backendName: "webgpu",
|
|
10038
|
-
kernelFunc:
|
|
10038
|
+
kernelFunc: Wp
|
|
10039
10039
|
};
|
|
10040
10040
|
/**
|
|
10041
10041
|
* @license
|
|
@@ -10338,7 +10338,7 @@ function Yp(i) {
|
|
|
10338
10338
|
return e.disposeData(p.dataId), e.disposeData(f.dataId), e.disposeData(v.dataId), k;
|
|
10339
10339
|
}
|
|
10340
10340
|
const Qp = {
|
|
10341
|
-
kernelName:
|
|
10341
|
+
kernelName: Bi,
|
|
10342
10342
|
backendName: "webgpu",
|
|
10343
10343
|
kernelFunc: Yp
|
|
10344
10344
|
};
|
|
@@ -10409,7 +10409,7 @@ function Zp(i) {
|
|
|
10409
10409
|
return e.runWebGPUProgram(n, [s, r], "int32", u);
|
|
10410
10410
|
}
|
|
10411
10411
|
const Jp = {
|
|
10412
|
-
kernelName:
|
|
10412
|
+
kernelName: Ei,
|
|
10413
10413
|
backendName: "webgpu",
|
|
10414
10414
|
kernelFunc: Zp
|
|
10415
10415
|
};
|
|
@@ -10481,7 +10481,7 @@ function tf(i) {
|
|
|
10481
10481
|
return e.runWebGPUProgram(a, [o, s, r], ge(s.dtype, r.dtype));
|
|
10482
10482
|
}
|
|
10483
10483
|
const sf = {
|
|
10484
|
-
kernelName:
|
|
10484
|
+
kernelName: Ti,
|
|
10485
10485
|
backendName: "webgpu",
|
|
10486
10486
|
kernelFunc: tf
|
|
10487
10487
|
};
|
|
@@ -10502,7 +10502,7 @@ const sf = {
|
|
|
10502
10502
|
* =============================================================================
|
|
10503
10503
|
*/
|
|
10504
10504
|
const of = F({ opType: b.SELU }), rf = {
|
|
10505
|
-
kernelName:
|
|
10505
|
+
kernelName: Wi,
|
|
10506
10506
|
backendName: "webgpu",
|
|
10507
10507
|
kernelFunc: of
|
|
10508
10508
|
};
|
|
@@ -10942,8 +10942,8 @@ const Df = {
|
|
|
10942
10942
|
function Nf(i) {
|
|
10943
10943
|
const { inputs: t, backend: e, attrs: o } = i, { sparseIndices: s, sparseValues: r, defaultValue: a } = t, { outputShape: n } = o, { sliceRank: u, numUpdates: d, sliceSize: h, strides: l, outputSize: c } = Qe(r, s, n), p = !1;
|
|
10944
10944
|
if (r.dtype === "string") {
|
|
10945
|
-
const A = e.bufferSync(s), z = e.bufferSync(r),
|
|
10946
|
-
return e.makeTensorInfo(n,
|
|
10945
|
+
const A = e.bufferSync(s), z = e.bufferSync(r), E = Ge(e.readSync(a.dataId)[0]), T = du(A, z, n, c, h, d, u, l, E, p);
|
|
10946
|
+
return e.makeTensorInfo(n, T.dtype, T.values);
|
|
10947
10947
|
}
|
|
10948
10948
|
const f = [c / h, h], m = R({
|
|
10949
10949
|
inputs: { x: s },
|
|
@@ -11056,7 +11056,7 @@ const Ff = F({ opType: b.SQRT }), Lf = {
|
|
|
11056
11056
|
* limitations under the License.
|
|
11057
11057
|
* =============================================================================
|
|
11058
11058
|
*/
|
|
11059
|
-
const
|
|
11059
|
+
const Bf = {
|
|
11060
11060
|
kernelName: Zi,
|
|
11061
11061
|
backendName: "webgpu",
|
|
11062
11062
|
kernelFunc: ({ inputs: i, backend: t }) => {
|
|
@@ -11080,12 +11080,12 @@ const Ef = {
|
|
|
11080
11080
|
* limitations under the License.
|
|
11081
11081
|
* =============================================================================
|
|
11082
11082
|
*/
|
|
11083
|
-
const
|
|
11083
|
+
const Ef = _({
|
|
11084
11084
|
opType: D.SQUARED_DIFFERENCE
|
|
11085
|
-
}),
|
|
11085
|
+
}), Tf = {
|
|
11086
11086
|
kernelName: Ji,
|
|
11087
11087
|
backendName: "webgpu",
|
|
11088
|
-
kernelFunc:
|
|
11088
|
+
kernelFunc: Ef
|
|
11089
11089
|
};
|
|
11090
11090
|
/**
|
|
11091
11091
|
* @license
|
|
@@ -11103,14 +11103,14 @@ const Bf = _({
|
|
|
11103
11103
|
* limitations under the License.
|
|
11104
11104
|
* =============================================================================
|
|
11105
11105
|
*/
|
|
11106
|
-
function
|
|
11106
|
+
function Wf({ inputs: i, attrs: t, backend: e }) {
|
|
11107
11107
|
const { x: o } = i, s = new le(o.shape, b.STEP, "stepAlpha : f32,"), r = [{ type: "float32", data: [t.alpha] }];
|
|
11108
11108
|
return e.runWebGPUProgram(s, [o], o.dtype, r);
|
|
11109
11109
|
}
|
|
11110
11110
|
const _f = {
|
|
11111
11111
|
kernelName: er,
|
|
11112
11112
|
backendName: "webgpu",
|
|
11113
|
-
kernelFunc:
|
|
11113
|
+
kernelFunc: Wf
|
|
11114
11114
|
};
|
|
11115
11115
|
/**
|
|
11116
11116
|
* @license
|
|
@@ -11533,14 +11533,14 @@ function tm(i) {
|
|
|
11533
11533
|
const h = $(n) / u, l = R({ inputs: { x: s }, attrs: { shape: [h, u] }, backend: e }), c = xt(r), p = xt(u);
|
|
11534
11534
|
let f = null;
|
|
11535
11535
|
const m = () => f === null ? [l, l] : [l, f], g = (k, P, N) => {
|
|
11536
|
-
const A = m(), z = new Jf(N),
|
|
11536
|
+
const A = m(), z = new Jf(N), T = [
|
|
11537
11537
|
{ type: "int32", data: [u] },
|
|
11538
11538
|
{ type: "int32", data: [f === null ? 1 : 0] },
|
|
11539
11539
|
{ type: "float32", data: [Number.NEGATIVE_INFINITY] },
|
|
11540
11540
|
{ type: "int32", data: [k] },
|
|
11541
11541
|
{ type: "int32", data: [P] }
|
|
11542
11542
|
], G = f;
|
|
11543
|
-
f = e.runWebGPUProgram(z, A, "int32",
|
|
11543
|
+
f = e.runWebGPUProgram(z, A, "int32", T), ue(e, G);
|
|
11544
11544
|
};
|
|
11545
11545
|
for (let k = 1; k < c; k *= 2) {
|
|
11546
11546
|
const P = k * 2;
|
|
@@ -11552,10 +11552,10 @@ function tm(i) {
|
|
|
11552
11552
|
{ type: "int32", data: [u] },
|
|
11553
11553
|
{ type: "int32", data: [f === null ? 1 : 0] },
|
|
11554
11554
|
{ type: "int32", data: [c] }
|
|
11555
|
-
],
|
|
11556
|
-
f = e.runWebGPUProgram(N, P, "int32", z), ue(e,
|
|
11557
|
-
const
|
|
11558
|
-
for (let M =
|
|
11555
|
+
], E = f;
|
|
11556
|
+
f = e.runWebGPUProgram(N, P, "int32", z), ue(e, E);
|
|
11557
|
+
const T = c / 2, G = T * 2;
|
|
11558
|
+
for (let M = T; M >= 1; M /= 2)
|
|
11559
11559
|
g(G, M, f.shape);
|
|
11560
11560
|
}
|
|
11561
11561
|
let x = f;
|
|
@@ -11869,7 +11869,7 @@ function dm(i) {
|
|
|
11869
11869
|
const h = Se([d], n);
|
|
11870
11870
|
let l = s;
|
|
11871
11871
|
h != null && (l = Y({ inputs: { x: s }, backend: e, attrs: { perm: h } }), u.push(l), d = ve(1, n)[0]);
|
|
11872
|
-
const c =
|
|
11872
|
+
const c = Br(l.shape, d, a), p = $([l.shape[d]]), f = R({ inputs: { x: l }, backend: e, attrs: { shape: [-1, p] } });
|
|
11873
11873
|
u.push(f);
|
|
11874
11874
|
const m = s.dtype, g = [f.shape[0], a], x = H({ backend: e, attrs: { shape: g, value: 0, dtype: m } }), C = new um(f.shape, g, m), w = [
|
|
11875
11875
|
{ type: "int32", data: [a] },
|
|
@@ -11906,14 +11906,14 @@ const lm = {
|
|
|
11906
11906
|
* =============================================================================
|
|
11907
11907
|
*/
|
|
11908
11908
|
const cm = [
|
|
11909
|
-
|
|
11909
|
+
Bn,
|
|
11910
11910
|
wu,
|
|
11911
11911
|
yu,
|
|
11912
11912
|
vu,
|
|
11913
11913
|
ku,
|
|
11914
11914
|
Du,
|
|
11915
|
-
|
|
11916
|
-
|
|
11915
|
+
Bu,
|
|
11916
|
+
Tu,
|
|
11917
11917
|
_u,
|
|
11918
11918
|
Uu,
|
|
11919
11919
|
Ou,
|
|
@@ -11932,8 +11932,8 @@ const cm = [
|
|
|
11932
11932
|
Dd,
|
|
11933
11933
|
$d,
|
|
11934
11934
|
Ld,
|
|
11935
|
-
|
|
11936
|
-
|
|
11935
|
+
Tn,
|
|
11936
|
+
Td,
|
|
11937
11937
|
Ud,
|
|
11938
11938
|
Yd,
|
|
11939
11939
|
tl,
|
|
@@ -11948,8 +11948,8 @@ const cm = [
|
|
|
11948
11948
|
vl,
|
|
11949
11949
|
kl,
|
|
11950
11950
|
Dl,
|
|
11951
|
-
|
|
11952
|
-
|
|
11951
|
+
Bl,
|
|
11952
|
+
Tl,
|
|
11953
11953
|
zl,
|
|
11954
11954
|
Vl,
|
|
11955
11955
|
Ol,
|
|
@@ -11973,11 +11973,11 @@ const cm = [
|
|
|
11973
11973
|
Dc,
|
|
11974
11974
|
$c,
|
|
11975
11975
|
Ac,
|
|
11976
|
-
|
|
11977
|
-
|
|
11976
|
+
Bc,
|
|
11977
|
+
Wc,
|
|
11978
11978
|
Vc,
|
|
11979
11979
|
Mc,
|
|
11980
|
-
|
|
11980
|
+
En,
|
|
11981
11981
|
Gc,
|
|
11982
11982
|
_d,
|
|
11983
11983
|
Xc,
|
|
@@ -12000,9 +12000,9 @@ const cm = [
|
|
|
12000
12000
|
Lh,
|
|
12001
12001
|
Dh,
|
|
12002
12002
|
Ah,
|
|
12003
|
-
|
|
12003
|
+
Eh,
|
|
12004
12004
|
ed,
|
|
12005
|
-
|
|
12005
|
+
Wh,
|
|
12006
12006
|
Vh,
|
|
12007
12007
|
Mh,
|
|
12008
12008
|
Gh,
|
|
@@ -12028,7 +12028,7 @@ const cm = [
|
|
|
12028
12028
|
Mr,
|
|
12029
12029
|
$p,
|
|
12030
12030
|
Fp,
|
|
12031
|
-
|
|
12031
|
+
Ep,
|
|
12032
12032
|
_p,
|
|
12033
12033
|
Mp,
|
|
12034
12034
|
Gp,
|
|
@@ -12054,8 +12054,8 @@ const cm = [
|
|
|
12054
12054
|
$f,
|
|
12055
12055
|
Af,
|
|
12056
12056
|
Lf,
|
|
12057
|
-
|
|
12058
|
-
|
|
12057
|
+
Bf,
|
|
12058
|
+
Tf,
|
|
12059
12059
|
Xf,
|
|
12060
12060
|
ec,
|
|
12061
12061
|
qf,
|