@stellarapp/tfjs-stellar 1.0.5 → 1.0.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/masks.test.d.ts +2 -0
- package/dist/masks.test.d.ts.map +1 -0
- package/dist/masks.test.js +55 -0
- package/dist/masks.test.js.map +1 -0
- package/dist/utils.test.js +0 -15
- package/dist/utils.test.js.map +1 -1
- package/package.json +1 -1
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"masks.test.d.ts","sourceRoot":"","sources":["../src/masks.test.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import * as masks from "./masks";
|
|
3
|
+
tf.env().set('IS_NODE', false);
|
|
4
|
+
describe(" mask tests", () => {
|
|
5
|
+
test("packing mask for self-attention", () => {
|
|
6
|
+
const boundaries = new Int32Array([1, 0, 0, 1, 0, 0, 1, 0, 0]);
|
|
7
|
+
const packing_mask = masks.packing(boundaries);
|
|
8
|
+
const expected_mask = tf.tensor([
|
|
9
|
+
[0, 0, 0, -10000000, -10000000, -10000001, -10000000, -10000000, -10000002],
|
|
10
|
+
[0, 0, 0, -10000001, -10000001, -10000002, -10000000, -10000000, -10000001],
|
|
11
|
+
[0, 0, 0, -10000003, -10000000, -10000000, -10000002, -10000002, -10000000],
|
|
12
|
+
[-10000000, -10000002, -10000000, 0, 0, 0, -10000001, -10000001, -10000000],
|
|
13
|
+
[-10000000, -10000000, -10000001, 0, 0, 0, -10000003, -10000000, -10000001],
|
|
14
|
+
[-10000000, -10000002, -10000000, 0, 0, 0, -10000000, -10000001, -10000000],
|
|
15
|
+
[-10000000, -10000000, -10000000, -10000000, -10000000, -10000000, 0, 0, 0],
|
|
16
|
+
[-10000000, -10000002, -10000000, -10000002, -10000000, -10000000, 0, 0, 0],
|
|
17
|
+
[-10000000, -10000001, -10000000, -10000000, -10000002, -10000003, 0, 0, 0]
|
|
18
|
+
]);
|
|
19
|
+
// The mask uses -1e7 on masked positions which introduces extra integers on
|
|
20
|
+
// some values in the float32 tensor. Ideally it should check that the sum is equal to 0,
|
|
21
|
+
// but since there are 54 masked positions, we'll just check that it's less than 108
|
|
22
|
+
expect(packing_mask.sub(expected_mask).sum().arraySync()).toBeLessThan(108);
|
|
23
|
+
});
|
|
24
|
+
test("packing mask for non-packed sequence", () => {
|
|
25
|
+
const boundaries = new Int32Array([1, 0, 0, 0, 0]);
|
|
26
|
+
const packing_mask = masks.packing(boundaries);
|
|
27
|
+
expect(packing_mask.sum().arraySync()).toEqual(0);
|
|
28
|
+
});
|
|
29
|
+
test("causal mask size 4", async () => {
|
|
30
|
+
const seq_len = 4;
|
|
31
|
+
const causal_mask = masks.causal(seq_len, seq_len);
|
|
32
|
+
const _ = -1e7;
|
|
33
|
+
const expected_mask = tf.tensor([
|
|
34
|
+
[0, _, _, _],
|
|
35
|
+
[0, 0, _, _],
|
|
36
|
+
[0, 0, 0, _],
|
|
37
|
+
[0, 0, 0, 0]
|
|
38
|
+
]);
|
|
39
|
+
// this might fail due to precision issues on the masked positions,
|
|
40
|
+
// in which case use less <= to 6 or 12 (number of masked positions x2)
|
|
41
|
+
expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
|
|
42
|
+
});
|
|
43
|
+
test("causal mask size 5", () => {
|
|
44
|
+
const expected_mask = tf.tensor([
|
|
45
|
+
[0, -10000000, -10000000, -10000000, -10000000],
|
|
46
|
+
[0, 0, -10000000, -10000000, -10000000],
|
|
47
|
+
[0, 0, 0, -10000000, -10000000],
|
|
48
|
+
[0, 0, 0, 0, -10000000],
|
|
49
|
+
[0, 0, 0, 0, 0]
|
|
50
|
+
]);
|
|
51
|
+
const causal_mask = masks.causal(5, 5);
|
|
52
|
+
expect(causal_mask.equal(expected_mask).sum().arraySync()).toBe(25);
|
|
53
|
+
});
|
|
54
|
+
});
|
|
55
|
+
//# sourceMappingURL=masks.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"masks.test.js","sourceRoot":"","sources":["../src/masks.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,KAAK,KAAK,MAAM,SAAS,CAAC;AAEjC,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,aAAa,EAAE,GAAG,EAAE;IACzB,IAAI,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACzC,MAAM,UAAU,GAAG,IAAI,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAC/D,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC;QAE/C,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YAC3E,CAAC,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SAC9E,CAAC,CAAC;QAEH,4EAA4E;QAC5E,yFAAyF;QACzF,oFAAoF;QACpF,MAAM,CAAE,YAAY,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,SAAS,EAAa,CAAC,CAAC,YAAY,CAAC,GAAG,CAAC,CAAC;IAC5F,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,sCAAsC,EAAE,GAAG,EAAE;QAC9C,MAAM,UAAU,GAAG,IAAI,UAAU,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QACnD,MAAM,YAAY,GAAG,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC;QAE/C,MAAM,CAAE,YAAY,CAAC,GAAG,EAAE,CAAC,SAAS,EAAa,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAClE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,MAAM,OAAO,GAAG,CAAC,CAAC;QAClB,MAAM,WAAW,GAAG,KAAK,CAAC,MAAM,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;QAEnD,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC;QACf,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACf,CAAC,CAAC;QAEH,mEAAmE;QACnE,uEAAuE;QACvE,MAAM,CAAC,CAAC,MAAM,WAAW,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,oBAAoB,EAAE,GAAG,EAAE;QAC5B,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC/C,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YACvC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,EAAE,CAAC,QAAQ,CAAC;YAC/B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,QAAQ,CAAC;YACvB,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SAAC,CAAC,CAAC;QAEtB,MAAM,WAAW,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;QAEvC,MAAM,CAAC,WAAW,CAAC,KAAK,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,SAAS,EAAE,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC;IACxE,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
|
package/dist/utils.test.js
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import * as tf from "@tensorflow/tfjs";
|
|
2
2
|
import { getScaleShape, getRandomCropStart } from "./utils";
|
|
3
|
-
import { causal } from "./masks";
|
|
4
3
|
// avoid TFJS node message during Jest testing
|
|
5
4
|
tf.env().set('IS_NODE', false);
|
|
6
5
|
describe("test custom TFJS utility functions", () => {
|
|
@@ -55,19 +54,5 @@ describe("test custom TFJS utility functions", () => {
|
|
|
55
54
|
const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555]);
|
|
56
55
|
expect(scale8_h).toBeLessThan(scale8_w);
|
|
57
56
|
});
|
|
58
|
-
test("causal attention map", async () => {
|
|
59
|
-
const seq_len = 4;
|
|
60
|
-
const causal_mask = causal(seq_len, seq_len);
|
|
61
|
-
const _ = -1e7;
|
|
62
|
-
const expected_mask = tf.tensor([
|
|
63
|
-
[0, _, _, _],
|
|
64
|
-
[0, 0, _, _],
|
|
65
|
-
[0, 0, 0, _],
|
|
66
|
-
[0, 0, 0, 0]
|
|
67
|
-
]);
|
|
68
|
-
// this might fail due to precision issues on the masked positions,
|
|
69
|
-
// in which case use less <= to 6 or 12 (number of masked positions x2)
|
|
70
|
-
expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
|
|
71
|
-
});
|
|
72
57
|
});
|
|
73
58
|
//# sourceMappingURL=utils.test.js.map
|
package/dist/utils.test.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"utils.test.js","sourceRoot":"","sources":["../src/utils.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,aAAa,EAAE,kBAAkB,EAAE,MAAM,SAAS,CAAC;
|
|
1
|
+
{"version":3,"file":"utils.test.js","sourceRoot":"","sources":["../src/utils.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,aAAa,EAAE,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAG5D,8CAA8C;AAC9C,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,2DAA2D,EAAE,KAAK,IAAI,EAAE;QACzE,sCAAsC;QACtC,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAC/C,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAElD,MAAM,CAAC,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzE,CAAC,CAAC,CAAC;IAGH,EAAE,CAAC,6CAA6C,EAAE,KAAK,IAAI,EAAE;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,IAAI,EAAE,GAAG,CAAqB,CAAC;YACjD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QAED,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAChD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;IACL,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,KAAK,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACnD,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IACtC,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,oCAAoC;QACpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,0CAA0C;QAC1C,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAC;QAEpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC;QAEpC,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,eAAe,CAAC,QAAQ,CAAC,CAAC;QAE3C,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,YAAY,CAAC,QAAQ,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;AAEP,CAAC,CAAC,CAAC"}
|