@genai-fi/nanogpt 0.4.5 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/BaseLayer-BhrMN8JO.js +135 -0
- package/dist/Generator.js +52 -49
- package/dist/NanoGPTModel.d.ts +13 -17
- package/dist/NanoGPTModel.js +128 -136
- package/dist/{Reshape-CiAY8ltP.js → Reshape-BE5rA4rT.js} +8 -8
- package/dist/TeachableLLM.js +1 -1
- package/dist/{TiedEmbedding-DznFwzcB.js → TiedEmbedding-DsDRvLB0.js} +751 -768
- package/dist/{axis_util-QP0LdI1v.js → axis_util-97KkkyRQ.js} +1 -1
- package/dist/broadcast_to-CMlkG8NS.js +44 -0
- package/dist/{concat-DvWM7HGZ.js → concat-Cxbo2sOz.js} +3 -3
- package/dist/{dropout-DFEXTPV0.js → dropout-kbDY39Ci.js} +1 -1
- package/dist/{gather-C5D8PxwA.js → gather-Bxe1Qip8.js} +4 -4
- package/dist/{gpgpu_math-CUzjlO9A.js → gpgpu_math-C0zyxKFi.js} +1 -1
- package/dist/{index--6vO-cOz.js → index-iNhkcAEQ.js} +82 -82
- package/dist/{kernel_funcs_utils-C6YBCuOt.js → kernel_funcs_utils-C4eIk4fE.js} +20 -20
- package/dist/layers/BaseLayer.d.ts +28 -4
- package/dist/layers/BaseLayer.js +3 -16
- package/dist/layers/CausalSelfAttention.d.ts +21 -24
- package/dist/layers/CausalSelfAttention.js +73 -128
- package/dist/layers/MLP.d.ts +8 -15
- package/dist/layers/MLP.js +43 -81
- package/dist/layers/RMSNorm.d.ts +5 -10
- package/dist/layers/RMSNorm.js +13 -29
- package/dist/layers/RoPECache.js +14 -12
- package/dist/layers/TiedEmbedding.d.ts +6 -16
- package/dist/layers/TiedEmbedding.js +5 -5
- package/dist/layers/TransformerBlock.d.ts +12 -16
- package/dist/layers/TransformerBlock.js +20 -41
- package/dist/{log_sum_exp-CiEy1aUe.js → log_sum_exp-CkumwesB.js} +11 -11
- package/dist/main.js +1 -1
- package/dist/{mat_mul-BEHRPMh0.js → mat_mul-D0SifYfJ.js} +3 -3
- package/dist/{max-BUShNgfh.js → max-CYaAjEEp.js} +3 -3
- package/dist/{moments-DYOHXoRV.js → moments-B06NlR_V.js} +6 -6
- package/dist/{norm-DSva3hI3.js → norm-D3676xIo.js} +7 -7
- package/dist/{ones-D6kB8bdY.js → ones-BIeFnPHR.js} +2 -2
- package/dist/ops/appendCache.js +4 -4
- package/dist/ops/attentionMask.d.ts +1 -1
- package/dist/ops/attentionMask.js +4 -4
- package/dist/ops/cpu/appendCache.js +2 -2
- package/dist/ops/cpu/attentionMask.js +14 -15
- package/dist/ops/cpu/fusedSoftmax.js +2 -2
- package/dist/ops/cpu/gatherSub.js +5 -5
- package/dist/ops/cpu/gelu.js +1 -1
- package/dist/ops/cpu/matMulGelu.js +1 -1
- package/dist/ops/cpu/matMulMul.d.ts +1 -0
- package/dist/ops/cpu/matMulMul.js +17 -0
- package/dist/ops/cpu/mulDropout.js +1 -1
- package/dist/ops/cpu/normRMS.js +1 -1
- package/dist/ops/cpu/qkv.js +3 -3
- package/dist/ops/cpu/rope.js +5 -5
- package/dist/ops/cpu/scatterSub.js +8 -8
- package/dist/ops/fusedSoftmax.js +1 -1
- package/dist/ops/gatherSub.js +1 -1
- package/dist/ops/gelu.js +1 -1
- package/dist/ops/grads/attentionMask.js +13 -9
- package/dist/ops/grads/fusedSoftmax.js +12 -9
- package/dist/ops/grads/gelu.js +1 -1
- package/dist/ops/grads/matMulGelu.js +1 -1
- package/dist/ops/grads/normRMS.js +1 -1
- package/dist/ops/grads/qkv.js +19 -9
- package/dist/ops/grads/rope.js +1 -1
- package/dist/ops/matMulGelu.js +1 -1
- package/dist/ops/matMulMul.d.ts +2 -0
- package/dist/ops/matMulMul.js +9 -0
- package/dist/ops/mulDrop.js +1 -1
- package/dist/ops/node/sparseCrossEntropy.js +1 -1
- package/dist/ops/normRMS.js +1 -1
- package/dist/ops/qkv.js +1 -1
- package/dist/ops/scatterSub.js +1 -1
- package/dist/ops/webgl/appendCache.js +1 -1
- package/dist/ops/webgl/attentionMask.js +13 -12
- package/dist/ops/webgl/fusedSoftmax.js +43 -40
- package/dist/ops/webgl/gatherSub.js +1 -1
- package/dist/ops/webgl/gelu.js +2 -2
- package/dist/ops/webgl/matMulGelu.js +17 -17
- package/dist/ops/webgl/matMulMul.d.ts +14 -0
- package/dist/ops/webgl/matMulMul.js +28 -0
- package/dist/ops/webgl/mulDropout.js +1 -1
- package/dist/ops/webgl/normRMS.js +29 -21
- package/dist/ops/webgl/qkv.js +1 -1
- package/dist/ops/webgl/rope.js +1 -1
- package/dist/ops/webgl/scatterSub.js +1 -1
- package/dist/ops-ObfXLHYQ.js +1269 -0
- package/dist/{range-C_vpUjBu.js → range-BsFU-SNG.js} +1 -1
- package/dist/{reshape-z51Eu-re.js → reshape-DxTPgnwL.js} +3 -3
- package/dist/{sin-H567uayl.js → sin-BOX-JVAj.js} +5 -5
- package/dist/slice_util-D-kaD4ZV.js +49 -0
- package/dist/{softmax-Dsxflvdl.js → softmax-BjsptB07.js} +2 -2
- package/dist/{split-B_k_jwud.js → split-BCbrzthj.js} +4 -4
- package/dist/{stack-CmqSdsfs.js → stack--cqr9Dgc.js} +2 -2
- package/dist/{sum-DdkDf2MG.js → sum-B_92TaHD.js} +5 -5
- package/dist/{tensor-BGYi41cj.js → tensor-CfiPXsW4.js} +1 -1
- package/dist/{tensor2d-DUr_htjt.js → tensor2d-tSxWdFMH.js} +1 -1
- package/dist/tfjs_backend-NucKez4s.js +1010 -0
- package/dist/training/AdamExt.js +1 -1
- package/dist/training/DatasetBuilder.js +44 -44
- package/dist/training/Evaluator.js +6 -6
- package/dist/training/FullTrainer.js +1 -1
- package/dist/training/Trainer.js +7 -7
- package/dist/training/sparseCrossEntropy.js +4 -4
- package/dist/utilities/dummy.js +10 -10
- package/dist/utilities/generate.js +3 -3
- package/dist/utilities/load.js +1 -1
- package/dist/utilities/profile.js +1 -1
- package/dist/utilities/save.js +13 -11
- package/dist/utilities/weights.js +2 -2
- package/dist/{zeros-8xl-W2DC.js → zeros-NMYTayy7.js} +3 -3
- package/package.json +1 -1
- package/dist/slice_util-BdhYwFY_.js +0 -90
- package/dist/tfjs_backend-DuKis_xG.js +0 -2271
- package/dist/variable-BJTZ3jOy.js +0 -23
|
@@ -0,0 +1,1010 @@
|
|
|
1
|
+
import { o as h, i as f, E as $, ap as Te, k as _, g as Ee, aq as xe, ar as Ie, as as Le, at as Ne, au as be, av as Ce, aw as Pe, b as Q, ax as Fe, a9 as U, q as ae, p as ie, N as le, c as fe, ay as he, am as pe, az as je, t as S, y as $e, ai as Me, a4 as Be } from "./index-iNhkcAEQ.js";
|
|
2
|
+
import { s as C, t as Ke, a as Ue, b as ve } from "./ops-ObfXLHYQ.js";
|
|
3
|
+
import { r as Re, d as Ve } from "./dropout-kbDY39Ci.js";
|
|
4
|
+
import { r as u } from "./reshape-DxTPgnwL.js";
|
|
5
|
+
import { g as qe } from "./gather-Bxe1Qip8.js";
|
|
6
|
+
import { s as Ge } from "./sum-B_92TaHD.js";
|
|
7
|
+
import { m as A } from "./mat_mul-D0SifYfJ.js";
|
|
8
|
+
import { c as M } from "./concat-Cxbo2sOz.js";
|
|
9
|
+
/**
|
|
10
|
+
* @license
|
|
11
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
12
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
13
|
+
* you may not use this file except in compliance with the License.
|
|
14
|
+
* You may obtain a copy of the License at
|
|
15
|
+
*
|
|
16
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
17
|
+
*
|
|
18
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
19
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
20
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
21
|
+
* See the License for the specific language governing permissions and
|
|
22
|
+
* limitations under the License.
|
|
23
|
+
* =============================================================================
|
|
24
|
+
*/
|
|
25
|
+
function Je(e) {
|
|
26
|
+
const t = { x: f(e, "x", "sigmoid", "float32") };
|
|
27
|
+
return $.runKernel(Te, t);
|
|
28
|
+
}
|
|
29
|
+
const Ze = /* @__PURE__ */ h({ sigmoid_: Je });
|
|
30
|
+
/**
|
|
31
|
+
* @license
|
|
32
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
33
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
34
|
+
* you may not use this file except in compliance with the License.
|
|
35
|
+
* You may obtain a copy of the License at
|
|
36
|
+
*
|
|
37
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
38
|
+
*
|
|
39
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
40
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
41
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
42
|
+
* See the License for the specific language governing permissions and
|
|
43
|
+
* limitations under the License.
|
|
44
|
+
* =============================================================================
|
|
45
|
+
*/
|
|
46
|
+
function We(e, n, t) {
|
|
47
|
+
const r = f(e, "x", "clipByValue");
|
|
48
|
+
if (_(n <= t, () => `Error in clip: min (${n}) must be less than or equal to max (${t}).`), n === t)
|
|
49
|
+
return Ee(r.shape, n, r.dtype);
|
|
50
|
+
const s = { x: r }, o = { clipValueMin: n, clipValueMax: t };
|
|
51
|
+
return $.runKernel(xe, s, o);
|
|
52
|
+
}
|
|
53
|
+
const Ye = /* @__PURE__ */ h({ clipByValue_: We });
|
|
54
|
+
function He(e) {
|
|
55
|
+
return M(
|
|
56
|
+
e,
|
|
57
|
+
0
|
|
58
|
+
/* axis */
|
|
59
|
+
);
|
|
60
|
+
}
|
|
61
|
+
const Qe = /* @__PURE__ */ h({ concat1d_: He });
|
|
62
|
+
function Xe(e, n) {
|
|
63
|
+
return M(e, n);
|
|
64
|
+
}
|
|
65
|
+
const ze = /* @__PURE__ */ h({ concat2d_: Xe });
|
|
66
|
+
function en(e, n) {
|
|
67
|
+
return M(e, n);
|
|
68
|
+
}
|
|
69
|
+
const nn = /* @__PURE__ */ h({ concat3d_: en });
|
|
70
|
+
function tn(e, n) {
|
|
71
|
+
return M(e, n);
|
|
72
|
+
}
|
|
73
|
+
const rn = /* @__PURE__ */ h({ concat4d_: tn });
|
|
74
|
+
/**
|
|
75
|
+
* @license
|
|
76
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
77
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
78
|
+
* you may not use this file except in compliance with the License.
|
|
79
|
+
* You may obtain a copy of the License at
|
|
80
|
+
*
|
|
81
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
82
|
+
*
|
|
83
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
84
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
85
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
86
|
+
* See the License for the specific language governing permissions and
|
|
87
|
+
* limitations under the License.
|
|
88
|
+
* =============================================================================
|
|
89
|
+
*/
|
|
90
|
+
function sn(e) {
|
|
91
|
+
const t = { x: f(e, "x", "elu", "float32") };
|
|
92
|
+
return $.runKernel(Ie, t);
|
|
93
|
+
}
|
|
94
|
+
const ke = /* @__PURE__ */ h({ elu_: sn });
|
|
95
|
+
/**
|
|
96
|
+
* @license
|
|
97
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
98
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
99
|
+
* you may not use this file except in compliance with the License.
|
|
100
|
+
* You may obtain a copy of the License at
|
|
101
|
+
*
|
|
102
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
103
|
+
*
|
|
104
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
105
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
106
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
107
|
+
* See the License for the specific language governing permissions and
|
|
108
|
+
* limitations under the License.
|
|
109
|
+
* =============================================================================
|
|
110
|
+
*/
|
|
111
|
+
function on(e, n = 0.2) {
|
|
112
|
+
const r = { x: f(e, "x", "leakyRelu") }, s = { alpha: n };
|
|
113
|
+
return $.runKernel(Le, r, s);
|
|
114
|
+
}
|
|
115
|
+
const un = /* @__PURE__ */ h({ leakyRelu_: on });
|
|
116
|
+
/**
|
|
117
|
+
* @license
|
|
118
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
119
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
120
|
+
* you may not use this file except in compliance with the License.
|
|
121
|
+
* You may obtain a copy of the License at
|
|
122
|
+
*
|
|
123
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
124
|
+
*
|
|
125
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
126
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
127
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
128
|
+
* See the License for the specific language governing permissions and
|
|
129
|
+
* limitations under the License.
|
|
130
|
+
* =============================================================================
|
|
131
|
+
*/
|
|
132
|
+
function cn(e, n) {
|
|
133
|
+
const t = f(e, "x", "prelu"), r = f(n, "alpha", "prelu"), s = { x: t, alpha: r };
|
|
134
|
+
return $.runKernel(Ne, s);
|
|
135
|
+
}
|
|
136
|
+
const an = /* @__PURE__ */ h({ prelu_: cn });
|
|
137
|
+
/**
|
|
138
|
+
* @license
|
|
139
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
140
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
141
|
+
* you may not use this file except in compliance with the License.
|
|
142
|
+
* You may obtain a copy of the License at
|
|
143
|
+
*
|
|
144
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
145
|
+
*
|
|
146
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
147
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
148
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
149
|
+
* See the License for the specific language governing permissions and
|
|
150
|
+
* limitations under the License.
|
|
151
|
+
* =============================================================================
|
|
152
|
+
*/
|
|
153
|
+
function ln(e) {
|
|
154
|
+
const t = { x: f(e, "x", "relu") };
|
|
155
|
+
return $.runKernel(be, t);
|
|
156
|
+
}
|
|
157
|
+
const fn = /* @__PURE__ */ h({ relu_: ln });
|
|
158
|
+
/**
|
|
159
|
+
* @license
|
|
160
|
+
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
161
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
162
|
+
* you may not use this file except in compliance with the License.
|
|
163
|
+
* You may obtain a copy of the License at
|
|
164
|
+
*
|
|
165
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
166
|
+
*
|
|
167
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
168
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
169
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
170
|
+
* See the License for the specific language governing permissions and
|
|
171
|
+
* limitations under the License.
|
|
172
|
+
* =============================================================================
|
|
173
|
+
*/
|
|
174
|
+
function hn(e) {
|
|
175
|
+
const t = { x: f(e, "x", "relu6") };
|
|
176
|
+
return $.runKernel(Ce, t);
|
|
177
|
+
}
|
|
178
|
+
const pn = /* @__PURE__ */ h({ relu6_: hn });
|
|
179
|
+
/**
|
|
180
|
+
* @license
|
|
181
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
182
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
183
|
+
* you may not use this file except in compliance with the License.
|
|
184
|
+
* You may obtain a copy of the License at
|
|
185
|
+
*
|
|
186
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
187
|
+
*
|
|
188
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
189
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
190
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
191
|
+
* See the License for the specific language governing permissions and
|
|
192
|
+
* limitations under the License.
|
|
193
|
+
* =============================================================================
|
|
194
|
+
*/
|
|
195
|
+
function dn(e, n, t) {
|
|
196
|
+
const r = f(e, "x", "slice1d");
|
|
197
|
+
return _(r.rank === 1, () => `slice1d expects a rank-1 tensor, but got a rank-${r.rank} tensor`), C(r, [n], [t]);
|
|
198
|
+
}
|
|
199
|
+
const X = /* @__PURE__ */ h({ slice1d_: dn });
|
|
200
|
+
/**
|
|
201
|
+
* @license
|
|
202
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
203
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
204
|
+
* you may not use this file except in compliance with the License.
|
|
205
|
+
* You may obtain a copy of the License at
|
|
206
|
+
*
|
|
207
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
208
|
+
*
|
|
209
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
210
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
211
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
212
|
+
* See the License for the specific language governing permissions and
|
|
213
|
+
* limitations under the License.
|
|
214
|
+
* =============================================================================
|
|
215
|
+
*/
|
|
216
|
+
function mn(e, n, t) {
|
|
217
|
+
const r = f(e, "x", "slice2d");
|
|
218
|
+
return _(r.rank === 2, () => `slice2d expects a rank-2 tensor, but got a rank-${r.rank} tensor`), C(r, n, t);
|
|
219
|
+
}
|
|
220
|
+
const we = /* @__PURE__ */ h({ slice2d_: mn });
|
|
221
|
+
/**
|
|
222
|
+
* @license
|
|
223
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
224
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
225
|
+
* you may not use this file except in compliance with the License.
|
|
226
|
+
* You may obtain a copy of the License at
|
|
227
|
+
*
|
|
228
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
229
|
+
*
|
|
230
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
231
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
232
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
233
|
+
* See the License for the specific language governing permissions and
|
|
234
|
+
* limitations under the License.
|
|
235
|
+
* =============================================================================
|
|
236
|
+
*/
|
|
237
|
+
function gn(e, n, t) {
|
|
238
|
+
const r = f(e, "x", "slice3d");
|
|
239
|
+
return _(r.rank === 3, () => `slice3d expects a rank-3 tensor, but got a rank-${r.rank} tensor`), C(r, n, t);
|
|
240
|
+
}
|
|
241
|
+
const z = /* @__PURE__ */ h({ slice3d_: gn });
|
|
242
|
+
/**
|
|
243
|
+
* @license
|
|
244
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
245
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
246
|
+
* you may not use this file except in compliance with the License.
|
|
247
|
+
* You may obtain a copy of the License at
|
|
248
|
+
*
|
|
249
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
250
|
+
*
|
|
251
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
252
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
253
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
254
|
+
* See the License for the specific language governing permissions and
|
|
255
|
+
* limitations under the License.
|
|
256
|
+
* =============================================================================
|
|
257
|
+
*/
|
|
258
|
+
function $n(e, n, t) {
|
|
259
|
+
const r = f(e, "x", "slice4d");
|
|
260
|
+
return _(r.rank === 4, () => `slice4d expects a rank-4 tensor, but got a rank-${r.rank} tensor`), C(r, n, t);
|
|
261
|
+
}
|
|
262
|
+
const K = /* @__PURE__ */ h({ slice4d_: $n });
|
|
263
|
+
/**
|
|
264
|
+
* @license
|
|
265
|
+
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
266
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
267
|
+
* you may not use this file except in compliance with the License.
|
|
268
|
+
* You may obtain a copy of the License at
|
|
269
|
+
*
|
|
270
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
271
|
+
*
|
|
272
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
273
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
274
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
275
|
+
* See the License for the specific language governing permissions and
|
|
276
|
+
* limitations under the License.
|
|
277
|
+
* =============================================================================
|
|
278
|
+
*/
|
|
279
|
+
function kn(e, n = 0) {
|
|
280
|
+
const r = { x: f(e, "x", "step") }, s = { alpha: n };
|
|
281
|
+
return $.runKernel(Pe, r, s);
|
|
282
|
+
}
|
|
283
|
+
const wn = /* @__PURE__ */ h({ step_: kn });
|
|
284
|
+
/**
|
|
285
|
+
* @license
|
|
286
|
+
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
287
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
288
|
+
* you may not use this file except in compliance with the License.
|
|
289
|
+
* You may obtain a copy of the License at
|
|
290
|
+
*
|
|
291
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
292
|
+
*
|
|
293
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
294
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
295
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
296
|
+
* See the License for the specific language governing permissions and
|
|
297
|
+
* limitations under the License.
|
|
298
|
+
* =============================================================================
|
|
299
|
+
*/
|
|
300
|
+
function An(e, n, t) {
|
|
301
|
+
if (t == null || t === "linear")
|
|
302
|
+
return e;
|
|
303
|
+
if (t === "relu")
|
|
304
|
+
return Q(e, wn(n));
|
|
305
|
+
throw new Error(`Cannot compute gradient for fused activation ${t}.`);
|
|
306
|
+
}
|
|
307
|
+
function Sn(e, n) {
|
|
308
|
+
let t = n;
|
|
309
|
+
const r = Fe(e.shape, n.shape);
|
|
310
|
+
return r.length > 0 && (t = Ge(t, r)), u(t, e.shape);
|
|
311
|
+
}
|
|
312
|
+
function yn(e, n, t, r) {
|
|
313
|
+
if (n === "linear")
|
|
314
|
+
return e;
|
|
315
|
+
if (n === "relu")
|
|
316
|
+
return fn(e);
|
|
317
|
+
if (n === "elu")
|
|
318
|
+
return ke(e);
|
|
319
|
+
if (n === "relu6")
|
|
320
|
+
return pn(e);
|
|
321
|
+
if (n === "prelu")
|
|
322
|
+
return an(e, t);
|
|
323
|
+
if (n === "leakyrelu")
|
|
324
|
+
return un(e, r);
|
|
325
|
+
if (n === "sigmoid")
|
|
326
|
+
return Ze(e);
|
|
327
|
+
throw new Error(`Unknown fused activation ${n}.`);
|
|
328
|
+
}
|
|
329
|
+
const On = (e, n) => !(e > 0) || n === "linear";
|
|
330
|
+
/**
|
|
331
|
+
* @license
|
|
332
|
+
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
333
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
334
|
+
* you may not use this file except in compliance with the License.
|
|
335
|
+
* You may obtain a copy of the License at
|
|
336
|
+
*
|
|
337
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
338
|
+
*
|
|
339
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
340
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
341
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
342
|
+
* See the License for the specific language governing permissions and
|
|
343
|
+
* limitations under the License.
|
|
344
|
+
* =============================================================================
|
|
345
|
+
*/
|
|
346
|
+
function _n({ a: e, b: n, transposeA: t = !1, transposeB: r = !1, bias: s, activation: o = "linear", preluActivationWeights: i, leakyreluAlpha: p = 0.2 }) {
|
|
347
|
+
if (On($.state.gradientDepth, o) === !1) {
|
|
348
|
+
let D = A(e, n, t, r);
|
|
349
|
+
return s != null && (D = U(D, s)), yn(D, o, i, p);
|
|
350
|
+
}
|
|
351
|
+
let c = f(e, "a", "fused matMul"), a = f(n, "b", "fused matMul");
|
|
352
|
+
[c, a] = ae(c, a);
|
|
353
|
+
const k = t ? c.shape[c.rank - 2] : c.shape[c.rank - 1], g = r ? a.shape[a.rank - 1] : a.shape[a.rank - 2], E = t ? c.shape[c.rank - 1] : c.shape[c.rank - 2], d = r ? a.shape[a.rank - 2] : a.shape[a.rank - 1], ne = c.shape.slice(0, -2), x = a.shape.slice(0, -2), te = ie(ne), re = ie(x);
|
|
354
|
+
_(k === g, () => `Error in fused matMul: inner shapes (${k}) and (${g}) of Tensors with shapes ${c.shape} and ${a.shape} and transposeA=${t} and transposeB=${r} must match.`);
|
|
355
|
+
const R = le(c.shape.slice(0, -2), a.shape.slice(0, -2)).concat([E, d]), V = t ? u(c, [te, k, E]) : u(c, [te, E, k]), q = r ? u(a, [re, d, g]) : u(a, [re, g, d]);
|
|
356
|
+
let I;
|
|
357
|
+
s != null && (I = f(s, "bias", "fused matMul"), [I] = ae(I, c), le(R, I.shape));
|
|
358
|
+
let se;
|
|
359
|
+
i != null && (se = f(i, "prelu weights", "fused matMul"));
|
|
360
|
+
const oe = (D, P) => {
|
|
361
|
+
const [y, O, T, B] = P, w = An(u(D, T.shape), T, o);
|
|
362
|
+
let L, N;
|
|
363
|
+
if (!t && !r ? (L = A(w, O, !1, !0), N = A(y, w, !0, !1)) : !t && r ? (L = A(w, O, !1, !1), N = A(w, y, !0, !1)) : t && !r ? (L = A(O, w, !1, !0), N = A(y, w, !1, !1)) : (L = A(O, w, !0, !0), N = A(w, y, !0, !0)), s != null) {
|
|
364
|
+
const De = Sn(B, w);
|
|
365
|
+
return [L, N, De];
|
|
366
|
+
} else
|
|
367
|
+
return [L, N];
|
|
368
|
+
}, ue = {
|
|
369
|
+
a: V,
|
|
370
|
+
b: q,
|
|
371
|
+
bias: I,
|
|
372
|
+
preluActivationWeights: se
|
|
373
|
+
}, ce = { transposeA: t, transposeB: r, activation: o, leakyreluAlpha: p };
|
|
374
|
+
return s == null ? fe((P, y, O) => {
|
|
375
|
+
const T = (
|
|
376
|
+
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
377
|
+
$.runKernel(he, ue, ce)
|
|
378
|
+
);
|
|
379
|
+
return O([P, y, T]), { value: u(T, R), gradFunc: oe };
|
|
380
|
+
})(V, q) : fe((P, y, O, T) => {
|
|
381
|
+
const B = (
|
|
382
|
+
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
383
|
+
$.runKernel(he, ue, ce)
|
|
384
|
+
);
|
|
385
|
+
return T([P, y, B, O]), { value: u(B, R), gradFunc: oe };
|
|
386
|
+
})(V, q, I);
|
|
387
|
+
}
|
|
388
|
+
const de = /* @__PURE__ */ h({ fusedMatMul_: _n });
|
|
389
|
+
/**
|
|
390
|
+
* @license
|
|
391
|
+
* Copyright 2018 Google LLC
|
|
392
|
+
*
|
|
393
|
+
* Use of this source code is governed by an MIT-style
|
|
394
|
+
* license that can be found in the LICENSE file or at
|
|
395
|
+
* https://opensource.org/licenses/MIT.
|
|
396
|
+
* =============================================================================
|
|
397
|
+
*/
|
|
398
|
+
const Dn = ["channelsFirst", "channelsLast"], Tn = ["nearest", "bilinear"], En = ["valid", "same", "causal"], xn = ["max", "avg"], Gn = ["sum", "mul", "concat", "ave"];
|
|
399
|
+
/**
|
|
400
|
+
* @license
|
|
401
|
+
* Copyright 2018 Google LLC
|
|
402
|
+
*
|
|
403
|
+
* Use of this source code is governed by an MIT-style
|
|
404
|
+
* license that can be found in the LICENSE file or at
|
|
405
|
+
* https://opensource.org/licenses/MIT.
|
|
406
|
+
* =============================================================================
|
|
407
|
+
*/
|
|
408
|
+
class Ae extends Error {
|
|
409
|
+
constructor(n) {
|
|
410
|
+
super(n), Object.setPrototypeOf(this, Ae.prototype);
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
class Se extends Error {
|
|
414
|
+
constructor(n) {
|
|
415
|
+
super(n), Object.setPrototypeOf(this, Se.prototype);
|
|
416
|
+
}
|
|
417
|
+
}
|
|
418
|
+
class l extends Error {
|
|
419
|
+
constructor(n) {
|
|
420
|
+
super(n), Object.setPrototypeOf(this, l.prototype);
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
class j extends Error {
|
|
424
|
+
constructor(n) {
|
|
425
|
+
super(n), Object.setPrototypeOf(this, j.prototype);
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
class ee extends Error {
|
|
429
|
+
constructor(n) {
|
|
430
|
+
super(n), Object.setPrototypeOf(this, ee.prototype);
|
|
431
|
+
}
|
|
432
|
+
}
|
|
433
|
+
/**
|
|
434
|
+
* @license
|
|
435
|
+
* Copyright 2018 Google LLC
|
|
436
|
+
*
|
|
437
|
+
* Use of this source code is governed by an MIT-style
|
|
438
|
+
* license that can be found in the LICENSE file or at
|
|
439
|
+
* https://opensource.org/licenses/MIT.
|
|
440
|
+
* =============================================================================
|
|
441
|
+
*/
|
|
442
|
+
function Jn(e, n) {
|
|
443
|
+
if (Array.isArray(e)) {
|
|
444
|
+
let t = [];
|
|
445
|
+
for (let r = 0; r < n; r++)
|
|
446
|
+
t = t.concat(e);
|
|
447
|
+
return t;
|
|
448
|
+
} else {
|
|
449
|
+
const t = new Array(n);
|
|
450
|
+
return t.fill(e), t;
|
|
451
|
+
}
|
|
452
|
+
}
|
|
453
|
+
function me(e, n) {
|
|
454
|
+
if (!e)
|
|
455
|
+
throw new ee(n);
|
|
456
|
+
}
|
|
457
|
+
function Zn(e, n) {
|
|
458
|
+
let t = 0;
|
|
459
|
+
for (const r of e)
|
|
460
|
+
r === n && t++;
|
|
461
|
+
return t;
|
|
462
|
+
}
|
|
463
|
+
function Wn(e) {
|
|
464
|
+
return e.length === 1 ? e[0] : e;
|
|
465
|
+
}
|
|
466
|
+
function Yn(e) {
|
|
467
|
+
return Array.isArray(e) ? e : [e];
|
|
468
|
+
}
|
|
469
|
+
function Hn(e) {
|
|
470
|
+
const t = e.replace(/(.)([A-Z][a-z0-9]+)/g, "$1_$2").replace(/([a-z])([A-Z])/g, "$1_$2").toLowerCase();
|
|
471
|
+
return t[0] !== "_" ? t : "private" + t;
|
|
472
|
+
}
|
|
473
|
+
function Qn(e) {
|
|
474
|
+
return e.length <= 1 || e.indexOf("_") === -1 ? e : e.replace(/[_]+(\w|$)/g, (n, t) => t.toUpperCase());
|
|
475
|
+
}
|
|
476
|
+
let m = {};
|
|
477
|
+
function Xn(e) {
|
|
478
|
+
if (e == null)
|
|
479
|
+
return null;
|
|
480
|
+
const n = {};
|
|
481
|
+
return n.className = e.getClassName(), n.config = e.getConfig(), n;
|
|
482
|
+
}
|
|
483
|
+
function W(e) {
|
|
484
|
+
if (!(e == null || typeof e != "object"))
|
|
485
|
+
if (Array.isArray(e))
|
|
486
|
+
e.forEach((n) => W(n));
|
|
487
|
+
else {
|
|
488
|
+
const n = Object.keys(e);
|
|
489
|
+
for (const t of n) {
|
|
490
|
+
const r = e[t];
|
|
491
|
+
r != null && typeof r == "object" && (!Array.isArray(r) && r.type === "ndarray" && typeof r.value == "number" ? e[t] = r.value : W(r));
|
|
492
|
+
}
|
|
493
|
+
}
|
|
494
|
+
}
|
|
495
|
+
function zn(e, n = {}, t = {}, r = "object", s = !1) {
|
|
496
|
+
if (typeof e == "string") {
|
|
497
|
+
const o = e;
|
|
498
|
+
let i;
|
|
499
|
+
if (o in t)
|
|
500
|
+
i = t[o];
|
|
501
|
+
else if (o in m)
|
|
502
|
+
i = m[o];
|
|
503
|
+
else if (i = n[o], i == null)
|
|
504
|
+
throw new l(`Unknown ${r}: ${e}. This may be due to one of the following reasons:
|
|
505
|
+
1. The ${r} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
|
|
506
|
+
2. The custom ${r} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`);
|
|
507
|
+
return i;
|
|
508
|
+
} else {
|
|
509
|
+
const o = e;
|
|
510
|
+
if (o.className == null || o.config == null)
|
|
511
|
+
throw new l(`${r}: Improper config format: ${JSON.stringify(o)}.
|
|
512
|
+
'className' and 'config' must set.`);
|
|
513
|
+
const i = o.className;
|
|
514
|
+
let p, c;
|
|
515
|
+
if (i in t ? [p, c] = t[i] : i in m ? [p, c] = m.className : i in n && ([p, c] = n[i]), p == null)
|
|
516
|
+
throw new l(`Unknown ${r}: ${i}. This may be due to one of the following reasons:
|
|
517
|
+
1. The ${r} is defined in Python, in which case it needs to be ported to TensorFlow.js or your JavaScript code.
|
|
518
|
+
2. The custom ${r} is defined in JavaScript, but is not registered properly with tf.serialization.registerClass().`);
|
|
519
|
+
if (c != null) {
|
|
520
|
+
const a = {};
|
|
521
|
+
for (const d of Object.keys(m))
|
|
522
|
+
a[d] = m[d];
|
|
523
|
+
for (const d of Object.keys(t))
|
|
524
|
+
a[d] = t[d];
|
|
525
|
+
const k = o.config;
|
|
526
|
+
k.customObjects = a;
|
|
527
|
+
const g = Object.assign({}, m);
|
|
528
|
+
for (const d of Object.keys(t))
|
|
529
|
+
m[d] = t[d];
|
|
530
|
+
W(o.config);
|
|
531
|
+
const E = c(p, o.config, t, s);
|
|
532
|
+
return m = Object.assign({}, g), E;
|
|
533
|
+
} else {
|
|
534
|
+
const a = Object.assign({}, m);
|
|
535
|
+
for (const g of Object.keys(t))
|
|
536
|
+
m[g] = t[g];
|
|
537
|
+
const k = new p(o.config);
|
|
538
|
+
return m = Object.assign({}, a), k;
|
|
539
|
+
}
|
|
540
|
+
}
|
|
541
|
+
}
|
|
542
|
+
function In(e, n) {
|
|
543
|
+
return e < n ? -1 : e > n ? 1 : 0;
|
|
544
|
+
}
|
|
545
|
+
function et(e, n) {
|
|
546
|
+
return -1 * In(e, n);
|
|
547
|
+
}
|
|
548
|
+
function nt(e) {
|
|
549
|
+
if (e == null)
|
|
550
|
+
return e;
|
|
551
|
+
const n = [];
|
|
552
|
+
for (const t of e)
|
|
553
|
+
n.indexOf(t) === -1 && n.push(t);
|
|
554
|
+
return n;
|
|
555
|
+
}
|
|
556
|
+
function tt(e) {
|
|
557
|
+
if (e == null)
|
|
558
|
+
throw new l(`Invalid value in obj: ${JSON.stringify(e)}`);
|
|
559
|
+
for (const n in e)
|
|
560
|
+
if (e.hasOwnProperty(n))
|
|
561
|
+
return !1;
|
|
562
|
+
return !0;
|
|
563
|
+
}
|
|
564
|
+
function v(e, n, t) {
|
|
565
|
+
if (t != null && e.indexOf(t) < 0)
|
|
566
|
+
throw new l(`${t} is not a valid ${n}. Valid values are ${e} or null/undefined.`);
|
|
567
|
+
}
|
|
568
|
+
function rt(e, n, t = 0, r = 1 / 0) {
|
|
569
|
+
return me(t >= 0), me(r >= t), Array.isArray(e) && e.length >= t && e.length <= r && e.every((s) => typeof s === n);
|
|
570
|
+
}
|
|
571
|
+
function Ln(e, n) {
|
|
572
|
+
Array.isArray(e) ? (_(e.length > 0, () => `${n} is unexpectedly an empty array.`), e.forEach((t, r) => Ln(t, `element ${r + 1} of ${n}`))) : _(Number.isInteger(e) && e > 0, () => `Expected ${n} to be a positive integer, but got ${ye(e)}.`);
|
|
573
|
+
}
|
|
574
|
+
function ye(e) {
|
|
575
|
+
return e === null ? "null" : Array.isArray(e) ? "[" + e.map((n) => ye(n)).join(",") + "]" : typeof e == "string" ? `"${e}"` : `${e}`;
|
|
576
|
+
}
|
|
577
|
+
function st(e, n, t) {
|
|
578
|
+
let r = t != null ? t() : pe(), s;
|
|
579
|
+
return (...i) => {
|
|
580
|
+
const p = t != null ? t() : pe();
|
|
581
|
+
return p - r < n || (r = p, s = e(...i)), s;
|
|
582
|
+
};
|
|
583
|
+
}
|
|
584
|
+
function ot(e) {
|
|
585
|
+
return e === "relu" ? "relu" : e === "linear" ? "linear" : e === "elu" ? "elu" : null;
|
|
586
|
+
}
|
|
587
|
+
/**
|
|
588
|
+
* @license
|
|
589
|
+
* Copyright 2018 Google LLC
|
|
590
|
+
*
|
|
591
|
+
* Use of this source code is governed by an MIT-style
|
|
592
|
+
* license that can be found in the LICENSE file or at
|
|
593
|
+
* https://opensource.org/licenses/MIT.
|
|
594
|
+
* =============================================================================
|
|
595
|
+
*/
|
|
596
|
+
const b = /* @__PURE__ */ new Map();
|
|
597
|
+
function Nn(e) {
|
|
598
|
+
v(Dn, "DataFormat", e);
|
|
599
|
+
}
|
|
600
|
+
function ut(e) {
|
|
601
|
+
v(Tn, "InterpolationFormat", e);
|
|
602
|
+
}
|
|
603
|
+
function ct(e) {
|
|
604
|
+
v(En, "PaddingMode", e);
|
|
605
|
+
}
|
|
606
|
+
function at(e) {
|
|
607
|
+
v(xn, "PoolMode", e);
|
|
608
|
+
}
|
|
609
|
+
const F = [], ge = "/";
|
|
610
|
+
function it(e, n) {
|
|
611
|
+
F.push(e);
|
|
612
|
+
try {
|
|
613
|
+
const t = n();
|
|
614
|
+
return F.pop(), t;
|
|
615
|
+
} catch (t) {
|
|
616
|
+
throw F.pop(), t;
|
|
617
|
+
}
|
|
618
|
+
}
|
|
619
|
+
function bn() {
|
|
620
|
+
return F.length === 0 ? "" : F.join(ge) + ge;
|
|
621
|
+
}
|
|
622
|
+
function lt(e) {
|
|
623
|
+
if (!Oe(e))
|
|
624
|
+
throw new Error("Not a valid tensor name: '" + e + "'");
|
|
625
|
+
return bn() + e;
|
|
626
|
+
}
|
|
627
|
+
function ft(e) {
|
|
628
|
+
if (!Oe(e))
|
|
629
|
+
throw new Error("Not a valid tensor name: '" + e + "'");
|
|
630
|
+
b.has(e) || b.set(e, 0);
|
|
631
|
+
const n = b.get(e);
|
|
632
|
+
if (b.set(e, b.get(e) + 1), n > 0) {
|
|
633
|
+
const t = `${e}_${n}`;
|
|
634
|
+
return b.set(t, 1), t;
|
|
635
|
+
} else
|
|
636
|
+
return e;
|
|
637
|
+
}
|
|
638
|
+
const Cn = new RegExp(/^[A-Za-z0-9][-A-Za-z0-9\._\/]*$/);
|
|
639
|
+
function Oe(e) {
|
|
640
|
+
return !!e.match(Cn);
|
|
641
|
+
}
|
|
642
|
+
/**
|
|
643
|
+
* @license
|
|
644
|
+
* Copyright 2018 Google LLC
|
|
645
|
+
*
|
|
646
|
+
* Use of this source code is governed by an MIT-style
|
|
647
|
+
* license that can be found in the LICENSE file or at
|
|
648
|
+
* https://opensource.org/licenses/MIT.
|
|
649
|
+
* =============================================================================
|
|
650
|
+
*/
|
|
651
|
+
function ht(e) {
|
|
652
|
+
return e === parseInt(e.toString(), 10);
|
|
653
|
+
}
|
|
654
|
+
function _e(e, n, t) {
|
|
655
|
+
n == null && (n = 0), t == null && (t = e.length);
|
|
656
|
+
let r = 1;
|
|
657
|
+
for (let s = n; s < t; ++s)
|
|
658
|
+
r *= e[s];
|
|
659
|
+
return r;
|
|
660
|
+
}
|
|
661
|
+
function pt(e) {
|
|
662
|
+
if (e.length === 0)
|
|
663
|
+
return Number.NaN;
|
|
664
|
+
let n = Number.POSITIVE_INFINITY;
|
|
665
|
+
for (let t = 0; t < e.length; t++) {
|
|
666
|
+
const r = e[t];
|
|
667
|
+
r < n && (n = r);
|
|
668
|
+
}
|
|
669
|
+
return n;
|
|
670
|
+
}
|
|
671
|
+
function dt(e) {
|
|
672
|
+
if (e.length === 0)
|
|
673
|
+
return Number.NaN;
|
|
674
|
+
let n = Number.NEGATIVE_INFINITY;
|
|
675
|
+
for (let t = 0; t < e.length; t++) {
|
|
676
|
+
const r = e[t];
|
|
677
|
+
r > n && (n = r);
|
|
678
|
+
}
|
|
679
|
+
return n;
|
|
680
|
+
}
|
|
681
|
+
function mt(e, n) {
|
|
682
|
+
if (n < e)
|
|
683
|
+
throw new l(`end (${n}) < begin (${e}) is forbidden.`);
|
|
684
|
+
const t = [];
|
|
685
|
+
for (let r = e; r < n; ++r)
|
|
686
|
+
t.push(r);
|
|
687
|
+
return t;
|
|
688
|
+
}
|
|
689
|
+
/**
|
|
690
|
+
* @license
|
|
691
|
+
* Copyright 2018 Google LLC
|
|
692
|
+
*
|
|
693
|
+
* Use of this source code is governed by an MIT-style
|
|
694
|
+
* license that can be found in the LICENSE file or at
|
|
695
|
+
* https://opensource.org/licenses/MIT.
|
|
696
|
+
* =============================================================================
|
|
697
|
+
*/
|
|
698
|
+
let G;
|
|
699
|
+
function gt() {
|
|
700
|
+
return G == null && (G = je().epsilon()), G;
|
|
701
|
+
}
|
|
702
|
+
function Y() {
|
|
703
|
+
return "channelsLast";
|
|
704
|
+
}
|
|
705
|
+
/**
|
|
706
|
+
* @license
|
|
707
|
+
* Copyright 2018 Google LLC
|
|
708
|
+
*
|
|
709
|
+
* Use of this source code is governed by an MIT-style
|
|
710
|
+
* license that can be found in the LICENSE file or at
|
|
711
|
+
* https://opensource.org/licenses/MIT.
|
|
712
|
+
* =============================================================================
|
|
713
|
+
*/
|
|
714
|
+
function $t(e, n) {
|
|
715
|
+
return $e(e, n);
|
|
716
|
+
}
|
|
717
|
+
function Pn(e, n = -1) {
|
|
718
|
+
const t = e.shape.slice();
|
|
719
|
+
return n < 0 && (n = t.length + n + 1), t.splice(n, 0, 1), u(e, t);
|
|
720
|
+
}
|
|
721
|
+
function kt(e, n) {
|
|
722
|
+
return S(() => {
|
|
723
|
+
if (e.shape.length !== 2)
|
|
724
|
+
throw new l(`repeat() expects a rank-2 tensor, but received a rank-${e.shape.length} tensor.`);
|
|
725
|
+
const t = Pn(e, 1);
|
|
726
|
+
return Fn(t, [1, n, 1]);
|
|
727
|
+
});
|
|
728
|
+
}
|
|
729
|
+
function wt(e) {
|
|
730
|
+
const n = [_e(e.shape)];
|
|
731
|
+
return u(e, n);
|
|
732
|
+
}
|
|
733
|
+
function At(e) {
|
|
734
|
+
if (e.rank <= 1)
|
|
735
|
+
throw new l(`batchFlatten requires a minimum rank of 2. Got rank: ${e.rank}.`);
|
|
736
|
+
const n = [e.shape[0], _e(e.shape, 1)];
|
|
737
|
+
return u(e, n);
|
|
738
|
+
}
|
|
739
|
+
function J(e, n, t) {
|
|
740
|
+
return S(() => {
|
|
741
|
+
switch (e.rank) {
|
|
742
|
+
case 1:
|
|
743
|
+
return X(e, n, t);
|
|
744
|
+
case 2:
|
|
745
|
+
return we(e, [n, 0], [t, e.shape[1]]);
|
|
746
|
+
case 3:
|
|
747
|
+
return z(e, [n, 0, 0], [t, e.shape[1], e.shape[2]]);
|
|
748
|
+
case 4:
|
|
749
|
+
return K(e, [n, 0, 0, 0], [t, e.shape[1], e.shape[2], e.shape[3]]);
|
|
750
|
+
case 5:
|
|
751
|
+
return C(e, [n, 0, 0, 0, 0], [
|
|
752
|
+
t,
|
|
753
|
+
e.shape[1],
|
|
754
|
+
e.shape[2],
|
|
755
|
+
e.shape[3],
|
|
756
|
+
e.shape[4]
|
|
757
|
+
]);
|
|
758
|
+
case 6:
|
|
759
|
+
return C(e, [n, 0, 0, 0, 0, 0], [
|
|
760
|
+
t,
|
|
761
|
+
e.shape[1],
|
|
762
|
+
e.shape[2],
|
|
763
|
+
e.shape[3],
|
|
764
|
+
e.shape[4],
|
|
765
|
+
e.shape[5]
|
|
766
|
+
]);
|
|
767
|
+
default:
|
|
768
|
+
throw new l(`sliceAlongFirstAxis() received an unsupported tensor rank: ${e.rank}`);
|
|
769
|
+
}
|
|
770
|
+
});
|
|
771
|
+
}
|
|
772
|
+
function Z(e, n, t) {
|
|
773
|
+
return S(() => {
|
|
774
|
+
switch (e.rank) {
|
|
775
|
+
case 1:
|
|
776
|
+
return X(e, n, t);
|
|
777
|
+
case 2:
|
|
778
|
+
return we(e, [0, n], [e.shape[0], t]);
|
|
779
|
+
case 3:
|
|
780
|
+
return z(e, [0, 0, n], [e.shape[0], e.shape[1], t]);
|
|
781
|
+
case 4:
|
|
782
|
+
return K(e, [0, 0, 0, n], [e.shape[0], e.shape[1], e.shape[2], t]);
|
|
783
|
+
default:
|
|
784
|
+
throw new l(`sliceAlongLastAxis() received an unsupported tensor rank: ${e.rank}`);
|
|
785
|
+
}
|
|
786
|
+
});
|
|
787
|
+
}
|
|
788
|
+
function St(e, n, t, r) {
|
|
789
|
+
return S(() => {
|
|
790
|
+
switch (e.rank) {
|
|
791
|
+
case 1:
|
|
792
|
+
return X(e, n, t);
|
|
793
|
+
case 2:
|
|
794
|
+
switch (r) {
|
|
795
|
+
case 1:
|
|
796
|
+
return J(e, n, t);
|
|
797
|
+
case 2:
|
|
798
|
+
return Z(e, n, t);
|
|
799
|
+
default:
|
|
800
|
+
throw new l(`The axis is not within the rank of the tensor ${r}`);
|
|
801
|
+
}
|
|
802
|
+
case 3:
|
|
803
|
+
switch (r) {
|
|
804
|
+
case 1:
|
|
805
|
+
return J(e, n, t);
|
|
806
|
+
case 2:
|
|
807
|
+
return z(e, [0, n, 0], [e.shape[0], t, e.shape[2]]);
|
|
808
|
+
case 3:
|
|
809
|
+
return Z(e, n, t);
|
|
810
|
+
default:
|
|
811
|
+
throw new l(`The axis is not within the rank of the tensor ${r}`);
|
|
812
|
+
}
|
|
813
|
+
case 4:
|
|
814
|
+
switch (r) {
|
|
815
|
+
case 1:
|
|
816
|
+
return J(e, n, t);
|
|
817
|
+
case 2:
|
|
818
|
+
return K(e, [0, n, 0, 0], [e.shape[0], t, e.shape[2], e.shape[3]]);
|
|
819
|
+
case 3:
|
|
820
|
+
return K(e, [0, 0, n, 0], [e.shape[0], e.shape[1], t, e.shape[3]]);
|
|
821
|
+
case 4:
|
|
822
|
+
return Z(e, n, t);
|
|
823
|
+
default:
|
|
824
|
+
throw new l(`The axis is not within the rank of the tensor ${r}`);
|
|
825
|
+
}
|
|
826
|
+
default:
|
|
827
|
+
throw new l(`sliceAlongLastAxis() received an unsupported tensor rank: ${e.rank}`);
|
|
828
|
+
}
|
|
829
|
+
});
|
|
830
|
+
}
|
|
831
|
+
function yt(e, n = -1) {
|
|
832
|
+
let t;
|
|
833
|
+
return n < 0 && (t = e[0].rank, t !== 0 ? n = t : n = 0), n === e[0].rank && (n = -1), M(e, n);
|
|
834
|
+
}
|
|
835
|
+
function Ot(e, n) {
|
|
836
|
+
switch (e.rank) {
|
|
837
|
+
case 1:
|
|
838
|
+
return Qe([e, n]);
|
|
839
|
+
case 2:
|
|
840
|
+
return ze([e, n], 0);
|
|
841
|
+
case 3:
|
|
842
|
+
return nn([e, n], 0);
|
|
843
|
+
case 4:
|
|
844
|
+
return rn([e, n], 0);
|
|
845
|
+
default:
|
|
846
|
+
throw new l(`concatAlongFirstAxis() received an unsupported tensor rank: ${e.rank}`);
|
|
847
|
+
}
|
|
848
|
+
}
|
|
849
|
+
function Fn(e, n) {
|
|
850
|
+
if (Array.isArray(n) || (n = [n]), e.rank !== n.length)
|
|
851
|
+
throw new l(`The length of input n (${n.length}) does not match the number of dimensions in input x (${e.rank})`);
|
|
852
|
+
return Ue(e, n);
|
|
853
|
+
}
|
|
854
|
+
function _t(e, n = 0, t = 1, r, s) {
|
|
855
|
+
return Re(e, n, t, r, s);
|
|
856
|
+
}
|
|
857
|
+
function Dt(e, n, t, r) {
|
|
858
|
+
if (e.rank < 2 || n.rank < 2)
|
|
859
|
+
throw new j(`dot requires both inputs to be rank >= 2 but got x shape = ${e.shape} and y shape = ${n.shape}`);
|
|
860
|
+
if (n.rank >= 3) {
|
|
861
|
+
const s = e.shape.slice(-1)[0], o = n.shape.slice(-2)[0];
|
|
862
|
+
if (s !== o)
|
|
863
|
+
throw new j(`If rank y >= 3, then the second last dim of y must equal the last dim of x but got x shape = ${e.shape} and y shape = ${n.shape}`);
|
|
864
|
+
}
|
|
865
|
+
if (e.rank === 2 && n.rank === 2)
|
|
866
|
+
return de({
|
|
867
|
+
a: e,
|
|
868
|
+
b: n,
|
|
869
|
+
transposeA: !1,
|
|
870
|
+
transposeB: !1,
|
|
871
|
+
bias: r ? H(e.rank, r, Y()) : null,
|
|
872
|
+
activation: t
|
|
873
|
+
});
|
|
874
|
+
{
|
|
875
|
+
const s = e.shape.slice(), o = s.pop();
|
|
876
|
+
e = u(e, [-1, o]);
|
|
877
|
+
const i = n.shape.slice(), p = i.pop(), c = i.pop(), a = [...i, p], k = Array.from({ length: n.rank }, (ne, x) => x === 0 ? n.rank - 2 : x <= n.rank - 2 ? x - 1 : x);
|
|
878
|
+
n = u(ve(n, k), [c, -1]);
|
|
879
|
+
const g = [...s, ...a];
|
|
880
|
+
return u(de({
|
|
881
|
+
a: e,
|
|
882
|
+
b: n,
|
|
883
|
+
transposeA: !1,
|
|
884
|
+
transposeB: !1,
|
|
885
|
+
bias: r ? H(e.rank, r, Y()) : null,
|
|
886
|
+
activation: t
|
|
887
|
+
}), g);
|
|
888
|
+
}
|
|
889
|
+
}
|
|
890
|
+
function Tt(e, n, t) {
|
|
891
|
+
return S(() => (Array.isArray(n) ? n = Ke(n, "int32") : n = $e(n, "int32"), qe(e, n, t)));
|
|
892
|
+
}
|
|
893
|
+
function Et(e) {
|
|
894
|
+
return Q(e, e);
|
|
895
|
+
}
|
|
896
|
+
function H(e, n, t) {
|
|
897
|
+
const r = n.shape;
|
|
898
|
+
if (n.rank !== 1 && n.rank !== e)
|
|
899
|
+
throw new l(`Unexpected bias dimensions: ${n.rank}; expected it to be 1 or ${e}`);
|
|
900
|
+
if (e === 5) {
|
|
901
|
+
if (t === "channelsFirst")
|
|
902
|
+
return r.length === 1 ? u(n, [1, r[0], 1, 1, 1]) : u(n, [1, r[3], r[0], r[1], r[2]]);
|
|
903
|
+
if (t === "channelsLast")
|
|
904
|
+
return r.length === 1 ? u(n, [1, 1, 1, 1, r[0]]) : u(n, [1].concat(r));
|
|
905
|
+
} else if (e === 4) {
|
|
906
|
+
if (t === "channelsFirst")
|
|
907
|
+
return r.length === 1 ? u(n, [1, r[0], 1, 1]) : u(n, [1, r[2], r[0], r[1]]);
|
|
908
|
+
if (t === "channelsLast")
|
|
909
|
+
return r.length === 1 ? u(n, [1, 1, 1, r[0]]) : u(n, [1].concat(r));
|
|
910
|
+
} else if (e === 3) {
|
|
911
|
+
if (t === "channelsFirst")
|
|
912
|
+
return r.length === 1 ? u(n, [1, r[0], 1]) : u(n, [1, r[1], r[0]]);
|
|
913
|
+
if (t === "channelsLast")
|
|
914
|
+
return r.length === 1 ? u(n, [1, 1, r[0]]) : u(n, [1].concat(r));
|
|
915
|
+
} else if (e < 3)
|
|
916
|
+
return n;
|
|
917
|
+
throw new l(`Unsupported input rank by biasAdd: ${n.rank}`);
|
|
918
|
+
}
|
|
919
|
+
function xt(e, n, t) {
|
|
920
|
+
return S(() => (t == null && (t = Y()), Nn(t), U(e, H(e.rank, n, t))));
|
|
921
|
+
}
|
|
922
|
+
function It(e, n = 1) {
|
|
923
|
+
if (n !== 1)
|
|
924
|
+
throw new j(`Support for alpha values other than 1 (${n}) is not implemented yet.`);
|
|
925
|
+
return ke(e);
|
|
926
|
+
}
|
|
927
|
+
function Lt(e) {
|
|
928
|
+
return S(() => Me(e, U(Be(e), 1)));
|
|
929
|
+
}
|
|
930
|
+
function Nt(e, n, t, r) {
|
|
931
|
+
return S(() => Ve(e, n, t, r));
|
|
932
|
+
}
|
|
933
|
+
function bt(e) {
|
|
934
|
+
return S(() => {
|
|
935
|
+
const n = U(0.5, Q(0.2, e));
|
|
936
|
+
return Ye(n, 0, 1);
|
|
937
|
+
});
|
|
938
|
+
}
|
|
939
|
+
function Ct(e, n, t = !1) {
|
|
940
|
+
return t ? e() : n();
|
|
941
|
+
}
|
|
942
|
+
export {
|
|
943
|
+
Ln as $,
|
|
944
|
+
Ae as A,
|
|
945
|
+
$t as B,
|
|
946
|
+
Qn as C,
|
|
947
|
+
nt as D,
|
|
948
|
+
et as E,
|
|
949
|
+
Jn as F,
|
|
950
|
+
tt as G,
|
|
951
|
+
J as H,
|
|
952
|
+
Pn as I,
|
|
953
|
+
Tt as J,
|
|
954
|
+
mt as K,
|
|
955
|
+
Zn as L,
|
|
956
|
+
It as M,
|
|
957
|
+
j as N,
|
|
958
|
+
bt as O,
|
|
959
|
+
Lt as P,
|
|
960
|
+
un as Q,
|
|
961
|
+
Se as R,
|
|
962
|
+
an as S,
|
|
963
|
+
ke as T,
|
|
964
|
+
dt as U,
|
|
965
|
+
l as V,
|
|
966
|
+
ht as W,
|
|
967
|
+
rt as X,
|
|
968
|
+
xt as Y,
|
|
969
|
+
St as Z,
|
|
970
|
+
ut as _,
|
|
971
|
+
yn as a,
|
|
972
|
+
ot as a0,
|
|
973
|
+
ct as a1,
|
|
974
|
+
Y as a2,
|
|
975
|
+
Fn as a3,
|
|
976
|
+
pt as a4,
|
|
977
|
+
Ot as a5,
|
|
978
|
+
Ct as a6,
|
|
979
|
+
Nt as a7,
|
|
980
|
+
yt as a8,
|
|
981
|
+
At as a9,
|
|
982
|
+
kt as aa,
|
|
983
|
+
at as ab,
|
|
984
|
+
Gn as ac,
|
|
985
|
+
Sn as b,
|
|
986
|
+
v as c,
|
|
987
|
+
Dt as d,
|
|
988
|
+
Nn as e,
|
|
989
|
+
_e as f,
|
|
990
|
+
An as g,
|
|
991
|
+
zn as h,
|
|
992
|
+
Xn as i,
|
|
993
|
+
lt as j,
|
|
994
|
+
ft as k,
|
|
995
|
+
Wn as l,
|
|
996
|
+
Yn as m,
|
|
997
|
+
it as n,
|
|
998
|
+
wn as o,
|
|
999
|
+
Ze as p,
|
|
1000
|
+
Ye as q,
|
|
1001
|
+
_t as r,
|
|
1002
|
+
On as s,
|
|
1003
|
+
Hn as t,
|
|
1004
|
+
gt as u,
|
|
1005
|
+
fn as v,
|
|
1006
|
+
st as w,
|
|
1007
|
+
wt as x,
|
|
1008
|
+
Et as y,
|
|
1009
|
+
me as z
|
|
1010
|
+
};
|