@genai-fi/nanogpt 0.4.3 → 0.4.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/Generator.js +3 -3
- package/dist/NanoGPTModel.js +8 -8
- package/dist/Reshape-CiAY8ltP.js +212 -0
- package/dist/TeachableLLM.js +14 -5
- package/dist/{TiedEmbedding-CnJ1bx4q.js → TiedEmbedding-DznFwzcB.js} +244 -244
- package/dist/{axis_util-BgTGy5w8.js → axis_util-QP0LdI1v.js} +1 -1
- package/dist/{concat-CuRsVY-K.js → concat-DvWM7HGZ.js} +1 -1
- package/dist/data/parquet.js +9 -6
- package/dist/data/textLoader.js +6 -5
- package/dist/{dropout-DfDdklfL.js → dropout-DFEXTPV0.js} +4 -4
- package/dist/{gather-ZYRWhmXR.js → gather-C5D8PxwA.js} +1 -1
- package/dist/gpgpu_math-CUzjlO9A.js +23 -0
- package/dist/{index-C4JCoBvj.js → index--6vO-cOz.js} +87 -87
- package/dist/{kernel_funcs_utils-CAd1h9X1.js → kernel_funcs_utils-C6YBCuOt.js} +72 -91
- package/dist/layers/CausalSelfAttention.js +47 -46
- package/dist/layers/MLP.js +31 -33
- package/dist/layers/RMSNorm.d.ts +1 -2
- package/dist/layers/RMSNorm.js +10 -10
- package/dist/layers/RoPECache.js +3 -3
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.js +2 -2
- package/dist/{log_sum_exp-BswFnwOb.js → log_sum_exp-CiEy1aUe.js} +7 -7
- package/dist/main.js +28 -19
- package/dist/{mat_mul-415y5Qn2.js → mat_mul-BEHRPMh0.js} +1 -1
- package/dist/{max-CP_9O2Yd.js → max-BUShNgfh.js} +1 -1
- package/dist/{moments-CjeIaVdp.js → moments-DYOHXoRV.js} +5 -5
- package/dist/{norm-CZM380I3.js → norm-DSva3hI3.js} +13 -13
- package/dist/{ones-Bf3YR48P.js → ones-D6kB8bdY.js} +2 -2
- package/dist/ops/appendCache.js +3 -3
- package/dist/ops/attentionMask.js +1 -1
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +2 -2
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +4 -4
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.d.ts +1 -0
- package/dist/ops/cpu/matMulGelu.js +40 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.d.ts +1 -0
- package/dist/ops/cpu/normRMS.js +39 -0
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +4 -4
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +2 -2
- package/dist/ops/grads/attentionMask.js +1 -1
- package/dist/ops/grads/fusedSoftmax.js +2 -2
- package/dist/ops/grads/gelu.js +24 -3
- package/dist/ops/grads/matMulGelu.d.ts +1 -0
- package/dist/ops/grads/matMulGelu.js +17 -0
- package/dist/ops/grads/normRMS.d.ts +2 -0
- package/dist/ops/grads/normRMS.js +20 -0
- package/dist/ops/grads/qkv.js +1 -1
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.d.ts +3 -0
- package/dist/ops/matMulGelu.js +14 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.d.ts +2 -0
- package/dist/ops/normRMS.js +10 -0
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +1 -1
- package/dist/ops/webgl/fusedSoftmax.js +689 -895
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.d.ts +21 -0
- package/dist/ops/webgl/matMulGelu.js +168 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.d.ts +1 -0
- package/dist/ops/webgl/normRMS.js +78 -0
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/{range-9AzeApCc.js → range-C_vpUjBu.js} +1 -1
- package/dist/{reshape-Boe4DuIO.js → reshape-z51Eu-re.js} +1 -1
- package/dist/{sin-KmhiDuMa.js → sin-H567uayl.js} +1 -1
- package/dist/{slice_util-19zDNNSn.js → slice_util-BdhYwFY_.js} +2 -2
- package/dist/{softmax-Cujsg4ay.js → softmax-Dsxflvdl.js} +1 -1
- package/dist/{split-DbcNm1-i.js → split-B_k_jwud.js} +1 -1
- package/dist/{stack-D1YjmgKN.js → stack-CmqSdsfs.js} +1 -1
- package/dist/{sum-R28pucR5.js → sum-DdkDf2MG.js} +1 -1
- package/dist/{tensor-BVeHdl7V.js → tensor-BGYi41cj.js} +1 -1
- package/dist/{tensor2d-DqFGNs_K.js → tensor2d-DUr_htjt.js} +1 -1
- package/dist/{tfjs_backend-Cug-PH75.js → tfjs_backend-DuKis_xG.js} +46 -46
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +18 -18
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +5 -5
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +2 -2
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/weights.js +2 -2
- package/dist/{variable-LJT9Ld63.js → variable-BJTZ3jOy.js} +1 -1
- package/dist/{zeros-dnQxFgAD.js → zeros-8xl-W2DC.js} +1 -1
- package/package.json +1 -1
- package/dist/gelu-CnCt17Lk.js +0 -26
package/dist/data/parquet.js
CHANGED
|
@@ -1,14 +1,17 @@
|
|
|
1
1
|
import { B as n } from "../index-Tf7vU29b.js";
|
|
2
2
|
const p = 100 * 1024 * 1024;
|
|
3
|
-
async function
|
|
4
|
-
const
|
|
3
|
+
async function d(i, s = p, e = "text") {
|
|
4
|
+
const r = await (await import("../parquet-C0Tlmv9c.js").then((t) => t.p)).ParquetReader.openBuffer(n.from(await i.arrayBuffer())), a = [], f = r.getCursor([[e]]);
|
|
5
5
|
let o = 0;
|
|
6
6
|
for (; ; ) {
|
|
7
|
-
const t = await
|
|
8
|
-
if (!t ||
|
|
7
|
+
const t = await f.next();
|
|
8
|
+
if (!t || t[e] === void 0 || typeof t[e] != "string")
|
|
9
|
+
break;
|
|
10
|
+
if (t[e].length !== 0 && (a.push(t[e]), o += t[e].length, o > s))
|
|
11
|
+
break;
|
|
9
12
|
}
|
|
10
|
-
return
|
|
13
|
+
return r.close(), a;
|
|
11
14
|
}
|
|
12
15
|
export {
|
|
13
|
-
|
|
16
|
+
d as loadParquet
|
|
14
17
|
};
|
package/dist/data/textLoader.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { p as s } from "../papaparse.min-C8l2Kvo1.js";
|
|
2
|
-
import { loadParquet as
|
|
3
|
-
import { loadPDF as
|
|
2
|
+
import { loadParquet as m } from "./parquet.js";
|
|
3
|
+
import { loadPDF as u } from "./pdf.js";
|
|
4
4
|
import { loadDOCX as f } from "./docx.js";
|
|
5
5
|
function l(e, t) {
|
|
6
6
|
const r = e.findIndex((n) => n.toLowerCase() === t.toLowerCase());
|
|
@@ -31,9 +31,9 @@ function h(e) {
|
|
|
31
31
|
async function C(e, t) {
|
|
32
32
|
const r = e.type !== "" ? e.type : h(e.name);
|
|
33
33
|
if (r === "application/parquet")
|
|
34
|
-
return
|
|
34
|
+
return m(e, t?.maxSize, t?.column);
|
|
35
35
|
if (r === "application/pdf")
|
|
36
|
-
return
|
|
36
|
+
return u(e, t?.maxSize);
|
|
37
37
|
if (r === "application/vnd.openxmlformats-officedocument.wordprocessingml.document")
|
|
38
38
|
return f(e);
|
|
39
39
|
if (r === "text/csv") {
|
|
@@ -42,9 +42,10 @@ async function C(e, t) {
|
|
|
42
42
|
s.parse(n, {
|
|
43
43
|
header: !1,
|
|
44
44
|
skipEmptyLines: !0,
|
|
45
|
+
delimiter: ",",
|
|
45
46
|
complete: (a) => {
|
|
46
47
|
if (a.errors.length > 0)
|
|
47
|
-
o(new Error("Error parsing file"));
|
|
48
|
+
console.error(a.errors), o(new Error("Error parsing file"));
|
|
48
49
|
else {
|
|
49
50
|
const p = l(a.data[0], t?.column || "text"), i = t?.hasHeader ?? x(a.data[0]) ? a.data.slice(1) : a.data;
|
|
50
51
|
c(i.map((d) => d[p]));
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { o as l, h, E as m,
|
|
1
|
+
import { o as l, h, E as m, af as p, k as c, ag as d, ad as g, j as u, ah as V, ai as v, a8 as N, b as w } from "./index--6vO-cOz.js";
|
|
2
2
|
import { s as f } from "./index-C4L8Cm77.js";
|
|
3
3
|
/**
|
|
4
4
|
* @license
|
|
@@ -16,11 +16,11 @@ import { s as f } from "./index-C4L8Cm77.js";
|
|
|
16
16
|
* limitations under the License.
|
|
17
17
|
* =============================================================================
|
|
18
18
|
*/
|
|
19
|
-
function
|
|
19
|
+
function b(r) {
|
|
20
20
|
const e = { x: h(r, "x", "floor", "float32") };
|
|
21
21
|
return m.runKernel(p, e);
|
|
22
22
|
}
|
|
23
|
-
const x = /* @__PURE__ */ l({ floor_:
|
|
23
|
+
const x = /* @__PURE__ */ l({ floor_: b });
|
|
24
24
|
/**
|
|
25
25
|
* @license
|
|
26
26
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -181,7 +181,7 @@ function R(r, t, e, s) {
|
|
|
181
181
|
if (u(n.dtype === "float32", () => `x has to be a floating point tensor since it's going to be scaled, but got a ${n.dtype} tensor instead.`), u(t >= 0 && t < 1, () => `rate must be a float in the range [0, 1), but got ${t}.`), t === 0)
|
|
182
182
|
return r instanceof V ? n.clone() : n;
|
|
183
183
|
const o = E(n, e), a = 1 - t, i = v(x(N(D(o, 0, 1, "float32", s), a)), a);
|
|
184
|
-
return
|
|
184
|
+
return w(n, i);
|
|
185
185
|
}
|
|
186
186
|
const G = /* @__PURE__ */ l({ dropout_: R });
|
|
187
187
|
export {
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import { L as e } from "./index--6vO-cOz.js";
|
|
2
|
+
/**
|
|
3
|
+
* @license
|
|
4
|
+
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
* =============================================================================
|
|
17
|
+
*/
|
|
18
|
+
function n(o) {
|
|
19
|
+
return e().getBool("WEBGL_USE_SHAPES_UNIFORMS") && o <= 4;
|
|
20
|
+
}
|
|
21
|
+
export {
|
|
22
|
+
n as u
|
|
23
|
+
};
|
|
@@ -4001,7 +4001,7 @@ function As() {
|
|
|
4001
4001
|
*/
|
|
4002
4002
|
As();
|
|
4003
4003
|
export {
|
|
4004
|
-
|
|
4004
|
+
xe as $,
|
|
4005
4005
|
Ss as A,
|
|
4006
4006
|
Zs as B,
|
|
4007
4007
|
or as C,
|
|
@@ -4013,99 +4013,99 @@ export {
|
|
|
4013
4013
|
Fs as I,
|
|
4014
4014
|
kn as J,
|
|
4015
4015
|
En as K,
|
|
4016
|
-
|
|
4016
|
+
k as L,
|
|
4017
4017
|
ta as M,
|
|
4018
|
-
|
|
4019
|
-
|
|
4018
|
+
Lr as N,
|
|
4019
|
+
rs as O,
|
|
4020
4020
|
ba as P,
|
|
4021
|
-
|
|
4021
|
+
de as Q,
|
|
4022
4022
|
Ia as R,
|
|
4023
4023
|
qa as S,
|
|
4024
|
-
|
|
4025
|
-
|
|
4026
|
-
|
|
4027
|
-
|
|
4028
|
-
|
|
4029
|
-
|
|
4030
|
-
|
|
4031
|
-
|
|
4024
|
+
Ea as T,
|
|
4025
|
+
Zt as U,
|
|
4026
|
+
dr as V,
|
|
4027
|
+
Oa as W,
|
|
4028
|
+
De as X,
|
|
4029
|
+
ar as Y,
|
|
4030
|
+
ne as Z,
|
|
4031
|
+
aa as _,
|
|
4032
4032
|
M as a,
|
|
4033
|
-
|
|
4034
|
-
|
|
4035
|
-
|
|
4036
|
-
|
|
4037
|
-
|
|
4038
|
-
|
|
4039
|
-
|
|
4040
|
-
|
|
4041
|
-
|
|
4042
|
-
|
|
4043
|
-
|
|
4044
|
-
|
|
4045
|
-
|
|
4046
|
-
|
|
4047
|
-
|
|
4048
|
-
|
|
4049
|
-
|
|
4050
|
-
|
|
4051
|
-
|
|
4052
|
-
|
|
4053
|
-
|
|
4054
|
-
|
|
4055
|
-
|
|
4056
|
-
|
|
4057
|
-
|
|
4058
|
-
|
|
4059
|
-
|
|
4060
|
-
|
|
4061
|
-
|
|
4062
|
-
|
|
4063
|
-
|
|
4064
|
-
|
|
4065
|
-
|
|
4066
|
-
|
|
4067
|
-
|
|
4068
|
-
|
|
4069
|
-
|
|
4070
|
-
|
|
4071
|
-
|
|
4072
|
-
|
|
4073
|
-
|
|
4074
|
-
|
|
4075
|
-
|
|
4076
|
-
|
|
4077
|
-
|
|
4078
|
-
|
|
4079
|
-
|
|
4080
|
-
|
|
4081
|
-
|
|
4082
|
-
|
|
4083
|
-
|
|
4084
|
-
|
|
4085
|
-
|
|
4086
|
-
|
|
4087
|
-
|
|
4088
|
-
|
|
4089
|
-
|
|
4090
|
-
|
|
4091
|
-
|
|
4092
|
-
|
|
4093
|
-
|
|
4094
|
-
|
|
4095
|
-
|
|
4096
|
-
|
|
4033
|
+
xs as a$,
|
|
4034
|
+
V as a0,
|
|
4035
|
+
oa as a1,
|
|
4036
|
+
ns as a2,
|
|
4037
|
+
nt as a3,
|
|
4038
|
+
Qa as a4,
|
|
4039
|
+
Ca as a5,
|
|
4040
|
+
Fr as a6,
|
|
4041
|
+
qr as a7,
|
|
4042
|
+
S as a8,
|
|
4043
|
+
la as a9,
|
|
4044
|
+
Wr as aA,
|
|
4045
|
+
jr as aB,
|
|
4046
|
+
Kr as aC,
|
|
4047
|
+
ha as aD,
|
|
4048
|
+
Jr as aE,
|
|
4049
|
+
ia as aF,
|
|
4050
|
+
Sa as aG,
|
|
4051
|
+
Ta as aH,
|
|
4052
|
+
Aa as aI,
|
|
4053
|
+
Ra as aJ,
|
|
4054
|
+
$a as aK,
|
|
4055
|
+
Ds as aL,
|
|
4056
|
+
ro as aM,
|
|
4057
|
+
no as aN,
|
|
4058
|
+
eo as aO,
|
|
4059
|
+
Io as aP,
|
|
4060
|
+
oo as aQ,
|
|
4061
|
+
yr as aR,
|
|
4062
|
+
$r as aS,
|
|
4063
|
+
ao as aT,
|
|
4064
|
+
da as aU,
|
|
4065
|
+
ma as aV,
|
|
4066
|
+
ga as aW,
|
|
4067
|
+
Na as aX,
|
|
4068
|
+
va as aY,
|
|
4069
|
+
to as aZ,
|
|
4070
|
+
yo as a_,
|
|
4071
|
+
ua as aa,
|
|
4072
|
+
Za as ab,
|
|
4073
|
+
$t as ac,
|
|
4074
|
+
Rt as ad,
|
|
4075
|
+
Rs as ae,
|
|
4076
|
+
xr as af,
|
|
4077
|
+
Wn as ag,
|
|
4078
|
+
D as ah,
|
|
4079
|
+
x as ai,
|
|
4080
|
+
F as aj,
|
|
4081
|
+
pe as ak,
|
|
4082
|
+
fo as al,
|
|
4083
|
+
dt as am,
|
|
4084
|
+
jt as an,
|
|
4085
|
+
ue as ao,
|
|
4086
|
+
za as ap,
|
|
4087
|
+
_a as aq,
|
|
4088
|
+
er as ar,
|
|
4089
|
+
rr as as,
|
|
4090
|
+
Pa as at,
|
|
4091
|
+
Ar as au,
|
|
4092
|
+
Br as av,
|
|
4093
|
+
Rr as aw,
|
|
4094
|
+
_r as ax,
|
|
4095
|
+
Or as ay,
|
|
4096
|
+
Gr as az,
|
|
4097
4097
|
b,
|
|
4098
4098
|
Vs as b$,
|
|
4099
|
-
|
|
4100
|
-
|
|
4101
|
-
|
|
4102
|
-
|
|
4103
|
-
|
|
4104
|
-
|
|
4105
|
-
|
|
4106
|
-
|
|
4107
|
-
|
|
4108
|
-
|
|
4099
|
+
$s as b0,
|
|
4100
|
+
ko as b1,
|
|
4101
|
+
Ps as b2,
|
|
4102
|
+
Cs as b3,
|
|
4103
|
+
Lt as b4,
|
|
4104
|
+
te as b5,
|
|
4105
|
+
uo as b6,
|
|
4106
|
+
dn as b7,
|
|
4107
|
+
Re as b8,
|
|
4108
|
+
$e as b9,
|
|
4109
4109
|
ya as bA,
|
|
4110
4110
|
pa as bB,
|
|
4111
4111
|
wa as bC,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { an as D, ao as N, O as w, n as R, Q as v, L as P } from "./index--6vO-cOz.js";
|
|
2
|
+
import { u as g } from "./gpgpu_math-CUzjlO9A.js";
|
|
2
3
|
/**
|
|
3
4
|
* @license
|
|
4
5
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -15,15 +16,15 @@ import { aj as E, ak as D, af as B, al as w, n as N, am as v } from "./index-C4J
|
|
|
15
16
|
* limitations under the License.
|
|
16
17
|
* =============================================================================
|
|
17
18
|
*/
|
|
18
|
-
function
|
|
19
|
+
function B(t) {
|
|
19
20
|
try {
|
|
20
|
-
return t.map((e) =>
|
|
21
|
+
return t.map((e) => D(e));
|
|
21
22
|
} catch (e) {
|
|
22
23
|
throw new Error(`Failed to decode encoded string bytes into utf-8, error: ${e}`);
|
|
23
24
|
}
|
|
24
25
|
}
|
|
25
|
-
function
|
|
26
|
-
return t.map((e) =>
|
|
26
|
+
function H(t) {
|
|
27
|
+
return t.map((e) => N(e));
|
|
27
28
|
}
|
|
28
29
|
/**
|
|
29
30
|
* @license
|
|
@@ -41,7 +42,7 @@ function F(t) {
|
|
|
41
42
|
* limitations under the License.
|
|
42
43
|
* =============================================================================
|
|
43
44
|
*/
|
|
44
|
-
function
|
|
45
|
+
function k(t) {
|
|
45
46
|
if (t <= 1)
|
|
46
47
|
return "int";
|
|
47
48
|
if (t === 2)
|
|
@@ -56,25 +57,6 @@ function R(t) {
|
|
|
56
57
|
return "ivec6";
|
|
57
58
|
throw Error(`GPU for rank ${t} is not yet supported`);
|
|
58
59
|
}
|
|
59
|
-
/**
|
|
60
|
-
* @license
|
|
61
|
-
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
62
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
63
|
-
* you may not use this file except in compliance with the License.
|
|
64
|
-
* You may obtain a copy of the License at
|
|
65
|
-
*
|
|
66
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
67
|
-
*
|
|
68
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
69
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
70
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
71
|
-
* See the License for the specific language governing permissions and
|
|
72
|
-
* limitations under the License.
|
|
73
|
-
* =============================================================================
|
|
74
|
-
*/
|
|
75
|
-
function y(t) {
|
|
76
|
-
return B().getBool("WEBGL_USE_SHAPES_UNIFORMS") && t <= 4;
|
|
77
|
-
}
|
|
78
60
|
/**
|
|
79
61
|
* @license
|
|
80
62
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -91,11 +73,11 @@ function y(t) {
|
|
|
91
73
|
* limitations under the License.
|
|
92
74
|
* =============================================================================
|
|
93
75
|
*/
|
|
94
|
-
function
|
|
95
|
-
return ["x", "y", "z", "w", "u", "v"].slice(0, e).map((
|
|
76
|
+
function E(t, e) {
|
|
77
|
+
return ["x", "y", "z", "w", "u", "v"].slice(0, e).map((o) => `${t}.${o}`);
|
|
96
78
|
}
|
|
97
|
-
function
|
|
98
|
-
return e === 1 ? [t] :
|
|
79
|
+
function z(t, e) {
|
|
80
|
+
return e === 1 ? [t] : E(t, e);
|
|
99
81
|
}
|
|
100
82
|
/**
|
|
101
83
|
* @license
|
|
@@ -113,9 +95,9 @@ function _(t, e) {
|
|
|
113
95
|
* limitations under the License.
|
|
114
96
|
* =============================================================================
|
|
115
97
|
*/
|
|
116
|
-
class
|
|
117
|
-
constructor(e,
|
|
118
|
-
this.variableNames = ["A", "B"], this.outputShape = w(
|
|
98
|
+
class C {
|
|
99
|
+
constructor(e, o, u) {
|
|
100
|
+
this.variableNames = ["A", "B"], this.outputShape = w(o, u), this.enableShapeUniforms = g(this.outputShape.length), this.userCode = `
|
|
119
101
|
float binaryOperation(float a, float b) {
|
|
120
102
|
${e}
|
|
121
103
|
}
|
|
@@ -144,22 +126,22 @@ class A {
|
|
|
144
126
|
* limitations under the License.
|
|
145
127
|
* =============================================================================
|
|
146
128
|
*/
|
|
147
|
-
class
|
|
148
|
-
constructor(e,
|
|
149
|
-
this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = w(
|
|
150
|
-
const
|
|
151
|
-
this.enableShapeUniforms =
|
|
129
|
+
class _ {
|
|
130
|
+
constructor(e, o, u, d = !1) {
|
|
131
|
+
this.variableNames = ["A", "B"], this.supportsBroadcasting = !0, this.packedInputs = !0, this.packedOutput = !0, this.outputShape = w(o, u);
|
|
132
|
+
const a = this.outputShape.length;
|
|
133
|
+
this.enableShapeUniforms = g(a);
|
|
152
134
|
let n = "";
|
|
153
135
|
if (d)
|
|
154
|
-
if (
|
|
136
|
+
if (a === 0 || R(this.outputShape) === 1)
|
|
155
137
|
n = `
|
|
156
138
|
result.y = 0.;
|
|
157
139
|
result.z = 0.;
|
|
158
140
|
result.w = 0.;
|
|
159
141
|
`;
|
|
160
142
|
else if (n = `
|
|
161
|
-
${
|
|
162
|
-
`,
|
|
143
|
+
${k(a)} coords = getOutputCoords();
|
|
144
|
+
`, a === 1)
|
|
163
145
|
this.enableShapeUniforms ? n += `
|
|
164
146
|
result.y = (coords + 1) >= outShape ? 0. : result.y;
|
|
165
147
|
result.z = 0.;
|
|
@@ -170,20 +152,20 @@ class z {
|
|
|
170
152
|
result.w = 0.;
|
|
171
153
|
`;
|
|
172
154
|
else {
|
|
173
|
-
const s =
|
|
155
|
+
const s = z("coords", a);
|
|
174
156
|
this.enableShapeUniforms ? n += `
|
|
175
157
|
bool nextRowOutOfBounds =
|
|
176
|
-
(${s[
|
|
158
|
+
(${s[a - 2]} + 1) >= outShape[${a} - 2];
|
|
177
159
|
bool nextColOutOfBounds =
|
|
178
|
-
(${s[
|
|
160
|
+
(${s[a - 1]} + 1) >= outShape[${a} - 1];
|
|
179
161
|
result.y = nextColOutOfBounds ? 0. : result.y;
|
|
180
162
|
result.z = nextRowOutOfBounds ? 0. : result.z;
|
|
181
163
|
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
|
|
182
164
|
` : n += `
|
|
183
165
|
bool nextRowOutOfBounds =
|
|
184
|
-
(${s[
|
|
166
|
+
(${s[a - 2]} + 1) >= ${this.outputShape[a - 2]};
|
|
185
167
|
bool nextColOutOfBounds =
|
|
186
|
-
(${s[
|
|
168
|
+
(${s[a - 1]} + 1) >= ${this.outputShape[a - 1]};
|
|
187
169
|
result.y = nextColOutOfBounds ? 0. : result.y;
|
|
188
170
|
result.z = nextRowOutOfBounds ? 0. : result.z;
|
|
189
171
|
result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;
|
|
@@ -222,9 +204,9 @@ class z {
|
|
|
222
204
|
* limitations under the License.
|
|
223
205
|
* =============================================================================
|
|
224
206
|
*/
|
|
225
|
-
function
|
|
226
|
-
const { inputs: e, backend:
|
|
227
|
-
return
|
|
207
|
+
function A(t) {
|
|
208
|
+
const { inputs: e, backend: o } = t, { x: u } = e;
|
|
209
|
+
return o.incRef(u.dataId), { dataId: u.dataId, shape: u.shape, dtype: u.dtype };
|
|
228
210
|
}
|
|
229
211
|
/**
|
|
230
212
|
* @license
|
|
@@ -243,8 +225,8 @@ function P(t) {
|
|
|
243
225
|
* =============================================================================
|
|
244
226
|
*/
|
|
245
227
|
function G(t) {
|
|
246
|
-
const { inputs: e, backend:
|
|
247
|
-
return n.complexTensorInfos = { real: l, imag: s },
|
|
228
|
+
const { inputs: e, backend: o } = t, { real: u, imag: d } = e, a = o.makeTensorInfo(u.shape, "complex64"), n = o.texData.get(a.dataId), l = A({ inputs: { x: u }, backend: o }), s = A({ inputs: { x: d }, backend: o });
|
|
229
|
+
return n.complexTensorInfos = { real: l, imag: s }, a;
|
|
248
230
|
}
|
|
249
231
|
/**
|
|
250
232
|
* @license
|
|
@@ -263,10 +245,10 @@ function G(t) {
|
|
|
263
245
|
* =============================================================================
|
|
264
246
|
*/
|
|
265
247
|
class V {
|
|
266
|
-
constructor(e,
|
|
267
|
-
this.variableNames = ["A"], this.outputShape = e, this.enableShapeUniforms =
|
|
248
|
+
constructor(e, o) {
|
|
249
|
+
this.variableNames = ["A"], this.outputShape = e, this.enableShapeUniforms = g(this.outputShape.length), this.userCode = `
|
|
268
250
|
float unaryOperation(float x) {
|
|
269
|
-
${
|
|
251
|
+
${o}
|
|
270
252
|
}
|
|
271
253
|
|
|
272
254
|
void main() {
|
|
@@ -278,7 +260,7 @@ class V {
|
|
|
278
260
|
`;
|
|
279
261
|
}
|
|
280
262
|
}
|
|
281
|
-
const
|
|
263
|
+
const K = "if (isnan(x)) return x;";
|
|
282
264
|
/**
|
|
283
265
|
* @license
|
|
284
266
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -296,10 +278,10 @@ const H = "if (isnan(x)) return x;";
|
|
|
296
278
|
* =============================================================================
|
|
297
279
|
*/
|
|
298
280
|
class L {
|
|
299
|
-
constructor(e,
|
|
300
|
-
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = e, this.enableShapeUniforms =
|
|
281
|
+
constructor(e, o) {
|
|
282
|
+
this.variableNames = ["A"], this.packedInputs = !0, this.packedOutput = !0, this.outputShape = e, this.enableShapeUniforms = g(this.outputShape.length), this.userCode = `
|
|
301
283
|
vec4 unaryOperation(vec4 x) {
|
|
302
|
-
${
|
|
284
|
+
${o}
|
|
303
285
|
}
|
|
304
286
|
|
|
305
287
|
void main() {
|
|
@@ -327,62 +309,61 @@ class L {
|
|
|
327
309
|
* limitations under the License.
|
|
328
310
|
* =============================================================================
|
|
329
311
|
*/
|
|
330
|
-
function
|
|
331
|
-
return ({ inputs: d, backend:
|
|
332
|
-
const { x: n } = d, l =
|
|
333
|
-
if (l.shouldExecuteOnCPU([n]) &&
|
|
334
|
-
const
|
|
312
|
+
function Y({ opSnippet: t, packedOpSnippet: e, cpuKernelImpl: o, dtype: u }) {
|
|
313
|
+
return ({ inputs: d, backend: a }) => {
|
|
314
|
+
const { x: n } = d, l = a, s = u || n.dtype;
|
|
315
|
+
if (l.shouldExecuteOnCPU([n]) && o != null) {
|
|
316
|
+
const c = l.texData.get(n.dataId), x = o(c.values, s);
|
|
335
317
|
return l.makeTensorInfo(n.shape, s, x);
|
|
336
318
|
}
|
|
337
|
-
const i =
|
|
319
|
+
const i = P().getBool("WEBGL_PACK_UNARY_OPERATIONS") && e != null;
|
|
338
320
|
let r;
|
|
339
321
|
return i ? r = new L(n.shape, e) : r = new V(n.shape, t), l.runWebGLProgram(r, [n], s);
|
|
340
322
|
};
|
|
341
323
|
}
|
|
342
|
-
function
|
|
324
|
+
function Q({ opSnippet: t, packedOpSnippet: e, checkOutOfBounds: o = !1, supportsComplex: u = !1, cpuKernelImpl: d, dtype: a }) {
|
|
343
325
|
return ({ inputs: n, backend: l }) => {
|
|
344
326
|
const { a: s, b: i } = n, r = l;
|
|
345
327
|
if (u && s.dtype === "complex64") {
|
|
346
|
-
const h = r.texData.get(s.dataId), f = r.texData.get(i.dataId), [
|
|
328
|
+
const h = r.texData.get(s.dataId), f = r.texData.get(i.dataId), [O, y] = [
|
|
347
329
|
[h.complexTensorInfos.real, f.complexTensorInfos.real],
|
|
348
330
|
[h.complexTensorInfos.imag, f.complexTensorInfos.imag]
|
|
349
331
|
].map((S) => {
|
|
350
|
-
const [
|
|
351
|
-
dataId:
|
|
352
|
-
dtype:
|
|
332
|
+
const [p, m] = S, $ = {
|
|
333
|
+
dataId: p.dataId,
|
|
334
|
+
dtype: p.dtype,
|
|
353
335
|
shape: s.shape
|
|
354
336
|
}, T = {
|
|
355
|
-
dataId:
|
|
356
|
-
dtype:
|
|
337
|
+
dataId: m.dataId,
|
|
338
|
+
dtype: m.dtype,
|
|
357
339
|
shape: i.shape
|
|
358
|
-
}, U = new
|
|
359
|
-
return r.runWebGLProgram(U, [$, T], v(
|
|
360
|
-
}), I = G({ inputs: { real:
|
|
361
|
-
return r.disposeIntermediateTensorInfo(
|
|
340
|
+
}, U = new C(t, s.shape, i.shape);
|
|
341
|
+
return r.runWebGLProgram(U, [$, T], v(p.dtype, m.dtype));
|
|
342
|
+
}), I = G({ inputs: { real: O, imag: y }, backend: r });
|
|
343
|
+
return r.disposeIntermediateTensorInfo(O), r.disposeIntermediateTensorInfo(y), I;
|
|
362
344
|
}
|
|
363
|
-
const
|
|
345
|
+
const c = a || v(s.dtype, i.dtype);
|
|
364
346
|
if ((s.dtype === "string" || i.dtype === "string" || r.shouldExecuteOnCPU([s, i])) && d != null) {
|
|
365
|
-
const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values,
|
|
347
|
+
const h = r.texData.get(s.dataId).values, f = r.texData.get(i.dataId).values, O = s.dtype === "string" ? (
|
|
366
348
|
// tslint:disable-next-line: no-any
|
|
367
|
-
|
|
368
|
-
) : h,
|
|
349
|
+
B(h)
|
|
350
|
+
) : h, y = s.dtype === "string" ? (
|
|
369
351
|
// tslint:disable-next-line: no-any
|
|
370
|
-
|
|
371
|
-
) : f, [I, S] = d(s.shape, i.shape,
|
|
372
|
-
return
|
|
352
|
+
B(f)
|
|
353
|
+
) : f, [I, S] = d(s.shape, i.shape, O, y, c), p = r.makeTensorInfo(S, c), m = r.texData.get(p.dataId);
|
|
354
|
+
return m.values = I, p;
|
|
373
355
|
}
|
|
374
|
-
const x =
|
|
356
|
+
const x = P().getBool("WEBGL_PACK_BINARY_OPERATIONS") && e != null;
|
|
375
357
|
let b;
|
|
376
|
-
return x ? b = new
|
|
358
|
+
return x ? b = new _(e, s.shape, i.shape, o) : b = new C(t, s.shape, i.shape), r.runWebGLProgram(b, [s, i], c);
|
|
377
359
|
};
|
|
378
360
|
}
|
|
379
361
|
export {
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
K as u
|
|
362
|
+
K as C,
|
|
363
|
+
H as a,
|
|
364
|
+
E as b,
|
|
365
|
+
Q as c,
|
|
366
|
+
B as f,
|
|
367
|
+
k as g,
|
|
368
|
+
Y as u
|
|
388
369
|
};
|