@thi.ng/tensors 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CHANGELOG.md +46 -0
- package/LICENSE +201 -0
- package/README.md +307 -0
- package/abs.d.ts +37 -0
- package/abs.js +14 -0
- package/add.d.ts +42 -0
- package/add.js +15 -0
- package/addn.d.ts +42 -0
- package/addn.js +15 -0
- package/api.d.ts +137 -0
- package/api.js +0 -0
- package/clamp.d.ts +47 -0
- package/clamp.js +15 -0
- package/clampn.d.ts +47 -0
- package/clampn.js +15 -0
- package/cos.d.ts +37 -0
- package/cos.js +14 -0
- package/defopn.d.ts +13 -0
- package/defopn.js +65 -0
- package/defoprt.d.ts +14 -0
- package/defoprt.js +77 -0
- package/defoprtt.d.ts +14 -0
- package/defoprtt.js +113 -0
- package/defopt.d.ts +13 -0
- package/defopt.js +109 -0
- package/defoptn.d.ts +13 -0
- package/defoptn.js +109 -0
- package/defoptnn.d.ts +13 -0
- package/defoptnn.js +109 -0
- package/defoptt.d.ts +13 -0
- package/defoptt.js +144 -0
- package/defopttt.d.ts +13 -0
- package/defopttt.js +177 -0
- package/div.d.ts +42 -0
- package/div.js +15 -0
- package/divn.d.ts +42 -0
- package/divn.js +15 -0
- package/dot.d.ts +31 -0
- package/dot.js +17 -0
- package/errors.d.ts +15 -0
- package/errors.js +9 -0
- package/exp.d.ts +37 -0
- package/exp.js +14 -0
- package/exp2.d.ts +37 -0
- package/exp2.js +14 -0
- package/format.d.ts +14 -0
- package/format.js +37 -0
- package/identity.d.ts +4 -0
- package/identity.js +11 -0
- package/index.d.ts +60 -0
- package/index.js +59 -0
- package/log.d.ts +37 -0
- package/log.js +14 -0
- package/log2.d.ts +37 -0
- package/log2.js +14 -0
- package/mag.d.ts +3 -0
- package/mag.js +5 -0
- package/magsq.d.ts +31 -0
- package/magsq.js +17 -0
- package/max.d.ts +37 -0
- package/max.js +14 -0
- package/maxn.d.ts +42 -0
- package/maxn.js +14 -0
- package/min.d.ts +37 -0
- package/min.js +14 -0
- package/minn.d.ts +42 -0
- package/minn.js +14 -0
- package/mul.d.ts +42 -0
- package/mul.js +15 -0
- package/mulm.d.ts +12 -0
- package/mulm.js +49 -0
- package/muln.d.ts +42 -0
- package/muln.js +15 -0
- package/mulv.d.ts +11 -0
- package/mulv.js +39 -0
- package/normalize.d.ts +3 -0
- package/normalize.js +11 -0
- package/package.json +261 -0
- package/pow.d.ts +37 -0
- package/pow.js +14 -0
- package/pown.d.ts +42 -0
- package/pown.js +14 -0
- package/product.d.ts +31 -0
- package/product.js +15 -0
- package/rand-distrib.d.ts +49 -0
- package/rand-distrib.js +52 -0
- package/relu.d.ts +37 -0
- package/relu.js +14 -0
- package/relun.d.ts +43 -0
- package/relun.js +14 -0
- package/select.d.ts +91 -0
- package/select.js +111 -0
- package/set.d.ts +8 -0
- package/set.js +14 -0
- package/setn.d.ts +8 -0
- package/setn.js +14 -0
- package/sigmoid.d.ts +39 -0
- package/sigmoid.js +15 -0
- package/sin.d.ts +37 -0
- package/sin.js +14 -0
- package/softmax.d.ts +27 -0
- package/softmax.js +10 -0
- package/softplus.d.ts +48 -0
- package/softplus.js +15 -0
- package/sqrt.d.ts +37 -0
- package/sqrt.js +14 -0
- package/step.d.ts +48 -0
- package/step.js +14 -0
- package/storage.d.ts +3 -0
- package/storage.js +32 -0
- package/sub.d.ts +42 -0
- package/sub.js +15 -0
- package/subn.d.ts +42 -0
- package/subn.js +15 -0
- package/sum.d.ts +31 -0
- package/sum.js +15 -0
- package/tan.d.ts +37 -0
- package/tan.js +14 -0
- package/tanh.d.ts +37 -0
- package/tanh.js +14 -0
- package/tensor.d.ts +127 -0
- package/tensor.js +517 -0
- package/top.d.ts +16 -0
- package/top.js +15 -0
package/tensor.js
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
1
|
+
import { swizzle } from "@thi.ng/arrays/swizzle";
|
|
2
|
+
import { isNumber } from "@thi.ng/checks/is-number";
|
|
3
|
+
import { equiv, equivArrayLike } from "@thi.ng/equiv";
|
|
4
|
+
import { illegalArgs } from "@thi.ng/errors/illegal-arguments";
|
|
5
|
+
import { outOfBounds } from "@thi.ng/errors/out-of-bounds";
|
|
6
|
+
import { unsupported } from "@thi.ng/errors/unsupported";
|
|
7
|
+
import { dot2, dot3, dot4 } from "@thi.ng/vectors/dot";
|
|
8
|
+
import { eqDeltaS as _eqDelta } from "@thi.ng/vectors/eqdelta";
|
|
9
|
+
import { product, product2, product3, product4 } from "@thi.ng/vectors/product";
|
|
10
|
+
import { illegalShape } from "./errors.js";
|
|
11
|
+
import { format } from "./format.js";
|
|
12
|
+
import { STORAGE } from "./storage.js";
|
|
13
|
+
const { abs, ceil, min } = Math;
|
|
14
|
+
class ATensor {
|
|
15
|
+
constructor(type, storage, data, shape, stride, offset = 0) {
|
|
16
|
+
this.type = type;
|
|
17
|
+
this.storage = storage;
|
|
18
|
+
this.data = data;
|
|
19
|
+
this.shape = shape;
|
|
20
|
+
this.stride = stride;
|
|
21
|
+
this.offset = offset;
|
|
22
|
+
}
|
|
23
|
+
get order() {
|
|
24
|
+
return strideOrder(this.stride);
|
|
25
|
+
}
|
|
26
|
+
get orderedShape() {
|
|
27
|
+
return swizzle(this.order)(this.shape);
|
|
28
|
+
}
|
|
29
|
+
get orderedStride() {
|
|
30
|
+
return swizzle(this.order)(this.stride);
|
|
31
|
+
}
|
|
32
|
+
copy() {
|
|
33
|
+
return new this.constructor(
|
|
34
|
+
this.type,
|
|
35
|
+
this.storage,
|
|
36
|
+
this.data,
|
|
37
|
+
this.shape.slice(),
|
|
38
|
+
this.stride.slice(),
|
|
39
|
+
this.offset
|
|
40
|
+
);
|
|
41
|
+
}
|
|
42
|
+
empty(storage = this.storage) {
|
|
43
|
+
return new this.constructor(
|
|
44
|
+
this.type,
|
|
45
|
+
storage,
|
|
46
|
+
storage.alloc(this.length),
|
|
47
|
+
this.shape.slice(),
|
|
48
|
+
shapeToStride(this.shape)
|
|
49
|
+
);
|
|
50
|
+
}
|
|
51
|
+
equiv(o) {
|
|
52
|
+
return this === o || o instanceof ATensor && equiv(this.shape, o.shape) && equivArrayLike([...this], [...o]);
|
|
53
|
+
}
|
|
54
|
+
eqDelta(o, eps = 1e-6) {
|
|
55
|
+
return this === o || equiv(this.shape, o.shape) && _eqDelta([...this], [...o], this.length, eps);
|
|
56
|
+
}
|
|
57
|
+
hi(pos) {
|
|
58
|
+
return new this.constructor(
|
|
59
|
+
this.type,
|
|
60
|
+
this.storage,
|
|
61
|
+
this.data,
|
|
62
|
+
__hi(pos, this),
|
|
63
|
+
this.stride,
|
|
64
|
+
this.offset
|
|
65
|
+
);
|
|
66
|
+
}
|
|
67
|
+
lo(pos) {
|
|
68
|
+
const { shape, offset } = __lo(pos, this);
|
|
69
|
+
return new this.constructor(
|
|
70
|
+
this.type,
|
|
71
|
+
this.storage,
|
|
72
|
+
this.data,
|
|
73
|
+
shape,
|
|
74
|
+
this.stride,
|
|
75
|
+
offset
|
|
76
|
+
);
|
|
77
|
+
}
|
|
78
|
+
step(select) {
|
|
79
|
+
const { shape, stride, offset } = __step(select, this);
|
|
80
|
+
return new this.constructor(
|
|
81
|
+
this.type,
|
|
82
|
+
this.storage,
|
|
83
|
+
this.data,
|
|
84
|
+
shape,
|
|
85
|
+
stride,
|
|
86
|
+
offset
|
|
87
|
+
);
|
|
88
|
+
}
|
|
89
|
+
pick(select) {
|
|
90
|
+
const { shape, stride, offset } = __pick(select, this);
|
|
91
|
+
return tensor(this.type, shape, {
|
|
92
|
+
data: this.data,
|
|
93
|
+
storage: this.storage,
|
|
94
|
+
copy: false,
|
|
95
|
+
stride,
|
|
96
|
+
offset
|
|
97
|
+
});
|
|
98
|
+
}
|
|
99
|
+
pack(storage = this.storage) {
|
|
100
|
+
return new this.constructor(
|
|
101
|
+
this.type,
|
|
102
|
+
storage,
|
|
103
|
+
storage.from(this),
|
|
104
|
+
this.shape.slice(),
|
|
105
|
+
shapeToStride(this.shape)
|
|
106
|
+
);
|
|
107
|
+
}
|
|
108
|
+
reshape(newShape, newStride) {
|
|
109
|
+
const newLength = product(newShape);
|
|
110
|
+
if (newLength !== this.length) illegalShape(newShape);
|
|
111
|
+
return tensor(this.type, newShape, {
|
|
112
|
+
storage: this.storage,
|
|
113
|
+
data: this.data,
|
|
114
|
+
copy: false,
|
|
115
|
+
stride: newStride ?? shapeToStride(newShape),
|
|
116
|
+
offset: this.offset
|
|
117
|
+
});
|
|
118
|
+
}
|
|
119
|
+
transpose(order) {
|
|
120
|
+
const reorder = swizzle(order);
|
|
121
|
+
return new this.constructor(
|
|
122
|
+
this.type,
|
|
123
|
+
this.storage,
|
|
124
|
+
this.data,
|
|
125
|
+
reorder(this.shape),
|
|
126
|
+
reorder(this.stride),
|
|
127
|
+
this.offset
|
|
128
|
+
);
|
|
129
|
+
}
|
|
130
|
+
toJSON() {
|
|
131
|
+
return {
|
|
132
|
+
buf: [...this],
|
|
133
|
+
shape: this.shape,
|
|
134
|
+
stride: shapeToStride(this.shape)
|
|
135
|
+
};
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
class Tensor1 extends ATensor {
|
|
139
|
+
*[Symbol.iterator]() {
|
|
140
|
+
let {
|
|
141
|
+
data,
|
|
142
|
+
length,
|
|
143
|
+
stride: [tx],
|
|
144
|
+
offset
|
|
145
|
+
} = this;
|
|
146
|
+
for (; length-- > 0; offset += tx) yield data[offset];
|
|
147
|
+
}
|
|
148
|
+
get dim() {
|
|
149
|
+
return 1;
|
|
150
|
+
}
|
|
151
|
+
get order() {
|
|
152
|
+
return [0];
|
|
153
|
+
}
|
|
154
|
+
get length() {
|
|
155
|
+
return this.shape[0];
|
|
156
|
+
}
|
|
157
|
+
index([x]) {
|
|
158
|
+
return this.offset + x * this.stride[0];
|
|
159
|
+
}
|
|
160
|
+
get([x]) {
|
|
161
|
+
return this.data[this.offset + x * this.stride[0]];
|
|
162
|
+
}
|
|
163
|
+
set([x], v) {
|
|
164
|
+
this.data[this.offset + x * this.stride[0]] = v;
|
|
165
|
+
return this;
|
|
166
|
+
}
|
|
167
|
+
pick([x]) {
|
|
168
|
+
if (x < 0 && x >= this.length) outOfBounds(x);
|
|
169
|
+
return new Tensor1(
|
|
170
|
+
this.type,
|
|
171
|
+
this.storage,
|
|
172
|
+
this.data,
|
|
173
|
+
[1],
|
|
174
|
+
[1],
|
|
175
|
+
this.offset + x * this.stride[0]
|
|
176
|
+
);
|
|
177
|
+
}
|
|
178
|
+
resize(newShape, fill, storage = this.storage) {
|
|
179
|
+
const newLength = product(newShape);
|
|
180
|
+
const newData = storage.alloc(newLength);
|
|
181
|
+
if (fill !== void 0) newData.fill(fill);
|
|
182
|
+
const {
|
|
183
|
+
data,
|
|
184
|
+
shape: [sx],
|
|
185
|
+
stride: [tx]
|
|
186
|
+
} = this;
|
|
187
|
+
const n = min(sx, newLength);
|
|
188
|
+
for (let i = this.offset, ii = 0, x = 0; x < sx && ii < n; x++, i += tx, ii++) {
|
|
189
|
+
newData[ii] = data[i];
|
|
190
|
+
}
|
|
191
|
+
return tensor(this.type, newShape, {
|
|
192
|
+
storage,
|
|
193
|
+
data: newData,
|
|
194
|
+
copy: false
|
|
195
|
+
});
|
|
196
|
+
}
|
|
197
|
+
transpose(_) {
|
|
198
|
+
return this.copy();
|
|
199
|
+
}
|
|
200
|
+
toString() {
|
|
201
|
+
const res = [];
|
|
202
|
+
for (let x of this) res.push(format(x));
|
|
203
|
+
return res.join(" ");
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
class Tensor2 extends ATensor {
|
|
207
|
+
_n;
|
|
208
|
+
*[Symbol.iterator]() {
|
|
209
|
+
const {
|
|
210
|
+
data,
|
|
211
|
+
shape: [sx, sy],
|
|
212
|
+
stride: [tx, ty]
|
|
213
|
+
} = this;
|
|
214
|
+
for (let ox = this.offset, x = 0; x < sx; x++, ox += tx) {
|
|
215
|
+
for (let y = 0; y < sy; y++) {
|
|
216
|
+
yield data[ox + y * ty];
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
}
|
|
220
|
+
get length() {
|
|
221
|
+
return this._n || (this._n = product2(this.shape));
|
|
222
|
+
}
|
|
223
|
+
get dim() {
|
|
224
|
+
return 2;
|
|
225
|
+
}
|
|
226
|
+
get order() {
|
|
227
|
+
return abs(this.stride[1]) > abs(this.stride[0]) ? [1, 0] : [0, 1];
|
|
228
|
+
}
|
|
229
|
+
index(pos) {
|
|
230
|
+
return this.offset + dot2(pos, this.stride);
|
|
231
|
+
}
|
|
232
|
+
get(pos) {
|
|
233
|
+
return this.data[this.offset + dot2(pos, this.stride)];
|
|
234
|
+
}
|
|
235
|
+
set(pos, v) {
|
|
236
|
+
this.data[this.offset + dot2(pos, this.stride)] = v;
|
|
237
|
+
return this;
|
|
238
|
+
}
|
|
239
|
+
resize(newShape, fill, storage = this.storage) {
|
|
240
|
+
const newLength = product(newShape);
|
|
241
|
+
const newData = storage.alloc(newLength);
|
|
242
|
+
if (fill !== void 0) newData.fill(fill);
|
|
243
|
+
const {
|
|
244
|
+
data,
|
|
245
|
+
shape: [sx, sy],
|
|
246
|
+
stride: [tx, ty]
|
|
247
|
+
} = this;
|
|
248
|
+
const n = min(this.length, newLength);
|
|
249
|
+
for (let ox = this.offset, i = 0, x = 0; x < sx; x++, ox += tx) {
|
|
250
|
+
for (let y = 0; y < sy && i < n; y++, i++) {
|
|
251
|
+
newData[i] = data[ox + y * ty];
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
return tensor(this.type, newShape, {
|
|
255
|
+
storage,
|
|
256
|
+
data: newData,
|
|
257
|
+
copy: false
|
|
258
|
+
});
|
|
259
|
+
}
|
|
260
|
+
toString() {
|
|
261
|
+
const res = [];
|
|
262
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
263
|
+
res.push(this.pick([i]).toString());
|
|
264
|
+
}
|
|
265
|
+
return res.join("\n");
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
class Tensor3 extends ATensor {
|
|
269
|
+
_n;
|
|
270
|
+
*[Symbol.iterator]() {
|
|
271
|
+
const {
|
|
272
|
+
data,
|
|
273
|
+
shape: [sx, sy, sz],
|
|
274
|
+
stride: [tx, ty, tz]
|
|
275
|
+
} = this;
|
|
276
|
+
for (let ox = this.offset, x = 0; x < sx; x++, ox += tx) {
|
|
277
|
+
for (let oy = ox, y = 0; y < sy; y++, oy += ty) {
|
|
278
|
+
for (let z = 0; z < sz; z++) {
|
|
279
|
+
yield data[oy + z * tz];
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
get length() {
|
|
285
|
+
return this._n || (this._n = product3(this.shape));
|
|
286
|
+
}
|
|
287
|
+
get dim() {
|
|
288
|
+
return 3;
|
|
289
|
+
}
|
|
290
|
+
index(pos) {
|
|
291
|
+
return this.offset + dot3(pos, this.stride);
|
|
292
|
+
}
|
|
293
|
+
get(pos) {
|
|
294
|
+
return this.data[this.offset + dot3(pos, this.stride)];
|
|
295
|
+
}
|
|
296
|
+
set(pos, v) {
|
|
297
|
+
this.data[this.offset + dot3(pos, this.stride)] = v;
|
|
298
|
+
return this;
|
|
299
|
+
}
|
|
300
|
+
resize(newShape, fill, storage = this.storage) {
|
|
301
|
+
const newLength = product(newShape);
|
|
302
|
+
const newData = storage.alloc(newLength);
|
|
303
|
+
if (fill !== void 0) newData.fill(fill);
|
|
304
|
+
const {
|
|
305
|
+
data,
|
|
306
|
+
shape: [sx, sy, sz],
|
|
307
|
+
stride: [tx, ty, tz]
|
|
308
|
+
} = this;
|
|
309
|
+
const n = min(this.length, newLength);
|
|
310
|
+
for (let ox = this.offset, i = 0, x = 0; x < sx; x++, ox += tx) {
|
|
311
|
+
for (let oy = ox, y = 0; y < sy; y++, oy += ty) {
|
|
312
|
+
for (let z = 0; z < sz && i < n; z++, i++) {
|
|
313
|
+
newData[i] = data[oy + z * tz];
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
return tensor(this.type, newShape, {
|
|
318
|
+
storage,
|
|
319
|
+
data: newData,
|
|
320
|
+
copy: false
|
|
321
|
+
});
|
|
322
|
+
}
|
|
323
|
+
toString() {
|
|
324
|
+
const res = [];
|
|
325
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
326
|
+
res.push(`--- ${i}: ---`, this.pick([i]).toString());
|
|
327
|
+
}
|
|
328
|
+
return res.join("\n");
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
class Tensor4 extends ATensor {
|
|
332
|
+
_n;
|
|
333
|
+
*[Symbol.iterator]() {
|
|
334
|
+
const {
|
|
335
|
+
data,
|
|
336
|
+
shape: [sx, sy, sz, sw],
|
|
337
|
+
stride: [tx, ty, tz, tw],
|
|
338
|
+
offset
|
|
339
|
+
} = this;
|
|
340
|
+
for (let ox = offset, x = 0; x < sx; x++, ox += tx) {
|
|
341
|
+
for (let oy = ox, y = 0; y < sy; y++, oy += ty) {
|
|
342
|
+
for (let oz = oy, z = 0; z < sz; z++, oz += tz) {
|
|
343
|
+
for (let w = 0; w < sw; w++) {
|
|
344
|
+
yield data[oz + w * tw];
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
get length() {
|
|
351
|
+
return this._n || (this._n = product4(this.shape));
|
|
352
|
+
}
|
|
353
|
+
get dim() {
|
|
354
|
+
return 4;
|
|
355
|
+
}
|
|
356
|
+
index(pos) {
|
|
357
|
+
return this.offset + dot4(pos, this.stride);
|
|
358
|
+
}
|
|
359
|
+
get(pos) {
|
|
360
|
+
return this.data[this.offset + dot4(pos, this.stride)];
|
|
361
|
+
}
|
|
362
|
+
set(pos, v) {
|
|
363
|
+
this.data[this.offset + dot4(pos, this.stride)] = v;
|
|
364
|
+
return this;
|
|
365
|
+
}
|
|
366
|
+
resize(newShape, fill, storage = this.storage) {
|
|
367
|
+
const newLength = product(newShape);
|
|
368
|
+
const newData = storage.alloc(newLength);
|
|
369
|
+
if (fill !== void 0) newData.fill(fill);
|
|
370
|
+
const {
|
|
371
|
+
data,
|
|
372
|
+
shape: [sx, sy, sz, sw],
|
|
373
|
+
stride: [tx, ty, tz, tw]
|
|
374
|
+
} = this;
|
|
375
|
+
const n = min(this.length, newLength);
|
|
376
|
+
for (let ox = this.offset, i = 0, x = 0; x < sx; x++, ox += tx) {
|
|
377
|
+
for (let oy = ox, y = 0; y < sy; y++, oy += ty) {
|
|
378
|
+
for (let oz = oy, z = 0; z < sz; z++, oz += tz) {
|
|
379
|
+
for (let w = 0; w < sw && i < n; w++, i++) {
|
|
380
|
+
newData[i] = data[oz + w * tw];
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
return tensor(this.type, newShape, {
|
|
386
|
+
storage,
|
|
387
|
+
data: newData,
|
|
388
|
+
copy: false
|
|
389
|
+
});
|
|
390
|
+
}
|
|
391
|
+
toString() {
|
|
392
|
+
const res = [];
|
|
393
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
394
|
+
res.push(`--- cube ${i}: ---`, this.pick([i]).toString());
|
|
395
|
+
}
|
|
396
|
+
return res.join("\n");
|
|
397
|
+
}
|
|
398
|
+
}
|
|
399
|
+
const TENSOR_IMPLS = [
|
|
400
|
+
void 0,
|
|
401
|
+
Tensor1,
|
|
402
|
+
Tensor2,
|
|
403
|
+
Tensor3,
|
|
404
|
+
Tensor4
|
|
405
|
+
];
|
|
406
|
+
function tensor(...args) {
|
|
407
|
+
if (Array.isArray(args[0])) return tensorFromArray(args[0], args[1]);
|
|
408
|
+
const type = args[0];
|
|
409
|
+
const shape = args[1];
|
|
410
|
+
const opts = args[2];
|
|
411
|
+
const storage = opts?.storage ?? STORAGE[type];
|
|
412
|
+
const stride = opts?.stride ?? shapeToStride(shape);
|
|
413
|
+
let data;
|
|
414
|
+
if (opts?.data) {
|
|
415
|
+
if (opts?.copy === false) data = opts.data;
|
|
416
|
+
else data = storage.from(opts.data);
|
|
417
|
+
} else {
|
|
418
|
+
data = storage.alloc(product(shape));
|
|
419
|
+
}
|
|
420
|
+
let offset = opts?.offset;
|
|
421
|
+
if (offset === void 0) {
|
|
422
|
+
offset = 0;
|
|
423
|
+
for (let i = 0; i < shape.length; i++) {
|
|
424
|
+
if (stride[i] < 0) {
|
|
425
|
+
offset -= (shape[i] - 1) * stride[i];
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
const ctor = TENSOR_IMPLS[shape.length];
|
|
430
|
+
return ctor ? new ctor(type, storage, data, shape, stride, offset) : unsupported(`unsupported dimension: ${shape.length}`);
|
|
431
|
+
}
|
|
432
|
+
function tensorFromArray(data, opts) {
|
|
433
|
+
const shape = [data.length];
|
|
434
|
+
let $data = data;
|
|
435
|
+
while (Array.isArray($data[0])) {
|
|
436
|
+
shape.push($data[0].length);
|
|
437
|
+
$data = $data.flat();
|
|
438
|
+
}
|
|
439
|
+
const $type = opts?.type ?? (isNumber($data[0]) ? "num" : "str");
|
|
440
|
+
if ($type === "str" && isNumber($data[0]))
|
|
441
|
+
illegalArgs("mismatched data type");
|
|
442
|
+
return tensor($type, shape, {
|
|
443
|
+
data: $data,
|
|
444
|
+
copy: $type !== "num" && $type !== "str",
|
|
445
|
+
storage: opts?.storage
|
|
446
|
+
});
|
|
447
|
+
}
|
|
448
|
+
const shapeToStride = (shape) => {
|
|
449
|
+
const n = shape.length;
|
|
450
|
+
const stride = new Array(n);
|
|
451
|
+
for (let i = n, s = 1; i-- > 0; s *= shape[i]) {
|
|
452
|
+
stride[i] = s;
|
|
453
|
+
}
|
|
454
|
+
return stride;
|
|
455
|
+
};
|
|
456
|
+
const strideOrder = (strides) => strides.map((x, i) => [x, i]).sort((a, b) => abs(b[0]) - abs(a[0])).map((x) => x[1]);
|
|
457
|
+
const __lo = (select, { shape, stride, offset }) => {
|
|
458
|
+
const newShape = [];
|
|
459
|
+
for (let i = 0, n = shape.length; i < n; i++) {
|
|
460
|
+
const x = select[i];
|
|
461
|
+
newShape.push(
|
|
462
|
+
x >= 0 ? (offset += stride[i] * x, shape[i] - x) : shape[i]
|
|
463
|
+
);
|
|
464
|
+
}
|
|
465
|
+
return { shape: newShape, offset };
|
|
466
|
+
};
|
|
467
|
+
const __hi = (select, { shape }) => {
|
|
468
|
+
const newShape = [];
|
|
469
|
+
for (let i = 0, n = shape.length; i < n; i++) {
|
|
470
|
+
const x = select[i];
|
|
471
|
+
newShape.push(x > 0 ? x : shape[i]);
|
|
472
|
+
}
|
|
473
|
+
return newShape;
|
|
474
|
+
};
|
|
475
|
+
const __step = (select, { shape, stride, offset }) => {
|
|
476
|
+
const newShape = shape.slice();
|
|
477
|
+
const newStride = stride.slice();
|
|
478
|
+
for (let i = 0, n = shape.length; i < n; i++) {
|
|
479
|
+
const x = select[i];
|
|
480
|
+
if (x) {
|
|
481
|
+
if (x < 0) {
|
|
482
|
+
offset += stride[i] * (shape[i] - 1);
|
|
483
|
+
newShape[i] = ceil(-shape[i] / x);
|
|
484
|
+
} else {
|
|
485
|
+
newShape[i] = ceil(shape[i] / x);
|
|
486
|
+
}
|
|
487
|
+
newStride[i] *= x;
|
|
488
|
+
}
|
|
489
|
+
}
|
|
490
|
+
return { shape: newShape, stride: newStride, offset };
|
|
491
|
+
};
|
|
492
|
+
const __pick = (select, { shape, stride, offset }) => {
|
|
493
|
+
const newShape = [];
|
|
494
|
+
const newStride = [];
|
|
495
|
+
for (let i = 0, n = shape.length; i < n; i++) {
|
|
496
|
+
const x = select[i];
|
|
497
|
+
if (x >= 0) {
|
|
498
|
+
offset += stride[i] * x;
|
|
499
|
+
} else {
|
|
500
|
+
newShape.push(shape[i]);
|
|
501
|
+
newStride.push(stride[i]);
|
|
502
|
+
}
|
|
503
|
+
}
|
|
504
|
+
return { shape: newShape, stride: newStride, offset };
|
|
505
|
+
};
|
|
506
|
+
export {
|
|
507
|
+
ATensor,
|
|
508
|
+
TENSOR_IMPLS,
|
|
509
|
+
Tensor1,
|
|
510
|
+
Tensor2,
|
|
511
|
+
Tensor3,
|
|
512
|
+
Tensor4,
|
|
513
|
+
shapeToStride,
|
|
514
|
+
strideOrder,
|
|
515
|
+
tensor,
|
|
516
|
+
tensorFromArray
|
|
517
|
+
};
|
package/top.d.ts
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import type { Maybe } from "@thi.ng/api";
|
|
2
|
+
/**
|
|
3
|
+
* Specialized / optimized version of
|
|
4
|
+
* [`thi.ng/defmulti`](https://thi.ng/defmulti) for tensor operations. Uses
|
|
5
|
+
* simplified logic to dispatch on tensor dimension ({@link ITensor.dim}) of
|
|
6
|
+
* `dispatch` argument.
|
|
7
|
+
*
|
|
8
|
+
* @param dispatch - arg index (default: 1)
|
|
9
|
+
*/
|
|
10
|
+
export declare const top: <T extends Function>(dispatch?: number, fallback?: T, ...optimized: Maybe<T>[]) => {
|
|
11
|
+
(...args: any[]): any;
|
|
12
|
+
add(dim: number, fn: T): T;
|
|
13
|
+
default(fn: T): T;
|
|
14
|
+
impl(dim?: number): Maybe<T>;
|
|
15
|
+
};
|
|
16
|
+
//# sourceMappingURL=top.d.ts.map
|
package/top.js
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import { unsupported } from "@thi.ng/errors/unsupported";
|
|
2
|
+
const top = (dispatch = 1, fallback, ...optimized) => {
|
|
3
|
+
const impls = [void 0].concat(optimized);
|
|
4
|
+
const fn = (...args) => {
|
|
5
|
+
const g = impls[args[dispatch].dim] || fallback;
|
|
6
|
+
return g ? g(...args) : unsupported(`no impl for dimension ${args[dispatch].dim}`);
|
|
7
|
+
};
|
|
8
|
+
fn.add = (dim, fn2) => impls[dim] = fn2;
|
|
9
|
+
fn.default = (fn2) => fallback = fn2;
|
|
10
|
+
fn.impl = (dim) => dim != null ? impls[dim] || fallback : fallback;
|
|
11
|
+
return fn;
|
|
12
|
+
};
|
|
13
|
+
export {
|
|
14
|
+
top
|
|
15
|
+
};
|