@genai-fi/nanogpt 0.2.6 → 0.2.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/TeachableLLM.js +1 -0
- package/dist/{complex-D6Bq1XDf.js → complex-CeoYJn2o.js} +1 -1
- package/dist/{index-D1SlunD-.js → index-DQfEAU9u.js} +17 -17
- package/dist/layers/CausalSelfAttention.js +40 -39
- package/dist/layers/TiedEmbedding.js +104 -126
- package/dist/main.js +15 -14
- package/dist/mat_mul-CuHB58-H.js +27 -0
- package/dist/ops/attentionMask.d.ts +2 -0
- package/dist/ops/attentionMask.js +62 -0
- package/dist/ops/gatherSub.js +2 -2
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/scatterSub.js +7 -7
- package/dist/{stack-DB2YLlAs.js → stack-C9cTkqpq.js} +3 -3
- package/dist/{sum-02UQ5Eaq.js → sum-B-O33dgG.js} +3 -3
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/sparseCrossEntropy.js +16 -16
- package/package.json +1 -1
package/dist/TeachableLLM.js
CHANGED
|
@@ -3858,29 +3858,29 @@ export {
|
|
|
3858
3858
|
Qn as V,
|
|
3859
3859
|
qs as _,
|
|
3860
3860
|
Z as a,
|
|
3861
|
-
|
|
3861
|
+
K as b,
|
|
3862
3862
|
I as c,
|
|
3863
|
-
|
|
3863
|
+
Is as d,
|
|
3864
3864
|
Xs as e,
|
|
3865
|
-
|
|
3866
|
-
|
|
3867
|
-
|
|
3868
|
-
|
|
3869
|
-
|
|
3870
|
-
|
|
3871
|
-
|
|
3865
|
+
Js as f,
|
|
3866
|
+
y as g,
|
|
3867
|
+
Ls as h,
|
|
3868
|
+
Ft as i,
|
|
3869
|
+
Nt as j,
|
|
3870
|
+
Qt as k,
|
|
3871
|
+
U as l,
|
|
3872
3872
|
p as m,
|
|
3873
|
-
|
|
3873
|
+
Ne as n,
|
|
3874
3874
|
F as o,
|
|
3875
|
-
|
|
3876
|
-
|
|
3875
|
+
Gs as p,
|
|
3876
|
+
vs as q,
|
|
3877
3877
|
Hs as r,
|
|
3878
3878
|
j as s,
|
|
3879
|
-
|
|
3880
|
-
|
|
3881
|
-
|
|
3882
|
-
|
|
3883
|
-
|
|
3879
|
+
Ts as t,
|
|
3880
|
+
w as u,
|
|
3881
|
+
js as v,
|
|
3882
|
+
Qs as w,
|
|
3883
|
+
E as x,
|
|
3884
3884
|
zs as y,
|
|
3885
3885
|
C as z
|
|
3886
3886
|
};
|
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import { attentionMask as z } from "../ops/attentionMask.js";
|
|
2
|
+
class j {
|
|
2
3
|
constructor(t, i, s, e) {
|
|
3
4
|
this.ropeCache = e, this.config = s, this.tf = t, this.index = i, this.cAttn = this.tf.layers.dense({
|
|
4
5
|
units: 3 * s.nEmbed,
|
|
@@ -18,9 +19,9 @@ class S {
|
|
|
18
19
|
stddev: 0.02 / Math.sqrt(2 * s.nLayer)
|
|
19
20
|
}),
|
|
20
21
|
biasInitializer: "zeros"
|
|
21
|
-
}), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor =
|
|
22
|
-
const
|
|
23
|
-
this.maskInf = this.tf.where(this.bias,
|
|
22
|
+
}), this.attnDropout = this.tf.layers.dropout({ rate: s.dropout }), this.residDropout = this.tf.layers.dropout({ rate: s.dropout }), this.bias = this.tf.linalg.bandPart(this.tf.ones([s.blockSize, s.blockSize]), -1, 0).cast("bool"), this.divisor = 1 / Math.sqrt(s.nEmbed / s.nHead);
|
|
23
|
+
const o = this.tf.zeros([s.blockSize, s.blockSize]), c = this.tf.fill([s.blockSize, s.blockSize], Number.NEGATIVE_INFINITY);
|
|
24
|
+
this.maskInf = this.tf.where(this.bias, o, c);
|
|
24
25
|
}
|
|
25
26
|
config;
|
|
26
27
|
cAttn;
|
|
@@ -52,70 +53,70 @@ class S {
|
|
|
52
53
|
this.cAttn.setWeights(t.get(`block_${this.index}_cAttn`) || []), this.cProj.setWeights(t.get(`block_${this.index}_cProj`) || []);
|
|
53
54
|
}
|
|
54
55
|
getAttentionScores(t, i, s) {
|
|
55
|
-
const e =
|
|
56
|
+
const e = z(t, i, this.maskInf, this.divisor), o = this.tf.softmax(e, -1);
|
|
56
57
|
return this.attnDropout.apply(o, { training: s });
|
|
57
58
|
}
|
|
58
59
|
// Attention with optional past. If pastLen > 0 and T_cur == 1, no mask needed.
|
|
59
60
|
getAttentionScoresWithPast(t, i, s, e) {
|
|
60
|
-
const
|
|
61
|
-
let
|
|
62
|
-
if (
|
|
61
|
+
const o = t.shape[2];
|
|
62
|
+
let r = this.tf.matMul(t, i, !1, !0).mul(this.divisor);
|
|
63
|
+
if (o > 1 && e > 0)
|
|
63
64
|
throw new Error("Cannot use past with T_cur > 1");
|
|
64
|
-
if (
|
|
65
|
-
const
|
|
66
|
-
|
|
65
|
+
if (o > 1) {
|
|
66
|
+
const a = this.maskInf.slice([0, 0], [o, o]).expandDims(0).expandDims(0);
|
|
67
|
+
r = r.add(a);
|
|
67
68
|
}
|
|
68
|
-
const
|
|
69
|
-
return this.attnDropout.apply(
|
|
69
|
+
const h = this.tf.softmax(r, -1);
|
|
70
|
+
return this.attnDropout.apply(h, { training: s });
|
|
70
71
|
}
|
|
71
72
|
getQKV(t) {
|
|
72
|
-
const [i, s, e] = t.shape,
|
|
73
|
-
|
|
74
|
-
const
|
|
75
|
-
|
|
73
|
+
const [i, s, e] = t.shape, o = this.cAttn.apply(t), [c, r, h] = this.tf.split(o, 3, -1);
|
|
74
|
+
o.dispose();
|
|
75
|
+
const a = e / this.config.nHead, u = this.tf.reshape(c, [i, s, this.config.nHead, a]);
|
|
76
|
+
c.dispose();
|
|
76
77
|
const f = u.transpose([0, 2, 1, 3]);
|
|
77
78
|
u.dispose();
|
|
78
|
-
const d = this.tf.reshape(
|
|
79
|
-
n.dispose();
|
|
80
|
-
const c = d.transpose([0, 2, 1, 3]);
|
|
81
|
-
d.dispose();
|
|
82
|
-
const l = this.tf.reshape(r, [i, s, this.config.nHead, o]);
|
|
79
|
+
const d = this.tf.reshape(r, [i, s, this.config.nHead, a]);
|
|
83
80
|
r.dispose();
|
|
81
|
+
const n = d.transpose([0, 2, 1, 3]);
|
|
82
|
+
d.dispose();
|
|
83
|
+
const l = this.tf.reshape(h, [i, s, this.config.nHead, a]);
|
|
84
|
+
h.dispose();
|
|
84
85
|
const p = l.transpose([0, 2, 1, 3]);
|
|
85
|
-
return l.dispose(), [f,
|
|
86
|
+
return l.dispose(), [f, n, p];
|
|
86
87
|
}
|
|
87
88
|
getOutputProjection(t, i) {
|
|
88
|
-
const s = t.shape[0], e = t.shape[2],
|
|
89
|
-
return this.residDropout.apply(
|
|
89
|
+
const s = t.shape[0], e = t.shape[2], o = this.config.nEmbed, c = t.transpose([0, 2, 1, 3]), r = this.tf.reshape(c, [s, e, o]), h = this.cProj.apply(r);
|
|
90
|
+
return this.residDropout.apply(h, { training: i });
|
|
90
91
|
}
|
|
91
92
|
// Added optional KV cache support (pastKV). Returns presentKV for chaining.
|
|
92
93
|
call(t, i = !1, s = !1, e) {
|
|
93
94
|
if (e && !this.config.useRope)
|
|
94
95
|
throw new Error("Cannot use pastKV without RoPE enabled");
|
|
95
96
|
return this.tf.tidy(() => {
|
|
96
|
-
const [
|
|
97
|
-
let
|
|
98
|
-
e && (p = e.length,
|
|
99
|
-
const b =
|
|
100
|
-
if (b >
|
|
101
|
-
const k = b -
|
|
102
|
-
|
|
97
|
+
const [o, c, r] = this.getQKV(t), h = o.shape[2], a = this.config.blockSize, u = e ? e.cumulativeLength : 0, [f, d] = this.ropeCache ? this.ropeCache.applyRoPE(o, c, u) : [o, c];
|
|
98
|
+
let n = d, l = r, p = 0;
|
|
99
|
+
e && (p = e.length, n = this.tf.concat([e.k, d], 2), l = this.tf.concat([e.v, r], 2));
|
|
100
|
+
const b = n.shape[2];
|
|
101
|
+
if (b > a) {
|
|
102
|
+
const k = b - a, g = n.shape[0], I = n.shape[1], _ = n.shape[3];
|
|
103
|
+
n = n.slice([0, 0, k, 0], [g, I, a, _]), l = l.slice([0, 0, k, 0], [g, I, a, _]), p = a - h;
|
|
103
104
|
}
|
|
104
105
|
let m;
|
|
105
|
-
p > 0 ? m = this.getAttentionScoresWithPast(f,
|
|
106
|
-
const
|
|
107
|
-
k: this.tf.keep(
|
|
106
|
+
p > 0 ? m = this.getAttentionScoresWithPast(f, n, i, p) : m = this.getAttentionScores(f, n, i);
|
|
107
|
+
const v = this.tf.matMul(m, l), A = this.getOutputProjection(v, i), P = {
|
|
108
|
+
k: this.tf.keep(n),
|
|
108
109
|
v: this.tf.keep(l),
|
|
109
|
-
length: p +
|
|
110
|
-
cumulativeLength: e ? e.cumulativeLength +
|
|
110
|
+
length: p + h,
|
|
111
|
+
cumulativeLength: e ? e.cumulativeLength + h : h
|
|
111
112
|
};
|
|
112
113
|
return { output: A, attention: s ? m.mean(1) : void 0, presentKV: P };
|
|
113
114
|
});
|
|
114
115
|
}
|
|
115
116
|
dispose() {
|
|
116
|
-
this.cAttn.dispose(), this.cProj.dispose(), this.attnDropout.dispose(), this.residDropout.dispose(), this.bias.dispose(), this.maskInf.dispose()
|
|
117
|
+
this.cAttn.dispose(), this.cProj.dispose(), this.attnDropout.dispose(), this.residDropout.dispose(), this.bias.dispose(), this.maskInf.dispose();
|
|
117
118
|
}
|
|
118
119
|
}
|
|
119
120
|
export {
|
|
120
|
-
|
|
121
|
+
j as default
|
|
121
122
|
};
|
|
@@ -1,29 +1,7 @@
|
|
|
1
|
-
import { o as h, c as
|
|
2
|
-
import { s as
|
|
3
|
-
import {
|
|
4
|
-
|
|
5
|
-
* @license
|
|
6
|
-
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
7
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
-
* you may not use this file except in compliance with the License.
|
|
9
|
-
* You may obtain a copy of the License at
|
|
10
|
-
*
|
|
11
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
-
*
|
|
13
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
14
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
-
* See the License for the specific language governing permissions and
|
|
17
|
-
* limitations under the License.
|
|
18
|
-
* =============================================================================
|
|
19
|
-
*/
|
|
20
|
-
function fe(t, e, s = !1, n = !1) {
|
|
21
|
-
let r = u(t, "a", "matMul"), i = u(e, "b", "matMul");
|
|
22
|
-
[r, i] = B(r, i);
|
|
23
|
-
const o = { a: r, b: i }, p = { transposeA: s, transposeB: n };
|
|
24
|
-
return c.runKernel(V, o, p);
|
|
25
|
-
}
|
|
26
|
-
const m = /* @__PURE__ */ h({ matMul_: fe });
|
|
1
|
+
import { o as h, c as i, E as o, y as V, D as X, I as Y, F as Z, N as ee, H as te, J as se, K as ne, O as re, Q as ue, g as L, x as ae, T as A, m as ie, U as oe, u as le, b as q, l as C, V as P, w as U, _ as H } from "../index-DQfEAU9u.js";
|
|
2
|
+
import { s as ce, r as f } from "../sum-B-O33dgG.js";
|
|
3
|
+
import { m } from "../mat_mul-CuHB58-H.js";
|
|
4
|
+
import { c as pe } from "../complex-CeoYJn2o.js";
|
|
27
5
|
/**
|
|
28
6
|
* @license
|
|
29
7
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -40,11 +18,11 @@ const m = /* @__PURE__ */ h({ matMul_: fe });
|
|
|
40
18
|
* limitations under the License.
|
|
41
19
|
* =============================================================================
|
|
42
20
|
*/
|
|
43
|
-
function
|
|
44
|
-
const s = { x:
|
|
45
|
-
return
|
|
21
|
+
function he(t) {
|
|
22
|
+
const s = { x: i(t, "x", "sigmoid", "float32") };
|
|
23
|
+
return o.runKernel(V, s);
|
|
46
24
|
}
|
|
47
|
-
const
|
|
25
|
+
const fe = /* @__PURE__ */ h({ sigmoid_: he });
|
|
48
26
|
/**
|
|
49
27
|
* @license
|
|
50
28
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -61,11 +39,11 @@ const me = /* @__PURE__ */ h({ sigmoid_: de });
|
|
|
61
39
|
* limitations under the License.
|
|
62
40
|
* =============================================================================
|
|
63
41
|
*/
|
|
64
|
-
function
|
|
65
|
-
const s = { x:
|
|
66
|
-
return
|
|
42
|
+
function de(t) {
|
|
43
|
+
const s = { x: i(t, "x", "elu", "float32") };
|
|
44
|
+
return o.runKernel(X, s);
|
|
67
45
|
}
|
|
68
|
-
const
|
|
46
|
+
const me = /* @__PURE__ */ h({ elu_: de });
|
|
69
47
|
/**
|
|
70
48
|
* @license
|
|
71
49
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -82,11 +60,11 @@ const $e = /* @__PURE__ */ h({ elu_: ge });
|
|
|
82
60
|
* limitations under the License.
|
|
83
61
|
* =============================================================================
|
|
84
62
|
*/
|
|
85
|
-
function
|
|
86
|
-
const s = { input:
|
|
87
|
-
return
|
|
63
|
+
function ge(t) {
|
|
64
|
+
const s = { input: i(t, "input", "imag") };
|
|
65
|
+
return o.runKernel(Y, s);
|
|
88
66
|
}
|
|
89
|
-
const
|
|
67
|
+
const $e = /* @__PURE__ */ h({ imag_: ge });
|
|
90
68
|
/**
|
|
91
69
|
* @license
|
|
92
70
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -103,11 +81,11 @@ const ke = /* @__PURE__ */ h({ imag_: xe });
|
|
|
103
81
|
* limitations under the License.
|
|
104
82
|
* =============================================================================
|
|
105
83
|
*/
|
|
106
|
-
function
|
|
107
|
-
const n = { x:
|
|
108
|
-
return
|
|
84
|
+
function xe(t, e = 0.2) {
|
|
85
|
+
const n = { x: i(t, "x", "leakyRelu") }, r = { alpha: e };
|
|
86
|
+
return o.runKernel(Z, n, r);
|
|
109
87
|
}
|
|
110
|
-
const
|
|
88
|
+
const ke = /* @__PURE__ */ h({ leakyRelu_: xe });
|
|
111
89
|
/**
|
|
112
90
|
* @license
|
|
113
91
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -124,11 +102,11 @@ const be = /* @__PURE__ */ h({ leakyRelu_: De });
|
|
|
124
102
|
* limitations under the License.
|
|
125
103
|
* =============================================================================
|
|
126
104
|
*/
|
|
127
|
-
function
|
|
128
|
-
const s = { x:
|
|
129
|
-
return
|
|
105
|
+
function De(t) {
|
|
106
|
+
const s = { x: i(t, "x", "neg") };
|
|
107
|
+
return o.runKernel(ee, s);
|
|
130
108
|
}
|
|
131
|
-
const
|
|
109
|
+
const be = /* @__PURE__ */ h({ neg_: De });
|
|
132
110
|
/**
|
|
133
111
|
* @license
|
|
134
112
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -145,11 +123,11 @@ const Se = /* @__PURE__ */ h({ neg_: ye });
|
|
|
145
123
|
* limitations under the License.
|
|
146
124
|
* =============================================================================
|
|
147
125
|
*/
|
|
148
|
-
function
|
|
149
|
-
const s =
|
|
150
|
-
return
|
|
126
|
+
function ye(t, e) {
|
|
127
|
+
const s = i(t, "x", "prelu"), n = i(e, "alpha", "prelu"), r = { x: s, alpha: n };
|
|
128
|
+
return o.runKernel(te, r);
|
|
151
129
|
}
|
|
152
|
-
const
|
|
130
|
+
const Se = /* @__PURE__ */ h({ prelu_: ye });
|
|
153
131
|
/**
|
|
154
132
|
* @license
|
|
155
133
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -166,11 +144,11 @@ const Ke = /* @__PURE__ */ h({ prelu_: Me });
|
|
|
166
144
|
* limitations under the License.
|
|
167
145
|
* =============================================================================
|
|
168
146
|
*/
|
|
169
|
-
function
|
|
170
|
-
const s = { input:
|
|
171
|
-
return
|
|
147
|
+
function Ke(t) {
|
|
148
|
+
const s = { input: i(t, "input", "real") };
|
|
149
|
+
return o.runKernel(se, s);
|
|
172
150
|
}
|
|
173
|
-
const
|
|
151
|
+
const _e = /* @__PURE__ */ h({ real_: Ke });
|
|
174
152
|
/**
|
|
175
153
|
* @license
|
|
176
154
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -187,11 +165,11 @@ const we = /* @__PURE__ */ h({ real_: _e });
|
|
|
187
165
|
* limitations under the License.
|
|
188
166
|
* =============================================================================
|
|
189
167
|
*/
|
|
190
|
-
function
|
|
191
|
-
const s = { x:
|
|
192
|
-
return
|
|
168
|
+
function Me(t) {
|
|
169
|
+
const s = { x: i(t, "x", "relu") };
|
|
170
|
+
return o.runKernel(ne, s);
|
|
193
171
|
}
|
|
194
|
-
const
|
|
172
|
+
const we = /* @__PURE__ */ h({ relu_: Me });
|
|
195
173
|
/**
|
|
196
174
|
* @license
|
|
197
175
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -208,11 +186,11 @@ const ze = /* @__PURE__ */ h({ relu_: We });
|
|
|
208
186
|
* limitations under the License.
|
|
209
187
|
* =============================================================================
|
|
210
188
|
*/
|
|
211
|
-
function
|
|
212
|
-
const s = { x:
|
|
213
|
-
return
|
|
189
|
+
function We(t) {
|
|
190
|
+
const s = { x: i(t, "x", "relu6") };
|
|
191
|
+
return o.runKernel(re, s);
|
|
214
192
|
}
|
|
215
|
-
const
|
|
193
|
+
const ze = /* @__PURE__ */ h({ relu6_: We });
|
|
216
194
|
/**
|
|
217
195
|
* @license
|
|
218
196
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -229,11 +207,11 @@ const Oe = /* @__PURE__ */ h({ relu6_: Ee });
|
|
|
229
207
|
* limitations under the License.
|
|
230
208
|
* =============================================================================
|
|
231
209
|
*/
|
|
232
|
-
function
|
|
233
|
-
const n = { x:
|
|
234
|
-
return
|
|
210
|
+
function Ee(t, e = 0) {
|
|
211
|
+
const n = { x: i(t, "x", "step") }, r = { alpha: e };
|
|
212
|
+
return o.runKernel(ue, n, r);
|
|
235
213
|
}
|
|
236
|
-
const
|
|
214
|
+
const Oe = /* @__PURE__ */ h({ step_: Ee });
|
|
237
215
|
/**
|
|
238
216
|
* @license
|
|
239
217
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -250,19 +228,19 @@ const Re = /* @__PURE__ */ h({ step_: Fe });
|
|
|
250
228
|
* limitations under the License.
|
|
251
229
|
* =============================================================================
|
|
252
230
|
*/
|
|
253
|
-
function
|
|
254
|
-
const n =
|
|
255
|
-
if (e == null && (e = n.shape.map((
|
|
256
|
-
L(
|
|
231
|
+
function Fe(t, e, s) {
|
|
232
|
+
const n = i(t, "x", "transpose");
|
|
233
|
+
if (e == null && (e = n.shape.map((l, p) => p).reverse()), L(n.rank === e.length, () => `Error in transpose: rank of input ${n.rank} must match length of perm ${e}.`), e.forEach((l) => {
|
|
234
|
+
L(l >= 0 && l < n.rank, () => `All entries in 'perm' must be between 0 and ${n.rank - 1} but got ${e}`);
|
|
257
235
|
}), n.rank <= 1)
|
|
258
236
|
return n.clone();
|
|
259
|
-
const r = { x: n },
|
|
260
|
-
return n.dtype === "complex64" ?
|
|
261
|
-
let
|
|
262
|
-
return
|
|
263
|
-
}) :
|
|
237
|
+
const r = { x: n }, c = { perm: e };
|
|
238
|
+
return n.dtype === "complex64" ? ae(() => {
|
|
239
|
+
let l = _e(n), p = $e(n);
|
|
240
|
+
return l = o.runKernel(A, { x: l }, c), p = o.runKernel(A, { x: p }, c), s && (p = be(p)), pe(l, p);
|
|
241
|
+
}) : o.runKernel(A, r, c);
|
|
264
242
|
}
|
|
265
|
-
const
|
|
243
|
+
const Re = /* @__PURE__ */ h({ transpose_: Fe });
|
|
266
244
|
/**
|
|
267
245
|
* @license
|
|
268
246
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
@@ -279,36 +257,36 @@ const Be = /* @__PURE__ */ h({ transpose_: Ae });
|
|
|
279
257
|
* limitations under the License.
|
|
280
258
|
* =============================================================================
|
|
281
259
|
*/
|
|
282
|
-
function
|
|
260
|
+
function Ae(t, e, s) {
|
|
283
261
|
if (s == null || s === "linear")
|
|
284
262
|
return t;
|
|
285
263
|
if (s === "relu")
|
|
286
|
-
return
|
|
264
|
+
return ie(t, Oe(e));
|
|
287
265
|
throw new Error(`Cannot compute gradient for fused activation ${s}.`);
|
|
288
266
|
}
|
|
289
|
-
function
|
|
267
|
+
function Le(t, e) {
|
|
290
268
|
let s = e;
|
|
291
|
-
const n =
|
|
292
|
-
return n.length > 0 && (s =
|
|
269
|
+
const n = oe(t.shape, e.shape);
|
|
270
|
+
return n.length > 0 && (s = ce(s, n)), f(s, t.shape);
|
|
293
271
|
}
|
|
294
|
-
function
|
|
272
|
+
function Te(t, e, s, n) {
|
|
295
273
|
if (e === "linear")
|
|
296
274
|
return t;
|
|
297
275
|
if (e === "relu")
|
|
298
|
-
return
|
|
276
|
+
return we(t);
|
|
299
277
|
if (e === "elu")
|
|
300
|
-
return
|
|
278
|
+
return me(t);
|
|
301
279
|
if (e === "relu6")
|
|
302
|
-
return
|
|
280
|
+
return ze(t);
|
|
303
281
|
if (e === "prelu")
|
|
304
|
-
return
|
|
282
|
+
return Se(t, s);
|
|
305
283
|
if (e === "leakyrelu")
|
|
306
|
-
return
|
|
284
|
+
return ke(t, n);
|
|
307
285
|
if (e === "sigmoid")
|
|
308
|
-
return
|
|
286
|
+
return fe(t);
|
|
309
287
|
throw new Error(`Unknown fused activation ${e}.`);
|
|
310
288
|
}
|
|
311
|
-
const
|
|
289
|
+
const Be = (t, e) => !(t > 0) || e === "linear";
|
|
312
290
|
/**
|
|
313
291
|
* @license
|
|
314
292
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
@@ -325,49 +303,49 @@ const ve = (t, e) => !(t > 0) || e === "linear";
|
|
|
325
303
|
* limitations under the License.
|
|
326
304
|
* =============================================================================
|
|
327
305
|
*/
|
|
328
|
-
function
|
|
329
|
-
if (
|
|
306
|
+
function Ne({ a: t, b: e, transposeA: s = !1, transposeB: n = !1, bias: r, activation: c = "linear", preluActivationWeights: l, leakyreluAlpha: p = 0.2 }) {
|
|
307
|
+
if (Be(o.state.gradientDepth, c) === !1) {
|
|
330
308
|
let x = m(t, e, s, n);
|
|
331
|
-
return r != null && (x =
|
|
309
|
+
return r != null && (x = le(x, r)), Te(x, c, l, p);
|
|
332
310
|
}
|
|
333
|
-
let
|
|
334
|
-
[
|
|
335
|
-
const D = s ?
|
|
336
|
-
L(D === b, () => `Error in fused matMul: inner shapes (${D}) and (${b}) of Tensors with shapes ${
|
|
337
|
-
const O = P(
|
|
311
|
+
let u = i(t, "a", "fused matMul"), a = i(e, "b", "fused matMul");
|
|
312
|
+
[u, a] = q(u, a);
|
|
313
|
+
const D = s ? u.shape[u.rank - 2] : u.shape[u.rank - 1], b = n ? a.shape[a.rank - 1] : a.shape[a.rank - 2], w = s ? u.shape[u.rank - 1] : u.shape[u.rank - 2], W = n ? a.shape[a.rank - 2] : a.shape[a.rank - 1], T = u.shape.slice(0, -2), y = a.shape.slice(0, -2), B = C(T), N = C(y);
|
|
314
|
+
L(D === b, () => `Error in fused matMul: inner shapes (${D}) and (${b}) of Tensors with shapes ${u.shape} and ${a.shape} and transposeA=${s} and transposeB=${n} must match.`);
|
|
315
|
+
const O = P(u.shape.slice(0, -2), a.shape.slice(0, -2)).concat([w, W]), F = s ? f(u, [B, D, w]) : f(u, [B, w, D]), R = n ? f(a, [N, W, b]) : f(a, [N, b, W]);
|
|
338
316
|
let S;
|
|
339
|
-
r != null && (S =
|
|
317
|
+
r != null && (S = i(r, "bias", "fused matMul"), [S] = q(S, u), P(O, S.shape));
|
|
340
318
|
let G;
|
|
341
|
-
|
|
342
|
-
const I = (x,
|
|
343
|
-
const [g, $, k, z] =
|
|
344
|
-
let
|
|
345
|
-
if (!s && !n ? (
|
|
346
|
-
const Q =
|
|
347
|
-
return [
|
|
319
|
+
l != null && (G = i(l, "prelu weights", "fused matMul"));
|
|
320
|
+
const I = (x, M) => {
|
|
321
|
+
const [g, $, k, z] = M, d = Ae(f(x, k.shape), k, c);
|
|
322
|
+
let K, _;
|
|
323
|
+
if (!s && !n ? (K = m(d, $, !1, !0), _ = m(g, d, !0, !1)) : !s && n ? (K = m(d, $, !1, !1), _ = m(d, g, !0, !1)) : s && !n ? (K = m($, d, !1, !0), _ = m(g, d, !1, !1)) : (K = m($, d, !0, !0), _ = m(d, g, !0, !0)), r != null) {
|
|
324
|
+
const Q = Le(z, d);
|
|
325
|
+
return [K, _, Q];
|
|
348
326
|
} else
|
|
349
|
-
return [
|
|
350
|
-
},
|
|
327
|
+
return [K, _];
|
|
328
|
+
}, v = {
|
|
351
329
|
a: F,
|
|
352
330
|
b: R,
|
|
353
331
|
bias: S,
|
|
354
332
|
preluActivationWeights: G
|
|
355
|
-
},
|
|
356
|
-
return r == null ? U((
|
|
333
|
+
}, j = { transposeA: s, transposeB: n, activation: c, leakyreluAlpha: p };
|
|
334
|
+
return r == null ? U((M, g, $) => {
|
|
357
335
|
const k = (
|
|
358
336
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
359
|
-
|
|
337
|
+
o.runKernel(H, v, j)
|
|
360
338
|
);
|
|
361
|
-
return $([
|
|
362
|
-
})(F, R) : U((
|
|
339
|
+
return $([M, g, k]), { value: f(k, O), gradFunc: I };
|
|
340
|
+
})(F, R) : U((M, g, $, k) => {
|
|
363
341
|
const z = (
|
|
364
342
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
365
|
-
|
|
343
|
+
o.runKernel(H, v, j)
|
|
366
344
|
);
|
|
367
|
-
return k([
|
|
345
|
+
return k([M, g, z, $]), { value: f(z, O), gradFunc: I };
|
|
368
346
|
})(F, R, S);
|
|
369
347
|
}
|
|
370
|
-
const J = /* @__PURE__ */ h({ fusedMatMul_:
|
|
348
|
+
const J = /* @__PURE__ */ h({ fusedMatMul_: Ne });
|
|
371
349
|
/**
|
|
372
350
|
* @license
|
|
373
351
|
* Copyright 2018 Google LLC
|
|
@@ -391,12 +369,12 @@ class E extends Error {
|
|
|
391
369
|
* https://opensource.org/licenses/MIT.
|
|
392
370
|
* =============================================================================
|
|
393
371
|
*/
|
|
394
|
-
function
|
|
372
|
+
function Ge(t, e, s, n) {
|
|
395
373
|
if (t.rank < 2 || e.rank < 2)
|
|
396
374
|
throw new E(`dot requires both inputs to be rank >= 2 but got x shape = ${t.shape} and y shape = ${e.shape}`);
|
|
397
375
|
if (e.rank >= 3) {
|
|
398
|
-
const r = t.shape.slice(-1)[0],
|
|
399
|
-
if (r !==
|
|
376
|
+
const r = t.shape.slice(-1)[0], c = e.shape.slice(-2)[0];
|
|
377
|
+
if (r !== c)
|
|
400
378
|
throw new E(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${t.shape} and y shape = ${e.shape}`);
|
|
401
379
|
}
|
|
402
380
|
if (t.rank === 2 && e.rank === 2)
|
|
@@ -409,11 +387,11 @@ function Ie(t, e, s, n) {
|
|
|
409
387
|
activation: s
|
|
410
388
|
});
|
|
411
389
|
{
|
|
412
|
-
const r = t.shape.slice(),
|
|
413
|
-
t = f(t, [-1,
|
|
414
|
-
const
|
|
415
|
-
e = f(
|
|
416
|
-
const b = [...r, ...
|
|
390
|
+
const r = t.shape.slice(), c = r.pop();
|
|
391
|
+
t = f(t, [-1, c]);
|
|
392
|
+
const l = e.shape.slice(), p = l.pop(), u = l.pop(), a = [...l, p], D = Array.from({ length: e.rank }, (T, y) => y === 0 ? e.rank - 2 : y <= e.rank - 2 ? y - 1 : y);
|
|
393
|
+
e = f(Re(e, D), [u, -1]);
|
|
394
|
+
const b = [...r, ...a];
|
|
417
395
|
return f(J({
|
|
418
396
|
a: t,
|
|
419
397
|
b: e,
|
|
@@ -424,7 +402,7 @@ function Ie(t, e, s, n) {
|
|
|
424
402
|
}), b);
|
|
425
403
|
}
|
|
426
404
|
}
|
|
427
|
-
class
|
|
405
|
+
class Pe {
|
|
428
406
|
vocabSize;
|
|
429
407
|
embedDim;
|
|
430
408
|
tf;
|
|
@@ -447,7 +425,7 @@ class Ue {
|
|
|
447
425
|
return this.tf.gather(this.tiedWeights, e, 0);
|
|
448
426
|
}
|
|
449
427
|
project(e) {
|
|
450
|
-
return
|
|
428
|
+
return Ge(e, this.tiedWeights.transpose());
|
|
451
429
|
}
|
|
452
430
|
getWeights() {
|
|
453
431
|
return [this.tiedWeights];
|
|
@@ -466,5 +444,5 @@ class Ue {
|
|
|
466
444
|
}
|
|
467
445
|
}
|
|
468
446
|
export {
|
|
469
|
-
|
|
447
|
+
Pe as default
|
|
470
448
|
};
|
package/dist/main.js
CHANGED
|
@@ -1,20 +1,21 @@
|
|
|
1
|
-
import { default as
|
|
2
|
-
import { default as
|
|
3
|
-
import { default as
|
|
1
|
+
import { default as m } from "./NanoGPTModel.js";
|
|
2
|
+
import { default as i } from "./TeachableLLM.js";
|
|
3
|
+
import { default as l } from "./tokeniser/CharTokeniser.js";
|
|
4
4
|
import { default as d } from "./utilities/waitForModel.js";
|
|
5
|
-
import { default as
|
|
6
|
-
import { estimateMemoryUsage as
|
|
5
|
+
import { default as x } from "./data/textLoader.js";
|
|
6
|
+
import { estimateMemoryUsage as T, estimateParameterCount as g, estimateResources as M, estimateTrainingMemoryUsage as C, validateConfig as c } from "./utilities/parameters.js";
|
|
7
7
|
import "./ops/scatterSub.js";
|
|
8
8
|
import "./ops/gatherSub.js";
|
|
9
|
+
import "./ops/attentionMask.js";
|
|
9
10
|
export {
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
11
|
+
l as CharTokeniser,
|
|
12
|
+
m as NanoGPT,
|
|
13
|
+
i as TeachableLLM,
|
|
14
|
+
T as estimateMemoryUsage,
|
|
15
|
+
g as estimateParameterCount,
|
|
16
|
+
M as estimateResources,
|
|
17
|
+
C as estimateTrainingMemoryUsage,
|
|
18
|
+
x as loadTextData,
|
|
19
|
+
c as validateConfig,
|
|
19
20
|
d as waitForModel
|
|
20
21
|
};
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import { o as c, c as s, b as m, E as M, B as p } from "./index-DQfEAU9u.js";
|
|
2
|
+
/**
|
|
3
|
+
* @license
|
|
4
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
5
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
* you may not use this file except in compliance with the License.
|
|
7
|
+
* You may obtain a copy of the License at
|
|
8
|
+
*
|
|
9
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
*
|
|
11
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
* See the License for the specific language governing permissions and
|
|
15
|
+
* limitations under the License.
|
|
16
|
+
* =============================================================================
|
|
17
|
+
*/
|
|
18
|
+
function b(e, o, n = !1, l = !1) {
|
|
19
|
+
let a = s(e, "a", "matMul"), t = s(o, "b", "matMul");
|
|
20
|
+
[a, t] = m(a, t);
|
|
21
|
+
const r = { a, b: t }, u = { transposeA: n, transposeB: l };
|
|
22
|
+
return M.runKernel(p, r, u);
|
|
23
|
+
}
|
|
24
|
+
const i = /* @__PURE__ */ c({ matMul_: b });
|
|
25
|
+
export {
|
|
26
|
+
i as m
|
|
27
|
+
};
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import { engine as d } from "@tensorflow/tfjs";
|
|
2
|
+
import { r as k, s as u } from "../index-DQfEAU9u.js";
|
|
3
|
+
import { m as l } from "../mat_mul-CuHB58-H.js";
|
|
4
|
+
class p {
|
|
5
|
+
variableNames = ["q", "k", "mask"];
|
|
6
|
+
outputShape;
|
|
7
|
+
userCode;
|
|
8
|
+
// enableShapeUniforms = true;
|
|
9
|
+
customUniforms = [{ name: "divisor", type: "float" }];
|
|
10
|
+
constructor(t, e, n, a) {
|
|
11
|
+
this.outputShape = [t, e, n, n], this.userCode = `
|
|
12
|
+
void main() {
|
|
13
|
+
ivec4 coords = getOutputCoords(); // [batch, nh, t1, t2]
|
|
14
|
+
int b = coords.x;
|
|
15
|
+
int h = coords.y;
|
|
16
|
+
int t1 = coords.z;
|
|
17
|
+
int t2 = coords.w;
|
|
18
|
+
|
|
19
|
+
float sum = 0.0;
|
|
20
|
+
for (int i = 0; i < ${a}; ++i) {
|
|
21
|
+
float qv = getQ(b, h, t1, i);
|
|
22
|
+
float kv = getK(b, h, t2, i); // k is transposed on last two dims
|
|
23
|
+
sum += qv * kv;
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
// Scale by divisor
|
|
27
|
+
float scaled = sum * divisor;
|
|
28
|
+
|
|
29
|
+
// Add mask
|
|
30
|
+
float maskVal = getMask(t1, t2); // mask is [T,T]
|
|
31
|
+
|
|
32
|
+
setOutput(scaled + maskVal);
|
|
33
|
+
}
|
|
34
|
+
`;
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
function f(s) {
|
|
38
|
+
const { q: t, k: e, mask: n } = s.inputs, { divisor: a } = s.attrs, o = s.backend, c = t.shape[0], i = t.shape[2], r = t.shape[1], m = new p(c, r, i, t.shape[3]);
|
|
39
|
+
return o.runWebGLProgram(m, [t, e, n], "float32", [[a]]);
|
|
40
|
+
}
|
|
41
|
+
const h = {
|
|
42
|
+
kernelName: "AttentionMask",
|
|
43
|
+
backendName: "webgl",
|
|
44
|
+
kernelFunc: f
|
|
45
|
+
};
|
|
46
|
+
k(h);
|
|
47
|
+
function b(s) {
|
|
48
|
+
const { q: t, k: e, mask: n } = s.inputs, { divisor: a } = s.attrs, o = t.shape[2], i = l(t, e, !1, !0).mul(u(a)), r = n.slice([0, 0], [o, o]).expandDims(0).expandDims(0);
|
|
49
|
+
return i.add(r);
|
|
50
|
+
}
|
|
51
|
+
const v = {
|
|
52
|
+
kernelName: "AttentionMask",
|
|
53
|
+
backendName: "cpu",
|
|
54
|
+
kernelFunc: b
|
|
55
|
+
};
|
|
56
|
+
k(v);
|
|
57
|
+
function C(s, t, e, n) {
|
|
58
|
+
return d().runKernel("AttentionMask", { q: s, k: t, mask: e }, { divisor: n });
|
|
59
|
+
}
|
|
60
|
+
export {
|
|
61
|
+
C as attentionMask
|
|
62
|
+
};
|
package/dist/ops/gatherSub.js
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { engine as l } from "@tensorflow/tfjs";
|
|
2
|
-
import { o as g, c as i, E as b, G as d, r as c, a as h } from "../index-
|
|
3
|
-
import { r as p, s as f } from "../stack-
|
|
2
|
+
import { o as g, c as i, E as b, G as d, r as c, a as h } from "../index-DQfEAU9u.js";
|
|
3
|
+
import { r as p, s as f } from "../stack-C9cTkqpq.js";
|
|
4
4
|
/**
|
|
5
5
|
* @license
|
|
6
6
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
package/dist/ops/scatterSub.js
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { engine as $ } from "@tensorflow/tfjs";
|
|
2
|
-
import {
|
|
3
|
-
import { c as m } from "../complex-
|
|
4
|
-
import { r as v, s as T } from "../stack-
|
|
2
|
+
import { j as u, k as S, l as p, E as f, n as E, o as N, c as l, p as y, r as h, a as D, m as x } from "../index-DQfEAU9u.js";
|
|
3
|
+
import { c as m } from "../complex-CeoYJn2o.js";
|
|
4
|
+
import { r as v, s as T } from "../stack-C9cTkqpq.js";
|
|
5
5
|
/**
|
|
6
6
|
* @license
|
|
7
7
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -23,7 +23,7 @@ function i(e, t = "float32") {
|
|
|
23
23
|
const a = i(e, "float32"), o = i(e, "float32");
|
|
24
24
|
return m(a, o);
|
|
25
25
|
}
|
|
26
|
-
const r = S(
|
|
26
|
+
const r = S(p(e), t);
|
|
27
27
|
return f.makeTensor(r, e, t);
|
|
28
28
|
}
|
|
29
29
|
/**
|
|
@@ -47,7 +47,7 @@ function d(e, t = "float32") {
|
|
|
47
47
|
const a = d(e, "float32"), o = i(e, "float32");
|
|
48
48
|
return m(a, o);
|
|
49
49
|
}
|
|
50
|
-
const r = E(
|
|
50
|
+
const r = E(p(e), t);
|
|
51
51
|
return f.makeTensor(r, e, t);
|
|
52
52
|
}
|
|
53
53
|
function C(e, t, r) {
|
|
@@ -131,7 +131,7 @@ const K = {
|
|
|
131
131
|
backendName: "webgl",
|
|
132
132
|
kernelFunc: P
|
|
133
133
|
};
|
|
134
|
-
|
|
134
|
+
h(K);
|
|
135
135
|
function A(e) {
|
|
136
136
|
const { logits: t, labels: r, dy: a } = e.inputs, o = r.shape[0], s = t.shape[1], n = v(0, o, 1, "int32"), c = T([n, r], 1), b = d([o]), g = I(c, b, [o, s]), k = D(t, g), w = a.reshape([o, 1]);
|
|
137
137
|
return x(k, w);
|
|
@@ -141,7 +141,7 @@ const F = {
|
|
|
141
141
|
backendName: "cpu",
|
|
142
142
|
kernelFunc: A
|
|
143
143
|
};
|
|
144
|
-
|
|
144
|
+
h(F);
|
|
145
145
|
function M(e, t, r) {
|
|
146
146
|
return $().runKernel("EfficientScatterSub", { logits: e, labels: t, dy: r }, {});
|
|
147
147
|
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { E as e, R as c, o as f,
|
|
1
|
+
import { E as e, R as c, o as f, f as u, g as a, P as i } from "./index-DQfEAU9u.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -15,7 +15,7 @@ import { E as e, R as c, o as f, d as u, f as a, P as i } from "./index-D1SlunD-
|
|
|
15
15
|
* limitations under the License.
|
|
16
16
|
* =============================================================================
|
|
17
17
|
*/
|
|
18
|
-
function
|
|
18
|
+
function l(n, s, t = 1, r = "float32") {
|
|
19
19
|
if (t === 0)
|
|
20
20
|
throw new Error("Cannot have a step of zero");
|
|
21
21
|
const o = { start: n, stop: s, step: t, dtype: r };
|
|
@@ -45,6 +45,6 @@ function k(n, s = 0) {
|
|
|
45
45
|
}
|
|
46
46
|
const h = /* @__PURE__ */ f({ stack_: k });
|
|
47
47
|
export {
|
|
48
|
-
|
|
48
|
+
l as r,
|
|
49
49
|
h as s
|
|
50
50
|
};
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { o, c as a, E as u,
|
|
1
|
+
import { o, c as a, E as u, h as i, i as p, S as x } from "./index-DQfEAU9u.js";
|
|
2
2
|
/**
|
|
3
3
|
* @license
|
|
4
4
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
@@ -17,7 +17,7 @@ import { o, c as a, E as u, g as p, h as i, S as x } from "./index-D1SlunD-.js";
|
|
|
17
17
|
*/
|
|
18
18
|
function l(n, t) {
|
|
19
19
|
const s = { x: a(n, "x", "reshape", "string_or_numeric") }, r = { shape: t };
|
|
20
|
-
return u.runKernel(
|
|
20
|
+
return u.runKernel(i, s, r);
|
|
21
21
|
}
|
|
22
22
|
const h = /* @__PURE__ */ o({ reshape_: l });
|
|
23
23
|
/**
|
|
@@ -38,7 +38,7 @@ const h = /* @__PURE__ */ o({ reshape_: l });
|
|
|
38
38
|
*/
|
|
39
39
|
function m(n, t = null, e = !1) {
|
|
40
40
|
let s = a(n, "x", "sum");
|
|
41
|
-
s.dtype === "bool" && (s =
|
|
41
|
+
s.dtype === "bool" && (s = p(s, "int32"));
|
|
42
42
|
const r = { x: s }, c = { axis: t, keepDims: e };
|
|
43
43
|
return u.runKernel(x, r, c);
|
|
44
44
|
}
|
package/dist/training/AdamExt.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { A as r, m as c, s as h, a as g, e as o } from "../index-
|
|
1
|
+
import { A as r, m as c, s as h, a as g, e as o } from "../index-DQfEAU9u.js";
|
|
2
2
|
class u extends r {
|
|
3
3
|
constructor(t, e, s, a, i) {
|
|
4
4
|
super(t, e, s, a), this.config = i, this.startLearningRate = t;
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { gatherSub as
|
|
2
|
-
import { scatterSub as
|
|
3
|
-
import { o as l, c as d, E as f, M as
|
|
4
|
-
import { s as F, r as b } from "../sum-
|
|
1
|
+
import { gatherSub as w } from "../ops/gatherSub.js";
|
|
2
|
+
import { scatterSub as K } from "../ops/scatterSub.js";
|
|
3
|
+
import { o as l, c as d, E as f, M as _, q as z, L as I, t as N, a as E, u as M, v as T, e as m, w as g, x as $, z as S } from "../index-DQfEAU9u.js";
|
|
4
|
+
import { s as F, r as b } from "../sum-B-O33dgG.js";
|
|
5
5
|
/**
|
|
6
6
|
* @license
|
|
7
7
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
@@ -47,9 +47,9 @@ function q(n, s) {
|
|
|
47
47
|
*/
|
|
48
48
|
function A(n, s = null, t = !1) {
|
|
49
49
|
const e = { x: d(n, "x", "max") }, r = { reductionIndices: s, keepDims: t };
|
|
50
|
-
return f.runKernel(
|
|
50
|
+
return f.runKernel(_, e, r);
|
|
51
51
|
}
|
|
52
|
-
const
|
|
52
|
+
const L = /* @__PURE__ */ l({ max_: A });
|
|
53
53
|
/**
|
|
54
54
|
* @license
|
|
55
55
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
@@ -109,7 +109,7 @@ const j = /* @__PURE__ */ l({ log_: W });
|
|
|
109
109
|
* =============================================================================
|
|
110
110
|
*/
|
|
111
111
|
function B(n, s = null, t = !1) {
|
|
112
|
-
const a = d(n, "x", "logSumExp"), e = N(s, a.shape), r =
|
|
112
|
+
const a = d(n, "x", "logSumExp"), e = N(s, a.shape), r = L(
|
|
113
113
|
a,
|
|
114
114
|
e,
|
|
115
115
|
!0
|
|
@@ -148,30 +148,30 @@ function J(n, s = -1) {
|
|
|
148
148
|
const Q = /* @__PURE__ */ l({ softmax_: J });
|
|
149
149
|
function R(n, s) {
|
|
150
150
|
return $(() => {
|
|
151
|
-
const t = n.shape[n.shape.length - 1], e = n.shape.slice(0, -1).reduce((h, x) => h * x, 1), r = n.shape.length > 2 ? n.reshape([e, t]) : n, c = s.shape.length > 1 ? s.reshape([e]).cast("int32") : s.cast("int32"), o =
|
|
152
|
-
return
|
|
151
|
+
const t = n.shape[n.shape.length - 1], e = n.shape.slice(0, -1).reduce((h, x) => h * x, 1), r = n.shape.length > 2 ? n.reshape([e, t]) : n, c = s.shape.length > 1 ? s.reshape([e]).cast("int32") : s.cast("int32"), o = L(r, -1, !0), p = E(r, o), u = H(p, -1);
|
|
152
|
+
return w(u, c, p);
|
|
153
153
|
});
|
|
154
154
|
}
|
|
155
|
-
function
|
|
156
|
-
return m().backendName === "tensorflow" ?
|
|
155
|
+
function ss() {
|
|
156
|
+
return m().backendName === "tensorflow" ? g((s, t, a) => {
|
|
157
157
|
const e = s.shape.length > 2 ? s.reshape([-1, s.shape[s.shape.length - 1]]) : s, r = t.shape.length > 1 ? t.reshape([-1]).cast("int32") : t.cast("int32"), [c, o] = m().runKernel(
|
|
158
158
|
"NativeSparseSoftmaxCrossEntropy",
|
|
159
159
|
{ logits: e, labels: r },
|
|
160
160
|
{}
|
|
161
161
|
);
|
|
162
|
-
return a([o.reshape(s.shape)]), { value: c, gradFunc: (p, u) => [u[0],
|
|
163
|
-
}) :
|
|
162
|
+
return a([o.reshape(s.shape)]), { value: c, gradFunc: (p, u) => [u[0], S(t)] };
|
|
163
|
+
}) : g(
|
|
164
164
|
// @ts-expect-error Invalid params
|
|
165
165
|
(s, t, a) => {
|
|
166
166
|
const e = s.shape[s.shape.length - 1], c = s.shape.slice(0, -1).reduce((h, x) => h * x, 1), o = s.reshape([c, e]), p = t.reshape([c]).cast("int32"), u = R(o, p);
|
|
167
167
|
return a([o, p]), o.dispose(), p.dispose(), { value: u, gradFunc: (h, x) => $(() => {
|
|
168
|
-
const
|
|
169
|
-
return [
|
|
168
|
+
const k = x[0], y = x[1], C = Q(k), G = K(C, y, h), v = S(t);
|
|
169
|
+
return [G.reshape(s.shape), v];
|
|
170
170
|
}) };
|
|
171
171
|
}
|
|
172
172
|
);
|
|
173
173
|
}
|
|
174
174
|
export {
|
|
175
|
-
|
|
175
|
+
ss as createSoftmaxCrossEntropyWithGrad,
|
|
176
176
|
R as sparseSoftmaxCrossEntropy
|
|
177
177
|
};
|