@genai-fi/nanogpt 0.8.5 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.d.ts +4 -1
- package/dist/Generator.js +144 -124
- package/dist/{RealDiv-D_q39E3A.js → RealDiv-D4EzDsC0.js} +7 -7
- package/dist/{Reshape-Bh_jzKzV.js → Reshape-Bowtk9BP.js} +2 -2
- package/dist/{Reshape-41YpQqEo.js → Reshape-DUqYftGC.js} +1 -1
- package/dist/TeachableLLM.js +5 -5
- package/dist/Trainer.d.ts +1 -0
- package/dist/Trainer.js +3 -0
- package/dist/{axis_util-Did9235A.js → axis_util-TbGYJ208.js} +1 -1
- package/dist/backend.js +2 -2
- package/dist/{backend_util-yC3YH1jo.js → backend_util-CJIiDoV1.js} +4 -4
- package/dist/{broadcast_to-CUvOdOT5.js → broadcast_to-DzlNweb8.js} +2 -2
- package/dist/checks/appendCache.js +2 -2
- package/dist/checks/attentionMask.js +3 -3
- package/dist/checks/gelu.js +2 -2
- package/dist/checks/matMulGelu.js +5 -5
- package/dist/checks/normRMS.js +4 -4
- package/dist/checks/normRMSGrad.js +3 -3
- package/dist/checks/qkv.js +2 -2
- package/dist/checks/rope.js +2 -2
- package/dist/{concat-pHiVqR3L.js → concat-B912vBbo.js} +1 -1
- package/dist/{dataset-DPPl-iLT.js → dataset-DlZtKmBq.js} +3 -3
- package/dist/{dropout-CcKSfOYE.js → dropout-C-csYCLj.js} +6 -6
- package/dist/{exports_initializers-DKk7-bsx.js → exports_initializers-B8iZMgQ0.js} +1 -1
- package/dist/{gather-CPg6ZlQA.js → gather-Dnpgw-YQ.js} +1 -1
- package/dist/{gelu-BkcmEEyD.js → gelu-Bp_-935b.js} +1 -1
- package/dist/{gpgpu_math-D_ODOLix.js → gpgpu_math-CDaYiyE_.js} +2 -2
- package/dist/{index-DdmHGZjq.js → index-BzFyqcy-.js} +13 -13
- package/dist/{index-evZ57wr4.js → index-C1rx_Ajs.js} +10 -10
- package/dist/{kernel_funcs_utils-CDfFpUab.js → kernel_funcs_utils-DKLK0Mg3.js} +3 -3
- package/dist/layers/BaseLayer.js +2 -2
- package/dist/layers/CausalSelfAttention.js +6 -6
- package/dist/layers/MLP.js +5 -5
- package/dist/layers/PositionEmbedding.js +5 -5
- package/dist/layers/RMSNorm.js +3 -3
- package/dist/layers/RoPECache.js +4 -4
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.js +1 -1
- package/dist/loader/loadTransformers.js +1 -1
- package/dist/loader/oldZipLoad.js +5 -5
- package/dist/{log_sum_exp-C8yFJfZz.js → log_sum_exp-DO6z8tSE.js} +9 -9
- package/dist/main.d.ts +1 -0
- package/dist/main.js +18 -16
- package/dist/{mat_mul-Dpy2mMRu.js → mat_mul-DzjTFx-u.js} +1 -1
- package/dist/{mod-CbibJi3D.js → mod-Dobti4j4.js} +1 -1
- package/dist/models/NanoGPTV1.d.ts +1 -0
- package/dist/models/NanoGPTV1.js +12 -9
- package/dist/models/model.d.ts +1 -0
- package/dist/models/model.js +5 -5
- package/dist/{mulmat_packed_gpu-q_Gmwyld.js → mulmat_packed_gpu-BT60jmzP.js} +1 -1
- package/dist/{ones-BAqVh-eA.js → ones-tIJeHlq-.js} +2 -2
- package/dist/ops/adamAdjust.js +1 -1
- package/dist/ops/adamMoments.js +1 -1
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/adamAdjust.js +1 -1
- package/dist/ops/cpu/adamMoments.js +2 -2
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +5 -5
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +2 -2
- package/dist/ops/cpu/matMulMul.js +1 -1
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +13 -13
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +2 -2
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.js +1 -1
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/rope.js +4 -4
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/adamAdjust.js +2 -2
- package/dist/ops/webgl/adamMoments.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +4 -4
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/log.js +3 -3
- package/dist/ops/webgl/matMulGelu.js +4 -4
- package/dist/ops/webgl/matMulMul.js +1 -1
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +2 -2
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops/webgpu/adamAdjust.js +3 -3
- package/dist/ops/webgpu/adamMoments.js +3 -3
- package/dist/ops/webgpu/appendCache.js +3 -3
- package/dist/ops/webgpu/attentionMask.js +3 -3
- package/dist/ops/webgpu/gatherSub.js +3 -3
- package/dist/ops/webgpu/gelu.js +37 -35
- package/dist/ops/webgpu/normRMS.js +2 -2
- package/dist/ops/webgpu/normRMSGrad.js +5 -5
- package/dist/ops/webgpu/qkv.js +3 -3
- package/dist/ops/webgpu/rope.js +3 -3
- package/dist/ops/webgpu/scatterSub.js +3 -3
- package/dist/ops/webgpu/utils/reductions.js +4 -4
- package/dist/{ops-542ai2vG.js → ops-LuCMAnmM.js} +65 -65
- package/dist/{random_width-DKGeiFuR.js → random_width-CXVRloNK.js} +23 -23
- package/dist/{range-BcUvLuf5.js → range-CWcz7xFA.js} +3 -3
- package/dist/{reciprocal-DhDWSKiD.js → reciprocal-C4rNcM-S.js} +1 -1
- package/dist/{register_all_kernels-Do9VvZmo.js → register_all_kernels-DIGpEwcf.js} +31 -31
- package/dist/{relu-B1AXs7p5.js → relu-BjCh_SYb.js} +1 -1
- package/dist/{reshape-WeJkT3ja.js → reshape-CnIwVG1c.js} +1 -1
- package/dist/{scatter_nd_util-B7yDhiQr.js → scatter_nd_util-BQdz--Gn.js} +1 -1
- package/dist/{selu_util-BgUO9gHY.js → selu_util-OtRzVwW5.js} +23 -23
- package/dist/{shared-V6D_md-c.js → shared-DmRsFyaJ.js} +6 -6
- package/dist/{shared-CZiWmQCI.js → shared-DuP7ue-R.js} +1 -1
- package/dist/{sin-CPxad7Am.js → sin-gpDNRxE0.js} +1 -1
- package/dist/{slice-B7jXtPnp.js → slice-d0Vo9XTN.js} +1 -1
- package/dist/{softmax-BfsyI4As.js → softmax-D7Jj3p_P.js} +1 -1
- package/dist/{split-BPxr8_8m.js → split-DK2k5eHf.js} +1 -1
- package/dist/{stack-BNwLzE43.js → stack-DFatutCx.js} +1 -1
- package/dist/{sum-ByFINZgi.js → sum-CJ0ULhmt.js} +1 -1
- package/dist/{tensor-DbqgIV9B.js → tensor-CZr4dh61.js} +1 -1
- package/dist/{tensor1d-CtJq5BOv.js → tensor1d-vML0r3q6.js} +1 -1
- package/dist/{tensor2d-CObBWBkW.js → tensor2d-D76QGjF3.js} +1 -1
- package/dist/{tensor4d-DLtk7Nxh.js → tensor4d-Df1WlVDY.js} +1 -1
- package/dist/training/Adam.js +2 -2
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +2 -2
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +2 -2
- package/dist/training/sparseCrossEntropy.js +3 -3
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/multinomialCPU.js +2 -2
- package/dist/utilities/performance.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/safetensors.js +2 -2
- package/dist/utilities/topP.d.ts +1 -0
- package/dist/utilities/topP.js +13 -0
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-DPFOJyRG.js → variable-Bm2OFwGI.js} +1 -1
- package/dist/{webgpu_program-Dhk9R5aG.js → webgpu_program-DkQJOJSd.js} +1 -1
- package/dist/{webgpu_util-BqGnZg8t.js → webgpu_util-pLEV9tks.js} +1 -1
- package/dist/{zeros-Dnwix0p4.js → zeros-Bj5rMYA7.js} +1 -1
- package/package.json +1 -1
package/dist/ops/webgpu/gelu.js
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
import { f as
|
|
2
|
-
import { g as
|
|
3
|
-
import { f as
|
|
4
|
-
const
|
|
5
|
-
class
|
|
1
|
+
import { f as s } from "../../index-BzFyqcy-.js";
|
|
2
|
+
import { g as a } from "../../webgpu_program-DkQJOJSd.js";
|
|
3
|
+
import { f as o, c as p } from "../../webgpu_util-pLEV9tks.js";
|
|
4
|
+
const u = 0.7978845608028654, i = 0.044715;
|
|
5
|
+
class d {
|
|
6
6
|
outputShape;
|
|
7
7
|
shaderKey;
|
|
8
8
|
dispatchLayout;
|
|
@@ -11,32 +11,33 @@ class h {
|
|
|
11
11
|
workgroupSize;
|
|
12
12
|
size = !0;
|
|
13
13
|
constructor(e) {
|
|
14
|
-
this.workgroupSize = [128, 1, 1], this.outputShape = e, this.dispatchLayout =
|
|
14
|
+
this.workgroupSize = [128, 1, 1], this.outputShape = e, this.dispatchLayout = o(this.outputShape), this.dispatch = p(this.dispatchLayout, this.outputShape, this.workgroupSize), this.shaderKey = "unary_gelu";
|
|
15
15
|
}
|
|
16
16
|
getUserCode() {
|
|
17
17
|
return `
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
18
|
+
// TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
|
|
19
|
+
fn tanhComplete(x: f32) -> f32 {
|
|
20
|
+
return select(tanh(x), sign(x), abs(x) > 15.0);
|
|
21
|
+
}
|
|
22
|
+
fn unaryOperation(x : f32) -> f32 {
|
|
23
|
+
let x3 = x * x * x;
|
|
24
|
+
var inner = fma(${i}, x3, x);
|
|
25
|
+
inner = ${u} * inner;
|
|
26
|
+
inner = tanhComplete(inner);
|
|
27
|
+
inner = 0.5 * (1.0 + inner);
|
|
28
|
+
return x * inner;
|
|
29
|
+
}
|
|
30
|
+
${a("index")} {
|
|
31
|
+
if (index < uniforms.size) {
|
|
32
|
+
let a = getAByOutputIndex(index);
|
|
33
|
+
setOutputAtIndex(index, unaryOperation(a));
|
|
34
|
+
}
|
|
35
|
+
}
|
|
35
36
|
`;
|
|
36
37
|
}
|
|
37
38
|
}
|
|
38
39
|
function c(t) {
|
|
39
|
-
const { x: e } = t.inputs, n = t.backend, r = new
|
|
40
|
+
const { x: e } = t.inputs, n = t.backend, r = new d(e.shape);
|
|
40
41
|
return n.runWebGPUProgram(r, [e], "float32");
|
|
41
42
|
}
|
|
42
43
|
const l = {
|
|
@@ -44,7 +45,7 @@ const l = {
|
|
|
44
45
|
backendName: "webgpu",
|
|
45
46
|
kernelFunc: c
|
|
46
47
|
};
|
|
47
|
-
|
|
48
|
+
s(l);
|
|
48
49
|
class x {
|
|
49
50
|
// Inputs: dy, x
|
|
50
51
|
variableNames = ["dy", "x"];
|
|
@@ -55,22 +56,23 @@ class x {
|
|
|
55
56
|
workgroupSize = [128, 1, 1];
|
|
56
57
|
size = !0;
|
|
57
58
|
constructor(e) {
|
|
58
|
-
this.outputShape = e, this.dispatchLayout =
|
|
59
|
+
this.outputShape = e, this.dispatchLayout = o(this.outputShape), this.dispatch = p(this.dispatchLayout, this.outputShape, this.workgroupSize);
|
|
59
60
|
}
|
|
60
61
|
getUserCode() {
|
|
61
62
|
return `
|
|
62
|
-
|
|
63
|
+
// TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
|
|
64
|
+
fn tanhComplete(x: f32) -> f32 {
|
|
63
65
|
return select(tanh(x), sign(x), abs(x) > 15.0);
|
|
64
66
|
}
|
|
65
|
-
${
|
|
67
|
+
${a("index")} {
|
|
66
68
|
if (index < uniforms.size) {
|
|
67
69
|
let X = getXByOutputIndex(index);
|
|
68
70
|
let x2 = X * X;
|
|
69
71
|
let x3 = x2 * X;
|
|
70
|
-
let u = ${
|
|
71
|
-
let t =
|
|
72
|
+
let u = ${u} * (X + ${i} * x3);
|
|
73
|
+
let t = tanhComplete(u);
|
|
72
74
|
let sech2 = 1.0 - t * t;
|
|
73
|
-
let du_dx = ${
|
|
75
|
+
let du_dx = ${u} * (1.0 + 3.0 * ${i} * x2);
|
|
74
76
|
let dgelu = 0.5 * (1.0 + t) + 0.5 * X * sech2 * du_dx;
|
|
75
77
|
let DY = getDyByOutputIndex(index);
|
|
76
78
|
setOutputAtIndex(index, DY * dgelu);
|
|
@@ -79,15 +81,15 @@ class x {
|
|
|
79
81
|
}
|
|
80
82
|
}
|
|
81
83
|
function g(t) {
|
|
82
|
-
const { dy: e, x: n } = t.inputs, r = t.backend,
|
|
83
|
-
return r.runWebGPUProgram(
|
|
84
|
+
const { dy: e, x: n } = t.inputs, r = t.backend, h = new x(n.shape);
|
|
85
|
+
return r.runWebGPUProgram(h, [e, n], "float32");
|
|
84
86
|
}
|
|
85
87
|
const f = {
|
|
86
88
|
kernelName: "GeluGrad",
|
|
87
89
|
backendName: "webgpu",
|
|
88
90
|
kernelFunc: g
|
|
89
91
|
};
|
|
90
|
-
|
|
92
|
+
s(f);
|
|
91
93
|
export {
|
|
92
|
-
|
|
94
|
+
d as GeluProgram
|
|
93
95
|
};
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { f as n } from "../../webgpu_util-
|
|
2
|
-
import { f as p, a4 as h } from "../../index-
|
|
1
|
+
import { f as n } from "../../webgpu_util-pLEV9tks.js";
|
|
2
|
+
import { f as p, a4 as h } from "../../index-BzFyqcy-.js";
|
|
3
3
|
import { createReduceInfo as u, reduce as c, createReductionShader as m } from "./utils/reductions.js";
|
|
4
4
|
class d {
|
|
5
5
|
outputShape;
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import { f, a4 as c, e as g } from "../../index-
|
|
1
|
+
import { f, a4 as c, e as g } from "../../index-BzFyqcy-.js";
|
|
2
2
|
import { createReduceInfo as k } from "./utils/reductions.js";
|
|
3
|
-
import { f as x } from "../../webgpu_util-
|
|
4
|
-
import { g as z } from "../../webgpu_program-
|
|
5
|
-
import { s as d } from "../../slice-
|
|
6
|
-
import { s as w } from "../../sum-
|
|
3
|
+
import { f as x } from "../../webgpu_util-pLEV9tks.js";
|
|
4
|
+
import { g as z } from "../../webgpu_program-DkQJOJSd.js";
|
|
5
|
+
import { s as d } from "../../slice-d0Vo9XTN.js";
|
|
6
|
+
import { s as w } from "../../sum-CJ0ULhmt.js";
|
|
7
7
|
class y {
|
|
8
8
|
outputShape;
|
|
9
9
|
shaderKey = "RMSNormGrad";
|
package/dist/ops/webgpu/qkv.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { g as h } from "../../webgpu_program-
|
|
2
|
-
import { f as c, c as d } from "../../webgpu_util-
|
|
3
|
-
import { f as p, a4 as m } from "../../index-
|
|
1
|
+
import { g as h } from "../../webgpu_program-DkQJOJSd.js";
|
|
2
|
+
import { f as c, c as d } from "../../webgpu_util-pLEV9tks.js";
|
|
3
|
+
import { f as p, a4 as m } from "../../index-BzFyqcy-.js";
|
|
4
4
|
class l {
|
|
5
5
|
variableNames = ["x", "kernel"];
|
|
6
6
|
outputShape;
|
package/dist/ops/webgpu/rope.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { g as c } from "../../webgpu_program-
|
|
2
|
-
import { f as m, c as l } from "../../webgpu_util-
|
|
3
|
-
import { f as x, a4 as f } from "../../index-
|
|
1
|
+
import { g as c } from "../../webgpu_program-DkQJOJSd.js";
|
|
2
|
+
import { f as m, c as l } from "../../webgpu_util-pLEV9tks.js";
|
|
3
|
+
import { f as x, a4 as f } from "../../index-BzFyqcy-.js";
|
|
4
4
|
class S {
|
|
5
5
|
variableNames = ["x", "sin", "cos"];
|
|
6
6
|
outputShape;
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { g as p } from "../../webgpu_program-
|
|
2
|
-
import { f as u, c as d } from "../../webgpu_util-
|
|
3
|
-
import { f as h, a4 as o } from "../../index-
|
|
1
|
+
import { g as p } from "../../webgpu_program-DkQJOJSd.js";
|
|
2
|
+
import { f as u, c as d } from "../../webgpu_util-pLEV9tks.js";
|
|
3
|
+
import { f as h, a4 as o } from "../../index-BzFyqcy-.js";
|
|
4
4
|
class b {
|
|
5
5
|
variableNames = ["labels", "softmaxProbs", "dy"];
|
|
6
6
|
outputShape;
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { p as l, j as d } from "../../../index-
|
|
2
|
-
import { g as p } from "../../../webgpu_program-
|
|
3
|
-
import { r as f } from "../../../Reshape-
|
|
4
|
-
import { c as x } from "../../../axis_util-
|
|
1
|
+
import { p as l, j as d } from "../../../index-BzFyqcy-.js";
|
|
2
|
+
import { g as p } from "../../../webgpu_program-DkQJOJSd.js";
|
|
3
|
+
import { r as f } from "../../../Reshape-DUqYftGC.js";
|
|
4
|
+
import { c as x } from "../../../axis_util-TbGYJ208.js";
|
|
5
5
|
function I(e, r, t, s, u) {
|
|
6
6
|
return `
|
|
7
7
|
fn DIV_CEIL(a : u32, b : u32) -> u32 {
|
|
@@ -1,22 +1,22 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { t as w } from "./tensor1d-
|
|
1
|
+
import { E as f, F as c, n as p, G as h, aS as Is, l as D, aT as Ss, aU as Ts, aV as xs, p as Ms, aW as M, w as ls, x as As, b as k, aX as gs, aY as ws, X as qs, ap as Ds, Q as K, aZ as Bs, a_ as Gs, a$ as zs, b0 as Os, b1 as Ls, b2 as Rs, b3 as Ws, b4 as Ks, b5 as vs, a0 as q, b6 as Vs, b7 as Cs, b8 as Ps, b9 as Ys, ba as js, aJ as Fs, bb as Us, t as Zs, bc as Q, bd as Hs, be as Xs, bf as Js, bg as Qs, bh as sn, bi as nn, bj as en, bk as tn, a as g, q as z, o as S, h as rn, c as d, bl as on, J as ss, d as an, a4 as O, z as cn } from "./index-BzFyqcy-.js";
|
|
2
|
+
import { t as w } from "./tensor1d-vML0r3q6.js";
|
|
3
3
|
import { n as un, a as ln, b as pn } from "./non_max_suppression_impl-CsEgBuMA.js";
|
|
4
|
-
import { r as y } from "./reshape-
|
|
5
|
-
import { s as ds } from "./split-
|
|
6
|
-
import { s as E } from "./sum-
|
|
7
|
-
import { b as ns } from "./broadcast_to-
|
|
8
|
-
import { s as x } from "./slice-
|
|
9
|
-
import { r as Z } from "./range-
|
|
10
|
-
import { t as fn } from "./tensor-
|
|
11
|
-
import { s as H } from "./stack-
|
|
12
|
-
import { c as mn, z as hn } from "./zeros-
|
|
13
|
-
import { e as $s } from "./axis_util-
|
|
14
|
-
import { m as es, a as ps, e as os, l as bn } from "./log_sum_exp-
|
|
15
|
-
import { c as ts } from "./concat-
|
|
16
|
-
import { m as G } from "./mat_mul-
|
|
17
|
-
import { t as rs } from "./tensor2d-
|
|
18
|
-
import { o as gn } from "./ones-
|
|
19
|
-
import { r as Es } from "./relu-
|
|
4
|
+
import { r as y } from "./reshape-CnIwVG1c.js";
|
|
5
|
+
import { s as ds } from "./split-DK2k5eHf.js";
|
|
6
|
+
import { s as E } from "./sum-CJ0ULhmt.js";
|
|
7
|
+
import { b as ns } from "./broadcast_to-DzlNweb8.js";
|
|
8
|
+
import { s as x } from "./slice-d0Vo9XTN.js";
|
|
9
|
+
import { r as Z } from "./range-CWcz7xFA.js";
|
|
10
|
+
import { t as fn } from "./tensor-CZr4dh61.js";
|
|
11
|
+
import { s as H } from "./stack-DFatutCx.js";
|
|
12
|
+
import { c as mn, z as hn } from "./zeros-Bj5rMYA7.js";
|
|
13
|
+
import { e as $s } from "./axis_util-TbGYJ208.js";
|
|
14
|
+
import { m as es, a as ps, e as os, l as bn } from "./log_sum_exp-DO6z8tSE.js";
|
|
15
|
+
import { c as ts } from "./concat-B912vBbo.js";
|
|
16
|
+
import { m as G } from "./mat_mul-DzjTFx-u.js";
|
|
17
|
+
import { t as rs } from "./tensor2d-D76QGjF3.js";
|
|
18
|
+
import { o as gn } from "./ones-tIJeHlq-.js";
|
|
19
|
+
import { r as Es } from "./relu-BjCh_SYb.js";
|
|
20
20
|
/**
|
|
21
21
|
* @license
|
|
22
22
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -312,7 +312,7 @@ const An = /* @__PURE__ */ f({ greaterEqual_: Mn });
|
|
|
312
312
|
*/
|
|
313
313
|
function wn(e) {
|
|
314
314
|
const s = { input: c(e, "input", "imag") };
|
|
315
|
-
return h.runKernel(
|
|
315
|
+
return h.runKernel(zs, s);
|
|
316
316
|
}
|
|
317
317
|
const qn = /* @__PURE__ */ f({ imag_: wn });
|
|
318
318
|
/**
|
|
@@ -335,7 +335,7 @@ function Dn(e, r) {
|
|
|
335
335
|
let s = c(e, "a", "less", "string_or_numeric"), t = c(r, "b", "less", "string_or_numeric");
|
|
336
336
|
[s, t] = K(s, t), D(s.shape, t.shape);
|
|
337
337
|
const n = { a: s, b: t };
|
|
338
|
-
return h.runKernel(
|
|
338
|
+
return h.runKernel(Os, n);
|
|
339
339
|
}
|
|
340
340
|
const ms = /* @__PURE__ */ f({ less_: Dn });
|
|
341
341
|
/**
|
|
@@ -381,7 +381,7 @@ function Gn(e) {
|
|
|
381
381
|
const s = { x: c(e, "x", "log1p") };
|
|
382
382
|
return h.runKernel(Rs, s);
|
|
383
383
|
}
|
|
384
|
-
const
|
|
384
|
+
const zn = /* @__PURE__ */ f({ log1p_: Gn });
|
|
385
385
|
/**
|
|
386
386
|
* @license
|
|
387
387
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -398,11 +398,11 @@ const On = /* @__PURE__ */ f({ log1p_: Gn });
|
|
|
398
398
|
* limitations under the License.
|
|
399
399
|
* =============================================================================
|
|
400
400
|
*/
|
|
401
|
-
function
|
|
401
|
+
function On(e) {
|
|
402
402
|
const s = { x: c(e, "x", "neg") };
|
|
403
403
|
return h.runKernel(Ws, s);
|
|
404
404
|
}
|
|
405
|
-
const
|
|
405
|
+
const V = /* @__PURE__ */ f({ neg_: On });
|
|
406
406
|
/**
|
|
407
407
|
* @license
|
|
408
408
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -467,7 +467,7 @@ function vn(e, r) {
|
|
|
467
467
|
let s = c(e, "a", "minimum"), t = c(r, "b", "minimum");
|
|
468
468
|
[s, t] = K(s, t), s.dtype === "bool" && (s = q(s, "int32"), t = q(t, "int32")), D(s.shape, t.shape);
|
|
469
469
|
const n = { a: s, b: t };
|
|
470
|
-
return h.runKernel(
|
|
470
|
+
return h.runKernel(Vs, n);
|
|
471
471
|
}
|
|
472
472
|
const as = /* @__PURE__ */ f({ minimum_: vn });
|
|
473
473
|
/**
|
|
@@ -486,13 +486,13 @@ const as = /* @__PURE__ */ f({ minimum_: vn });
|
|
|
486
486
|
* limitations under the License.
|
|
487
487
|
* =============================================================================
|
|
488
488
|
*/
|
|
489
|
-
function
|
|
489
|
+
function Vn(e, r) {
|
|
490
490
|
let s = c(e, "a", "notEqual", "string_or_numeric"), t = c(r, "b", "notEqual", "string_or_numeric");
|
|
491
491
|
[s, t] = K(s, t), D(s.shape, t.shape);
|
|
492
492
|
const n = { a: s, b: t };
|
|
493
|
-
return h.runKernel(
|
|
493
|
+
return h.runKernel(Cs, n);
|
|
494
494
|
}
|
|
495
|
-
const
|
|
495
|
+
const Cn = /* @__PURE__ */ f({ notEqual_: Vn });
|
|
496
496
|
/**
|
|
497
497
|
* @license
|
|
498
498
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -627,7 +627,7 @@ function Qn(e, r, s) {
|
|
|
627
627
|
const n = { x: t }, o = { perm: r };
|
|
628
628
|
return t.dtype === "complex64" ? Zs(() => {
|
|
629
629
|
let a = Yn(t), i = qn(t);
|
|
630
|
-
return a = h.runKernel(Q, { x: a }, o), i = h.runKernel(Q, { x: i }, o), s && (i =
|
|
630
|
+
return a = h.runKernel(Q, { x: a }, o), i = h.runKernel(Q, { x: i }, o), s && (i = V(i)), mn(a, i);
|
|
631
631
|
}) : h.runKernel(Q, n, o);
|
|
632
632
|
}
|
|
633
633
|
const hs = /* @__PURE__ */ f({ transpose_: Qn });
|
|
@@ -782,7 +782,7 @@ const ue = /* @__PURE__ */ f({ rotateWithOffset_: ce });
|
|
|
782
782
|
* limitations under the License.
|
|
783
783
|
* =============================================================================
|
|
784
784
|
*/
|
|
785
|
-
function
|
|
785
|
+
function C(e, r, s, t, n, o) {
|
|
786
786
|
t == null && (t = 0.5), n == null && (n = Number.NEGATIVE_INFINITY), o == null && (o = 0);
|
|
787
787
|
const a = e.shape[0];
|
|
788
788
|
return s = Math.min(s, a), p(0 <= t && t <= 1, () => `iouThreshold must be in [0, 1], but was '${t}'`), p(e.rank === 2, () => `boxes must be a 2D tensor, but was of rank '${e.rank}'`), p(e.shape[1] === 4, () => `boxes must have 4 columns, but 2nd dimension was ${e.shape[1]}`), p(r.rank === 1, () => "scores must be a 1D tensor"), p(r.shape[0] === a, () => `scores has incompatible shape with boxes. Expected ${a}, but was ${r.shape[0]}`), p(0 <= o && o <= 1, () => `softNmsSigma must be in [0, 1], but was '${o}'`), { maxOutputSize: s, iouThreshold: t, scoreThreshold: n, softNmsSigma: o };
|
|
@@ -804,7 +804,7 @@ function V(e, r, s, t, n, o) {
|
|
|
804
804
|
* =============================================================================
|
|
805
805
|
*/
|
|
806
806
|
function le(e, r, s, t = 0.5, n = Number.NEGATIVE_INFINITY) {
|
|
807
|
-
const o = c(e, "boxes", "nonMaxSuppression", "float32"), a = c(r, "scores", "nonMaxSuppression", "float32"), i =
|
|
807
|
+
const o = c(e, "boxes", "nonMaxSuppression", "float32"), a = c(r, "scores", "nonMaxSuppression", "float32"), i = C(o, a, s, t, n);
|
|
808
808
|
s = i.maxOutputSize, t = i.iouThreshold, n = i.scoreThreshold;
|
|
809
809
|
const l = { maxOutputSize: s, iouThreshold: t, scoreThreshold: n };
|
|
810
810
|
return h.runKernel(Qs, { boxes: o, scores: a }, l);
|
|
@@ -827,7 +827,7 @@ const pe = /* @__PURE__ */ f({ nonMaxSuppression_: le });
|
|
|
827
827
|
* =============================================================================
|
|
828
828
|
*/
|
|
829
829
|
async function fe(e, r, s, t = 0.5, n = Number.NEGATIVE_INFINITY) {
|
|
830
|
-
const o = c(e, "boxes", "nonMaxSuppressionAsync"), a = c(r, "scores", "nonMaxSuppressionAsync"), i =
|
|
830
|
+
const o = c(e, "boxes", "nonMaxSuppressionAsync"), a = c(r, "scores", "nonMaxSuppressionAsync"), i = C(o, a, s, t, n);
|
|
831
831
|
s = i.maxOutputSize, t = i.iouThreshold, n = i.scoreThreshold;
|
|
832
832
|
const l = await Promise.all([o.data(), a.data()]), u = l[0], m = l[1], { selectedIndices: b } = un(u, m, s, t, n);
|
|
833
833
|
return o !== e && o.dispose(), a !== r && a.dispose(), w(b, "int32");
|
|
@@ -850,7 +850,7 @@ const me = fe;
|
|
|
850
850
|
* =============================================================================
|
|
851
851
|
*/
|
|
852
852
|
function he(e, r, s, t = 0.5, n = Number.NEGATIVE_INFINITY, o = 0) {
|
|
853
|
-
const a = c(e, "boxes", "nonMaxSuppression"), i = c(r, "scores", "nonMaxSuppression"), l =
|
|
853
|
+
const a = c(e, "boxes", "nonMaxSuppression"), i = c(r, "scores", "nonMaxSuppression"), l = C(a, i, s, t, n, o);
|
|
854
854
|
s = l.maxOutputSize, t = l.iouThreshold, n = l.scoreThreshold, o = l.softNmsSigma;
|
|
855
855
|
const u = { boxes: a, scores: i }, m = { maxOutputSize: s, iouThreshold: t, scoreThreshold: n, softNmsSigma: o }, b = h.runKernel(sn, u, m);
|
|
856
856
|
return { selectedIndices: b[0], selectedScores: b[1] };
|
|
@@ -873,7 +873,7 @@ const be = /* @__PURE__ */ f({ nonMaxSuppressionWithScore_: he });
|
|
|
873
873
|
* =============================================================================
|
|
874
874
|
*/
|
|
875
875
|
async function ge(e, r, s, t = 0.5, n = Number.NEGATIVE_INFINITY, o = 0) {
|
|
876
|
-
const a = c(e, "boxes", "nonMaxSuppressionAsync"), i = c(r, "scores", "nonMaxSuppressionAsync"), l =
|
|
876
|
+
const a = c(e, "boxes", "nonMaxSuppressionAsync"), i = c(r, "scores", "nonMaxSuppressionAsync"), l = C(a, i, s, t, n, o);
|
|
877
877
|
s = l.maxOutputSize, t = l.iouThreshold, n = l.scoreThreshold, o = l.softNmsSigma;
|
|
878
878
|
const u = await Promise.all([a.data(), i.data()]), m = u[0], b = u[1], { selectedIndices: _, selectedScores: $ } = ln(m, b, s, t, n, o);
|
|
879
879
|
return a !== e && a.dispose(), i !== r && i.dispose(), {
|
|
@@ -899,7 +899,7 @@ const de = ge;
|
|
|
899
899
|
* =============================================================================
|
|
900
900
|
*/
|
|
901
901
|
function $e(e, r, s, t = 0.5, n = Number.NEGATIVE_INFINITY, o = !1) {
|
|
902
|
-
const a = c(e, "boxes", "nonMaxSuppression"), i = c(r, "scores", "nonMaxSuppression"), l =
|
|
902
|
+
const a = c(e, "boxes", "nonMaxSuppression"), i = c(r, "scores", "nonMaxSuppression"), l = C(
|
|
903
903
|
a,
|
|
904
904
|
i,
|
|
905
905
|
s,
|
|
@@ -933,7 +933,7 @@ const Ee = /* @__PURE__ */ f({ nonMaxSuppressionPadded_: $e });
|
|
|
933
933
|
* =============================================================================
|
|
934
934
|
*/
|
|
935
935
|
async function _e(e, r, s, t = 0.5, n = Number.NEGATIVE_INFINITY, o = !1) {
|
|
936
|
-
const a = c(e, "boxes", "nonMaxSuppressionAsync"), i = c(r, "scores", "nonMaxSuppressionAsync"), l =
|
|
936
|
+
const a = c(e, "boxes", "nonMaxSuppressionAsync"), i = c(r, "scores", "nonMaxSuppressionAsync"), l = C(
|
|
937
937
|
a,
|
|
938
938
|
i,
|
|
939
939
|
s,
|
|
@@ -1020,7 +1020,7 @@ function Te(e, r = "binary", s = !1, t = 0.5) {
|
|
|
1020
1020
|
if (p(n.rank === 3, () => `Error in threshold: image must be rank 3,but got rank ${n.rank}.`), p(n.shape[2] === 3 || n.shape[2] === 1, () => `Error in threshold: image color channel must be equal to 3 or 1but got ${n.shape[2]}.`), p(n.dtype === "int32" || n.dtype === "float32", () => `Error in dtype: image dtype must be int32 or float32,but got dtype ${n.dtype}.`), p(r === "otsu" || r === "binary", () => `Method must be binary or otsu, but was ${r}`), n.shape[2] === 3) {
|
|
1021
1021
|
[m, b, _] = ds(n, [1, 1, 1], -1);
|
|
1022
1022
|
const T = g(m, o), B = g(b, a), R = g(_, i);
|
|
1023
|
-
$ =
|
|
1023
|
+
$ = z(z(T, B), R);
|
|
1024
1024
|
} else
|
|
1025
1025
|
$ = e;
|
|
1026
1026
|
if (r === "otsu") {
|
|
@@ -1036,7 +1036,7 @@ function xe(e, r) {
|
|
|
1036
1036
|
o = x(e, 0, b + 1), a = x(e, b + 1), u = S(E(o), r), m = S(E(a), r);
|
|
1037
1037
|
const _ = E(g(o, Z(0, o.size)));
|
|
1038
1038
|
i = S(_, E(o));
|
|
1039
|
-
const $ = rn(a.shape, o.size), N =
|
|
1039
|
+
const $ = rn(a.shape, o.size), N = z(Z(0, a.size), $), A = g(a, N);
|
|
1040
1040
|
l = S(E(A), E(a));
|
|
1041
1041
|
const T = d(i, l), B = d(i, l), R = g(u, m);
|
|
1042
1042
|
n = g(g(R, T), B);
|
|
@@ -1091,7 +1091,7 @@ function qe(e, r, s) {
|
|
|
1091
1091
|
const n = t.shape, [o, a] = t.shape.slice(-2);
|
|
1092
1092
|
let i, l;
|
|
1093
1093
|
typeof r == "number" ? (p(r % 1 === 0, () => `bandPart(): numLower must be an integer, got ${r}.`), p(r <= o, () => `bandPart(): numLower (${r}) must not be greater than the number of rows (${o}).`), i = c(r < 0 ? o : r, "numLower", "bandPart")) : (p(r.dtype === "int32", () => "bandPart(): numLower's dtype must be an int32."), i = v(ms(r, 0), o, as(r, o))), typeof s == "number" ? (p(s % 1 === 0, () => `bandPart(): numUpper must be an integer, got ${s}.`), p(s <= a, () => `bandPart(): numUpper (${s}) must not be greater than the number of columns (${a}).`), l = c(s < 0 ? a : s, "numUpper", "bandPart")) : (p(s.dtype === "int32", () => "bandPart(): numUpper's dtype must be an int32."), l = v(ms(s, 0), a, as(s, a)));
|
|
1094
|
-
const u = y(Z(0, o, 1, "int32"), [-1, 1]), m = Z(0, a, 1, "int32"), b = d(u, m), _ = Rn(ys(b, i), An(b,
|
|
1094
|
+
const u = y(Z(0, o, 1, "int32"), [-1, 1]), m = Z(0, a, 1, "int32"), b = d(u, m), _ = Rn(ys(b, i), An(b, V(l))), $ = hn([o, a], t.dtype);
|
|
1095
1095
|
return y(H(Ns(y(t, [-1, o, a])).map((N) => v(_, N, $))), n);
|
|
1096
1096
|
}
|
|
1097
1097
|
const De = /* @__PURE__ */ f({ bandPart_: qe });
|
|
@@ -1151,7 +1151,7 @@ const Ge = /* @__PURE__ */ f({ gramSchmidt_: Be });
|
|
|
1151
1151
|
* limitations under the License.
|
|
1152
1152
|
* =============================================================================
|
|
1153
1153
|
*/
|
|
1154
|
-
function
|
|
1154
|
+
function ze(e, r = !1) {
|
|
1155
1155
|
if (p(e.rank >= 2, () => `qr() requires input tensor to have a rank >= 2, but got rank ${e.rank}`), e.rank === 2)
|
|
1156
1156
|
return bs(e, r);
|
|
1157
1157
|
{
|
|
@@ -1184,7 +1184,7 @@ function bs(e, r = !1) {
|
|
|
1184
1184
|
a,
|
|
1185
1185
|
x(R, [1, 0], [R.shape[0] - 1, R.shape[1]])
|
|
1186
1186
|
], 0);
|
|
1187
|
-
const Y =
|
|
1187
|
+
const Y = V(S(G(T, B), N)), j = x(o, [u, 0], [s - u, t]), X = g(Y, i), cs = hs(i);
|
|
1188
1188
|
if (u === 0)
|
|
1189
1189
|
o = d(j, G(X, G(cs, j)));
|
|
1190
1190
|
else {
|
|
@@ -1204,7 +1204,7 @@ function bs(e, r = !1) {
|
|
|
1204
1204
|
return !r && s > t && (n = x(n, [0, 0], [s, t]), o = x(o, [0, 0], [t, t])), [n, o];
|
|
1205
1205
|
});
|
|
1206
1206
|
}
|
|
1207
|
-
const
|
|
1207
|
+
const Oe = /* @__PURE__ */ f({ qr_: ze });
|
|
1208
1208
|
/**
|
|
1209
1209
|
* @license
|
|
1210
1210
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1246,7 +1246,7 @@ function Le(e, r, s = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
|
1246
1246
|
if (n == null)
|
|
1247
1247
|
return S(E(o), k(t.size));
|
|
1248
1248
|
{
|
|
1249
|
-
const a = g(n, gn(t.shape)), i = q(E(
|
|
1249
|
+
const a = g(n, gn(t.shape)), i = q(E(Cn(a, k(0))), "float32");
|
|
1250
1250
|
return S(E(o), i);
|
|
1251
1251
|
}
|
|
1252
1252
|
}
|
|
@@ -1272,7 +1272,7 @@ const L = /* @__PURE__ */ f({ computeWeightedLoss_: Le });
|
|
|
1272
1272
|
function Re(e, r, s, t = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
1273
1273
|
const n = c(e, "labels", "absoluteDifference"), o = c(r, "predictions", "absoluteDifference");
|
|
1274
1274
|
let a = null;
|
|
1275
|
-
s != null && (a = c(s, "weights", "absoluteDifference")),
|
|
1275
|
+
s != null && (a = c(s, "weights", "absoluteDifference")), O(n.shape, o.shape, "Error in absoluteDifference: ");
|
|
1276
1276
|
const i = M(d(n, o));
|
|
1277
1277
|
return L(i, a, t);
|
|
1278
1278
|
}
|
|
@@ -1280,22 +1280,22 @@ const We = /* @__PURE__ */ f({ absoluteDifference_: Re });
|
|
|
1280
1280
|
function Ke(e, r, s, t, n = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
1281
1281
|
const o = c(e, "labels", "cosineDistance"), a = c(r, "predictions", "cosineDistance");
|
|
1282
1282
|
let i = null;
|
|
1283
|
-
t != null && (i = c(t, "weights", "cosineDistance")),
|
|
1283
|
+
t != null && (i = c(t, "weights", "cosineDistance")), O(o.shape, a.shape, "Error in cosineDistance: ");
|
|
1284
1284
|
const l = k(1), u = d(l, E(g(o, a), s, !0));
|
|
1285
1285
|
return L(u, i, n);
|
|
1286
1286
|
}
|
|
1287
1287
|
const ve = /* @__PURE__ */ f({ cosineDistance_: Ke });
|
|
1288
|
-
function
|
|
1288
|
+
function Ve(e, r, s, t = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
1289
1289
|
let n = c(e, "labels", "hingeLoss");
|
|
1290
1290
|
const o = c(r, "predictions", "hingeLoss");
|
|
1291
1291
|
let a = null;
|
|
1292
|
-
s != null && (a = c(s, "weights", "hingeLoss")),
|
|
1292
|
+
s != null && (a = c(s, "weights", "hingeLoss")), O(n.shape, o.shape, "Error in hingeLoss: ");
|
|
1293
1293
|
const i = k(1);
|
|
1294
1294
|
n = d(g(k(2), n), i);
|
|
1295
1295
|
const l = Es(d(i, g(n, o)));
|
|
1296
1296
|
return L(l, a, t);
|
|
1297
1297
|
}
|
|
1298
|
-
const
|
|
1298
|
+
const Ce = /* @__PURE__ */ f({ hingeLoss_: Ve });
|
|
1299
1299
|
/**
|
|
1300
1300
|
* @license
|
|
1301
1301
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -1315,8 +1315,8 @@ const Ve = /* @__PURE__ */ f({ hingeLoss_: Ce });
|
|
|
1315
1315
|
function Pe(e, r, s, t = 1, n = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
1316
1316
|
const o = c(e, "labels", "huberLoss"), a = c(r, "predictions", "huberLoss");
|
|
1317
1317
|
let i = null;
|
|
1318
|
-
s != null && (i = c(s, "weights", "huberLoss")),
|
|
1319
|
-
const l = k(t), u = M(d(a, o)), m = as(u, l), b = d(u, m), _ =
|
|
1318
|
+
s != null && (i = c(s, "weights", "huberLoss")), O(o.shape, a.shape, "Error in huberLoss: ");
|
|
1319
|
+
const l = k(t), u = M(d(a, o)), m = as(u, l), b = d(u, m), _ = z(g(k(0.5), gs(m)), g(l, b));
|
|
1320
1320
|
return L(_, i, n);
|
|
1321
1321
|
}
|
|
1322
1322
|
const Ye = /* @__PURE__ */ f({ huberLoss_: Pe });
|
|
@@ -1339,8 +1339,8 @@ const Ye = /* @__PURE__ */ f({ huberLoss_: Pe });
|
|
|
1339
1339
|
function je(e, r, s, t = 1e-7, n = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
1340
1340
|
const o = c(e, "labels", "logLoss"), a = c(r, "predictions", "logLoss");
|
|
1341
1341
|
let i = null;
|
|
1342
|
-
s != null && (i = c(s, "weights", "logLoss")),
|
|
1343
|
-
const l = k(1), u = k(t), m =
|
|
1342
|
+
s != null && (i = c(s, "weights", "logLoss")), O(o.shape, a.shape, "Error in logLoss: ");
|
|
1343
|
+
const l = k(1), u = k(t), m = V(g(o, ps(z(a, u)))), b = g(d(l, o), ps(z(d(l, a), u))), _ = d(m, b);
|
|
1344
1344
|
return L(_, i, n);
|
|
1345
1345
|
}
|
|
1346
1346
|
const Fe = /* @__PURE__ */ f({ logLoss_: je });
|
|
@@ -1363,7 +1363,7 @@ const Fe = /* @__PURE__ */ f({ logLoss_: je });
|
|
|
1363
1363
|
function Ue(e, r, s, t = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
1364
1364
|
const n = c(e, "labels", "meanSquaredError"), o = c(r, "predictions", "meanSquaredError");
|
|
1365
1365
|
let a = null;
|
|
1366
|
-
s != null && (a = c(s, "weights", "meanSquaredError")),
|
|
1366
|
+
s != null && (a = c(s, "weights", "meanSquaredError")), O(n.shape, o.shape, "Error in meanSquaredError: ");
|
|
1367
1367
|
const i = Zn(n, o);
|
|
1368
1368
|
return L(i, a, t);
|
|
1369
1369
|
}
|
|
@@ -1386,17 +1386,17 @@ const Ze = /* @__PURE__ */ f({ meanSquaredError_: Ue });
|
|
|
1386
1386
|
*/
|
|
1387
1387
|
function He(e, r) {
|
|
1388
1388
|
const s = c(e, "labels", "sigmoidCrossEntropyWithLogits"), t = c(r, "logits", "sigmoidCrossEntropyWithLogits");
|
|
1389
|
-
|
|
1390
|
-
const n = Es(t), o = g(t, s), a =
|
|
1391
|
-
return
|
|
1389
|
+
O(s.shape, t.shape, "Error in sigmoidCrossEntropyWithLogits: ");
|
|
1390
|
+
const n = Es(t), o = g(t, s), a = zn(os(V(M(t))));
|
|
1391
|
+
return z(d(n, o), a);
|
|
1392
1392
|
}
|
|
1393
1393
|
function Xe(e, r, s, t = 0, n = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
1394
1394
|
let o = c(e, "multiClassLabels", "sigmoidCrossEntropy");
|
|
1395
1395
|
const a = c(r, "logits", "sigmoidCrossEntropy");
|
|
1396
1396
|
let i = null;
|
|
1397
|
-
if (s != null && (i = c(s, "weights", "sigmoidCrossEntropy")),
|
|
1397
|
+
if (s != null && (i = c(s, "weights", "sigmoidCrossEntropy")), O(o.shape, a.shape, "Error in sigmoidCrossEntropy: "), t > 0) {
|
|
1398
1398
|
const u = k(t), m = k(1), b = k(0.5);
|
|
1399
|
-
o =
|
|
1399
|
+
o = z(g(o, d(m, u)), g(b, u));
|
|
1400
1400
|
}
|
|
1401
1401
|
const l = He(o, a);
|
|
1402
1402
|
return L(l, i, n);
|
|
@@ -1424,7 +1424,7 @@ function Qe(e, r, s = -1) {
|
|
|
1424
1424
|
return cn((n, o, a) => {
|
|
1425
1425
|
const l = bn(o, [s], !0), u = d(q(o, "float32"), l);
|
|
1426
1426
|
a([n, u]);
|
|
1427
|
-
const m =
|
|
1427
|
+
const m = V(g(u, n));
|
|
1428
1428
|
return { value: E(m, [s]), gradFunc: ($, N) => {
|
|
1429
1429
|
const [A, T] = N, B = $s($.shape, [s]);
|
|
1430
1430
|
return [
|
|
@@ -1438,9 +1438,9 @@ function st(e, r, s, t = 0, n = I.SUM_BY_NONZERO_WEIGHTS) {
|
|
|
1438
1438
|
let o = c(e, "onehotLabels", "softmaxCrossEntropy");
|
|
1439
1439
|
const a = c(r, "logits", "softmaxCrossEntropy");
|
|
1440
1440
|
let i = null;
|
|
1441
|
-
if (s != null && (i = c(s, "weights", "softmaxCrossEntropy")),
|
|
1441
|
+
if (s != null && (i = c(s, "weights", "softmaxCrossEntropy")), O(o.shape, a.shape, "Error in softmaxCrossEntropy: "), t > 0) {
|
|
1442
1442
|
const u = k(t), m = k(1), b = k(o.shape[1]);
|
|
1443
|
-
o =
|
|
1443
|
+
o = z(g(o, d(m, u)), S(u, b));
|
|
1444
1444
|
}
|
|
1445
1445
|
const l = Qe(o, a);
|
|
1446
1446
|
return L(l, i, n);
|
|
@@ -1481,12 +1481,12 @@ const kt = {
|
|
|
1481
1481
|
}, yt = {
|
|
1482
1482
|
bandPart: De,
|
|
1483
1483
|
gramSchmidt: Ge,
|
|
1484
|
-
qr:
|
|
1484
|
+
qr: Oe
|
|
1485
1485
|
}, Nt = {
|
|
1486
1486
|
absoluteDifference: We,
|
|
1487
1487
|
computeWeightedLoss: L,
|
|
1488
1488
|
cosineDistance: ve,
|
|
1489
|
-
hingeLoss:
|
|
1489
|
+
hingeLoss: Ce,
|
|
1490
1490
|
huberLoss: Ye,
|
|
1491
1491
|
logLoss: Fe,
|
|
1492
1492
|
meanSquaredError: Ze,
|
|
@@ -1505,13 +1505,13 @@ export {
|
|
|
1505
1505
|
An as g,
|
|
1506
1506
|
W as h,
|
|
1507
1507
|
Tn as i,
|
|
1508
|
-
|
|
1508
|
+
zn as j,
|
|
1509
1509
|
as as k,
|
|
1510
1510
|
yt as l,
|
|
1511
1511
|
Kn as m,
|
|
1512
|
-
|
|
1512
|
+
V as n,
|
|
1513
1513
|
kt as o,
|
|
1514
|
-
|
|
1514
|
+
Cn as p,
|
|
1515
1515
|
fs as q,
|
|
1516
1516
|
ks as r,
|
|
1517
1517
|
Xn as s,
|
|
@@ -1,26 +1,26 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { k as ke, c as xt, o as ze, s as Qi, b as tr, d as Bu, m as er, t as In, l as nr, v as os, a as Wu, S as Gu, p as Pu, w as as, x as sr, y as Uu, z as Vu } from "./selu_util-
|
|
3
|
-
import { m as nt, n as wt, w as ne, b as qe, g as Je, c as ls, t as K, d as Ce, e as Yt, f as ju, u as dn, h as ye, i as Ku, l as Hu, j as qu, s as us, k as ir, o as qt, p as Hn, q as Ju } from "./ops-
|
|
4
|
-
import { r as N } from "./reshape-
|
|
5
|
-
import { s as W } from "./sum-
|
|
6
|
-
import { m as ct } from "./mat_mul-
|
|
7
|
-
import { s as Qt } from "./split-
|
|
8
|
-
import { s as Zu, c as rr } from "./sin-
|
|
9
|
-
import { e as qn, g as or, h as cs, c as Yu } from "./axis_util-
|
|
10
|
-
import { m as Ee, a as se, e as ie, l as Xu } from "./log_sum_exp-
|
|
11
|
-
import { s as Dn } from "./stack-
|
|
12
|
-
import { o as xe } from "./ones-
|
|
13
|
-
import { s as Dt } from "./slice-
|
|
14
|
-
import { M as Qu, f as ar, r as tc, d as ec, a as $n } from "./dropout-
|
|
15
|
-
import { z as Nt } from "./zeros-
|
|
16
|
-
import { c as pe } from "./concat-
|
|
17
|
-
import { g as lr } from "./gather-
|
|
18
|
-
import { t as fn } from "./tensor1d-
|
|
19
|
-
import { r as Ze } from "./relu-
|
|
20
|
-
import { s as ur } from "./softmax-
|
|
21
|
-
import { t as nc } from "./tensor-
|
|
22
|
-
import { r as sc } from "./range-
|
|
23
|
-
import { v as ic } from "./variable-
|
|
1
|
+
import { E as z, F as D, G as O, bJ as La, bK as Fa, bL as Ci, n as b, bM as Ii, a0 as L, bN as Di, bO as $i, bP as Ti, bQ as zi, h as Ei, bR as Li, bS as Fi, bT as Oi, bU as Mi, bV as Oa, bW as Ri, bX as Ma, bY as _i, bZ as Ra, b_ as Bi, Q as hn, l as vt, bx as _a, b$ as Wi, c0 as Gi, z as Ge, c as V, a as w, c1 as Ba, c2 as Pi, c3 as Ui, p as ce, aX as pt, c4 as Vi, c5 as ji, c6 as Ki, c7 as Hi, c8 as qi, bD as Ji, c9 as Zi, ca as Yi, M as Wa, ap as Ga, aq as Pa, cb as Xi, cc as Ua, q as T, cd as Ps, ce as Va, cf as ja, j as pn, cg as Us, ch as Ka, ci as Ha, cj as qa, ck as Ja, cl as Za, cm as Ya, cn as Xa, bm as Qa, co as tl, w as he, b as et, o as U, cp as el, bt as nl, ax as ht, cq as sl, A as Q, cr as il, cs as rl, ct as ol, cu as al, cv as ll, cw as ul, cx as cl, cy as hl, U as pl, cz as dl, bq as fl, bw as ml, cA as gl, K as bl, cB as yl, ac as wl, cC as kl, cD as xl, cE as Nl, as as vl, cF as Al, ai as Sl, aY as Cl, by as Il, ao as Dl, bz as $l, H as Tl, a_ as zl, az as El, cG as Ll, cH as Fl, cI as Ol, at as Ml, b2 as Rl, aj as _l, cJ as Bl, cK as Wl, cL as Gl, ah as Pl, bA as Ul, cM as Vl, cN as jl, b5 as Kl, aV as Hl, b6 as ql, cO as Jl, Y as Zl, bB as Yl, b3 as Xl, P as Ql, cP as tu, x as rs, au as eu, bC as nu, aC as su, Z as iu, av as ru, _ as ou, V as au, bj as lu, cQ as uu, bk as cu, cR as hu, b9 as pu, aT as du, ar as fu, cS as mu, ad as gu, $ as bu, S as yu, W as wu, bF as ku, cT as xu, ba as Nu, aw as vu, bH as Au, a1 as Su, cU as Cu, X as Iu, bc as Du, bb as $u, cV as Oe, cW as Tu, i as zu, an as Vs, cX as Eu, t as x, aW as $e, cY as S, cZ as Ke, c_ as He, af as Vt, d as J, ag as Lu, c$ as js, k as Jt, J as Fu, T as Te, O as Ou, d0 as Mu, m as Ks, d1 as Ru, d2 as Hs, d3 as _u } from "./index-BzFyqcy-.js";
|
|
2
|
+
import { k as ke, c as xt, o as ze, s as Qi, b as tr, d as Bu, m as er, t as In, l as nr, v as os, a as Wu, S as Gu, p as Pu, w as as, x as sr, y as Uu, z as Vu } from "./selu_util-OtRzVwW5.js";
|
|
3
|
+
import { m as nt, n as wt, w as ne, b as qe, g as Je, c as ls, t as K, d as Ce, e as Yt, f as ju, u as dn, h as ye, i as Ku, l as Hu, j as qu, s as us, k as ir, o as qt, p as Hn, q as Ju } from "./ops-LuCMAnmM.js";
|
|
4
|
+
import { r as N } from "./reshape-CnIwVG1c.js";
|
|
5
|
+
import { s as W } from "./sum-CJ0ULhmt.js";
|
|
6
|
+
import { m as ct } from "./mat_mul-DzjTFx-u.js";
|
|
7
|
+
import { s as Qt } from "./split-DK2k5eHf.js";
|
|
8
|
+
import { s as Zu, c as rr } from "./sin-gpDNRxE0.js";
|
|
9
|
+
import { e as qn, g as or, h as cs, c as Yu } from "./axis_util-TbGYJ208.js";
|
|
10
|
+
import { m as Ee, a as se, e as ie, l as Xu } from "./log_sum_exp-DO6z8tSE.js";
|
|
11
|
+
import { s as Dn } from "./stack-DFatutCx.js";
|
|
12
|
+
import { o as xe } from "./ones-tIJeHlq-.js";
|
|
13
|
+
import { s as Dt } from "./slice-d0Vo9XTN.js";
|
|
14
|
+
import { M as Qu, f as ar, r as tc, d as ec, a as $n } from "./dropout-C-csYCLj.js";
|
|
15
|
+
import { z as Nt } from "./zeros-Bj5rMYA7.js";
|
|
16
|
+
import { c as pe } from "./concat-B912vBbo.js";
|
|
17
|
+
import { g as lr } from "./gather-Dnpgw-YQ.js";
|
|
18
|
+
import { t as fn } from "./tensor1d-vML0r3q6.js";
|
|
19
|
+
import { r as Ze } from "./relu-BjCh_SYb.js";
|
|
20
|
+
import { s as ur } from "./softmax-D7Jj3p_P.js";
|
|
21
|
+
import { t as nc } from "./tensor-CZr4dh61.js";
|
|
22
|
+
import { r as sc } from "./range-CWcz7xFA.js";
|
|
23
|
+
import { v as ic } from "./variable-Bm2OFwGI.js";
|
|
24
24
|
/**
|
|
25
25
|
* @license
|
|
26
26
|
* Copyright 2020 Google LLC. All Rights Reserved.
|