@stellarapp/tfjs-stellar 1.0.0 → 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +47 -0
- package/dist/index.d.ts +7 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +7 -0
- package/dist/index.js.map +1 -0
- package/dist/jest.config.d.ts +8 -0
- package/dist/jest.config.d.ts.map +1 -0
- package/{jest.config.ts → dist/jest.config.js} +8 -64
- package/dist/jest.config.js.map +1 -0
- package/dist/kv_cache.d.ts +53 -0
- package/dist/kv_cache.d.ts.map +1 -0
- package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
- package/dist/kv_cache.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.js +76 -0
- package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
- package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
- package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
- package/dist/layers/gpt_decoder_block.js.map +1 -0
- package/dist/layers/index.d.ts +17 -0
- package/dist/layers/index.d.ts.map +1 -0
- package/dist/layers/index.js +33 -0
- package/dist/layers/index.js.map +1 -0
- package/dist/layers/multihead_attention.d.ts +106 -0
- package/dist/layers/multihead_attention.d.ts.map +1 -0
- package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
- package/dist/layers/multihead_attention.js.map +1 -0
- package/dist/layers/multihead_attention.test.d.ts +2 -0
- package/dist/layers/multihead_attention.test.d.ts.map +1 -0
- package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
- package/dist/layers/multihead_attention.test.js.map +1 -0
- package/dist/layers/positional_encoding.d.ts +37 -0
- package/dist/layers/positional_encoding.d.ts.map +1 -0
- package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
- package/dist/layers/positional_encoding.js.map +1 -0
- package/dist/layers/positional_encoding.test.d.ts +2 -0
- package/dist/layers/positional_encoding.test.d.ts.map +1 -0
- package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
- package/dist/layers/positional_encoding.test.js.map +1 -0
- package/dist/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
- package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
- package/dist/layers/rotary_position_embedding.js.map +1 -0
- package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/layers/rotary_position_embedding.test.js +88 -0
- package/dist/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
- package/dist/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
- package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/layers/transformer_decoder.d.ts +69 -0
- package/dist/layers/transformer_decoder.d.ts.map +1 -0
- package/dist/layers/transformer_decoder.js +182 -0
- package/dist/layers/transformer_decoder.js.map +1 -0
- package/dist/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
- package/dist/layers/transformer_decoder.test.js.map +1 -0
- package/dist/layers/transformer_encoder.d.ts +55 -0
- package/dist/layers/transformer_encoder.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
- package/dist/layers/transformer_encoder.js.map +1 -0
- package/dist/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
- package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
- package/dist/layers/transformer_encoder.test.js.map +1 -0
- package/dist/losses/dice.d.ts +30 -0
- package/dist/losses/dice.d.ts.map +1 -0
- package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
- package/dist/losses/dice.js.map +1 -0
- package/dist/losses/index.d.ts +2 -0
- package/dist/losses/index.d.ts.map +1 -0
- package/dist/losses/index.js +2 -0
- package/dist/losses/index.js.map +1 -0
- package/dist/masks.d.ts +20 -0
- package/dist/masks.d.ts.map +1 -0
- package/{src/packing_mask.ts → dist/masks.js} +16 -7
- package/dist/masks.js.map +1 -0
- package/dist/metrics.d.ts +20 -0
- package/dist/metrics.d.ts.map +1 -0
- package/{src/metrics.ts → dist/metrics.js} +8 -12
- package/dist/metrics.js.map +1 -0
- package/dist/models/gpt_model.d.ts +94 -0
- package/dist/models/gpt_model.d.ts.map +1 -0
- package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
- package/dist/models/gpt_model.js.map +1 -0
- package/dist/models/index.d.ts +7 -0
- package/dist/models/index.d.ts.map +1 -0
- package/dist/models/index.js +13 -0
- package/dist/models/index.js.map +1 -0
- package/dist/models/llm_model.d.ts +87 -0
- package/dist/models/llm_model.d.ts.map +1 -0
- package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
- package/dist/models/llm_model.js.map +1 -0
- package/dist/models/u_net.d.ts +40 -0
- package/dist/models/u_net.d.ts.map +1 -0
- package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
- package/dist/models/u_net.js.map +1 -0
- package/dist/src/index.d.ts +6 -0
- package/dist/src/index.d.ts.map +1 -0
- package/dist/src/index.js +6 -0
- package/dist/src/index.js.map +1 -0
- package/dist/src/kv_cache.d.ts +53 -0
- package/dist/src/kv_cache.d.ts.map +1 -0
- package/dist/src/kv_cache.js +135 -0
- package/dist/src/kv_cache.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
- package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
- package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
- package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
- package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
- package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
- package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
- package/dist/src/layers/gpt_decoder_block.js +51 -0
- package/dist/src/layers/gpt_decoder_block.js.map +1 -0
- package/dist/src/layers/index.d.ts +17 -0
- package/dist/src/layers/index.d.ts.map +1 -0
- package/dist/src/layers/index.js +33 -0
- package/dist/src/layers/index.js.map +1 -0
- package/dist/src/layers/multihead_attention.d.ts +106 -0
- package/dist/src/layers/multihead_attention.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.js +269 -0
- package/dist/src/layers/multihead_attention.js.map +1 -0
- package/dist/src/layers/multihead_attention.test.d.ts +2 -0
- package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
- package/dist/src/layers/multihead_attention.test.js +160 -0
- package/dist/src/layers/multihead_attention.test.js.map +1 -0
- package/dist/src/layers/positional_encoding.d.ts +37 -0
- package/dist/src/layers/positional_encoding.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.js +115 -0
- package/dist/src/layers/positional_encoding.js.map +1 -0
- package/dist/src/layers/positional_encoding.test.d.ts +2 -0
- package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
- package/dist/src/layers/positional_encoding.test.js +95 -0
- package/dist/src/layers/positional_encoding.test.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
- package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.js +99 -0
- package/dist/src/layers/rotary_position_embedding.js.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
- package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/rotary_position_embedding.test.js +88 -0
- package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
- package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.js +109 -0
- package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
- package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
- package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
- package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
- package/dist/src/layers/transformer_decoder.d.ts +69 -0
- package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
- package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
- package/dist/src/layers/transformer_decoder.js.map +1 -0
- package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_decoder.test.js +72 -0
- package/dist/src/layers/transformer_decoder.test.js.map +1 -0
- package/dist/src/layers/transformer_encoder.d.ts +55 -0
- package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.js +175 -0
- package/dist/src/layers/transformer_encoder.js.map +1 -0
- package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
- package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
- package/dist/src/layers/transformer_encoder.test.js +58 -0
- package/dist/src/layers/transformer_encoder.test.js.map +1 -0
- package/dist/src/losses/dice.d.ts +30 -0
- package/dist/src/losses/dice.d.ts.map +1 -0
- package/dist/src/losses/dice.js +93 -0
- package/dist/src/losses/dice.js.map +1 -0
- package/dist/src/losses/index.d.ts +2 -0
- package/dist/src/losses/index.d.ts.map +1 -0
- package/dist/src/losses/index.js +2 -0
- package/dist/src/losses/index.js.map +1 -0
- package/dist/src/masks.d.ts +20 -0
- package/dist/src/masks.d.ts.map +1 -0
- package/dist/src/masks.js +37 -0
- package/dist/src/masks.js.map +1 -0
- package/dist/src/metrics.d.ts +20 -0
- package/dist/src/metrics.d.ts.map +1 -0
- package/dist/src/metrics.js +28 -0
- package/dist/src/metrics.js.map +1 -0
- package/dist/src/models/gpt_model.d.ts +94 -0
- package/dist/src/models/gpt_model.d.ts.map +1 -0
- package/dist/src/models/gpt_model.js +154 -0
- package/dist/src/models/gpt_model.js.map +1 -0
- package/dist/src/models/index.d.ts +3 -0
- package/dist/src/models/index.d.ts.map +1 -0
- package/{src/models/index.ts → dist/src/models/index.js} +1 -0
- package/dist/src/models/index.js.map +1 -0
- package/dist/src/models/llm_model.d.ts +87 -0
- package/dist/src/models/llm_model.d.ts.map +1 -0
- package/dist/src/models/llm_model.js +245 -0
- package/dist/src/models/llm_model.js.map +1 -0
- package/dist/src/models/u_net.d.ts +40 -0
- package/dist/src/models/u_net.d.ts.map +1 -0
- package/dist/src/models/u_net.js +151 -0
- package/dist/src/models/u_net.js.map +1 -0
- package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
- package/dist/src/tfjs_types.d.ts.map +1 -0
- package/dist/src/tfjs_types.js +2 -0
- package/dist/src/tfjs_types.js.map +1 -0
- package/dist/src/utils.d.ts +28 -0
- package/dist/src/utils.d.ts.map +1 -0
- package/{src/utils.ts → dist/src/utils.js} +10 -33
- package/dist/src/utils.js.map +1 -0
- package/dist/src/utils.test.d.ts +2 -0
- package/dist/src/utils.test.d.ts.map +1 -0
- package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
- package/dist/src/utils.test.js.map +1 -0
- package/dist/tfjs_types.d.ts +10 -0
- package/dist/tfjs_types.d.ts.map +1 -0
- package/dist/tfjs_types.js +2 -0
- package/dist/tfjs_types.js.map +1 -0
- package/dist/utils.d.ts +28 -0
- package/dist/utils.d.ts.map +1 -0
- package/dist/utils.js +63 -0
- package/dist/utils.js.map +1 -0
- package/dist/utils.test.d.ts +2 -0
- package/dist/utils.test.d.ts.map +1 -0
- package/dist/utils.test.js +73 -0
- package/dist/utils.test.js.map +1 -0
- package/package.json +14 -4
- package/src/index.ts +0 -93
- package/src/layers/rotary_position_embedding.test.ts +0 -107
- package/src/losses/index.ts +0 -1
- package/src/testing.ts +0 -1
- package/tsconfig.json +0 -49
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { getScaleShape, getRandomCropStart } from "@/utils";
|
|
3
|
+
import { causal } from "@/masks";
|
|
4
|
+
// avoid TFJS node message during Jest testing
|
|
5
|
+
tf.env().set('IS_NODE', false);
|
|
6
|
+
describe("test custom TFJS utility functions", () => {
|
|
7
|
+
test("crop an image using the same shape, results in same shape", async () => {
|
|
8
|
+
// cropping an image of the same shape
|
|
9
|
+
const img_size = [133, 84];
|
|
10
|
+
const target_size = [133, 84];
|
|
11
|
+
expect(getRandomCropStart(img_size, target_size)).toEqual([0, 0, 0]);
|
|
12
|
+
});
|
|
13
|
+
it("should throw when crop is larger than image", async () => {
|
|
14
|
+
expect(() => getRandomCropStart([128, 128], [1000, 2000])).toThrow();
|
|
15
|
+
});
|
|
16
|
+
test("cropped image shape", async () => {
|
|
17
|
+
// cropping from wide to tall image
|
|
18
|
+
for (let i = 0; i < 100; i++) {
|
|
19
|
+
const img_size = [4923, 832];
|
|
20
|
+
const target_size = [333, 739];
|
|
21
|
+
const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size);
|
|
22
|
+
expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
|
|
23
|
+
expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
|
|
24
|
+
}
|
|
25
|
+
// cropping from tall to wide image
|
|
26
|
+
for (let i = 0; i < 100; i++) {
|
|
27
|
+
const img_size = [381, 999];
|
|
28
|
+
const target_size = [300, 157];
|
|
29
|
+
const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size);
|
|
30
|
+
expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
|
|
31
|
+
expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
|
|
32
|
+
}
|
|
33
|
+
});
|
|
34
|
+
test("scale 1:1, results in the same shape", async () => {
|
|
35
|
+
const scale = getScaleShape([256, 256], [256, 256]);
|
|
36
|
+
expect(scale).toEqual([256, 256]);
|
|
37
|
+
});
|
|
38
|
+
test("scaled image shape", async () => {
|
|
39
|
+
// scaling squares result in squares
|
|
40
|
+
const scale1 = getScaleShape([256, 256], [128, 128]);
|
|
41
|
+
expect(scale1).toEqual([128, 128]);
|
|
42
|
+
const scale2 = getScaleShape([128, 128], [256, 256]);
|
|
43
|
+
expect(scale2).toEqual([256, 256]);
|
|
44
|
+
const scale3 = getScaleShape([123, 123], [321, 321]);
|
|
45
|
+
expect(scale3).toEqual([321, 321]);
|
|
46
|
+
const scale4 = getScaleShape([321, 321], [123, 123]);
|
|
47
|
+
expect(scale4).toEqual([123, 123]);
|
|
48
|
+
// scaling rectangles result in rectangles
|
|
49
|
+
const scale5 = getScaleShape([640, 480], [1280, 960]);
|
|
50
|
+
expect(scale5).toEqual([1280, 960]);
|
|
51
|
+
const scale6 = getScaleShape([480, 640], [960, 1280]);
|
|
52
|
+
expect(scale6).toEqual([960, 1280]);
|
|
53
|
+
const [scale7_h, scale7_w] = getScaleShape([777, 555], [555, 333]);
|
|
54
|
+
expect(scale7_h).toBeGreaterThan(scale7_w);
|
|
55
|
+
const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555]);
|
|
56
|
+
expect(scale8_h).toBeLessThan(scale8_w);
|
|
57
|
+
});
|
|
58
|
+
test("causal attention map", async () => {
|
|
59
|
+
const seq_len = 4;
|
|
60
|
+
const causal_mask = causal(seq_len, seq_len);
|
|
61
|
+
const _ = -1e7;
|
|
62
|
+
const expected_mask = tf.tensor([
|
|
63
|
+
[0, _, _, _],
|
|
64
|
+
[0, 0, _, _],
|
|
65
|
+
[0, 0, 0, _],
|
|
66
|
+
[0, 0, 0, 0]
|
|
67
|
+
]);
|
|
68
|
+
// this might fail due to precision issues on the masked positions,
|
|
69
|
+
// in which case use less <= to 6 or 12 (number of masked positions x2)
|
|
70
|
+
expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
|
|
71
|
+
});
|
|
72
|
+
});
|
|
73
|
+
//# sourceMappingURL=utils.test.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"utils.test.js","sourceRoot":"","sources":["../src/utils.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,aAAa,EAAE,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAC5D,OAAO,EAAE,MAAM,EAAE,MAAM,SAAS,CAAC;AAEjC,8CAA8C;AAC9C,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,2DAA2D,EAAE,KAAK,IAAI,EAAE;QACzE,sCAAsC;QACtC,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAC/C,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAElD,MAAM,CAAC,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzE,CAAC,CAAC,CAAC;IAGH,EAAE,CAAC,6CAA6C,EAAE,KAAK,IAAI,EAAE;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,IAAI,EAAE,GAAG,CAAqB,CAAC;YACjD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QAED,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAChD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;IACL,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,KAAK,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACnD,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IACtC,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,oCAAoC;QACpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,0CAA0C;QAC1C,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAC;QAEpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC;QAEpC,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,eAAe,CAAC,QAAQ,CAAC,CAAC;QAE3C,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,YAAY,CAAC,QAAQ,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,OAAO,GAAG,CAAC,CAAC;QAClB,MAAM,WAAW,GAAG,MAAM,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;QAE7C,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC;QACf,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACf,CAAC,CAAC;QAEH,mEAAmE;QACnE,uEAAuE;QACvE,MAAM,CAAC,CAAC,MAAM,WAAW,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAC;AAEP,CAAC,CAAC,CAAC"}
|
package/package.json
CHANGED
|
@@ -1,18 +1,24 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@stellarapp/tfjs-stellar",
|
|
3
|
-
"version": "1.0.
|
|
3
|
+
"version": "1.0.2",
|
|
4
4
|
"description": "An extension of TensorFlow.js for implementing large language models.",
|
|
5
5
|
"license": "ISC",
|
|
6
6
|
"author": "",
|
|
7
7
|
"type": "module",
|
|
8
|
-
"main": "index.
|
|
8
|
+
"main": "dist/index.js",
|
|
9
|
+
"types": "dist/index.d.ts",
|
|
10
|
+
"files": [
|
|
11
|
+
"dist"
|
|
12
|
+
],
|
|
9
13
|
"scripts": {
|
|
10
|
-
"test": "npx jest"
|
|
14
|
+
"test": "npx jest",
|
|
15
|
+
"build": "tsc"
|
|
11
16
|
},
|
|
12
17
|
"devDependencies": {
|
|
13
18
|
"@tensorflow/tfjs": "^4.22.0",
|
|
14
19
|
"@types/jest": "^30.0.0",
|
|
15
20
|
"@types/node": "^26.0.0",
|
|
21
|
+
"globals": "^17.6.0",
|
|
16
22
|
"jest": "^30.4.2",
|
|
17
23
|
"ts-jest": "^29.4.11",
|
|
18
24
|
"tsx": "^4.22.4",
|
|
@@ -20,5 +26,9 @@
|
|
|
20
26
|
},
|
|
21
27
|
"peerDependencies": {
|
|
22
28
|
"@tensorflow/tfjs": "*"
|
|
29
|
+
},
|
|
30
|
+
"repository": {
|
|
31
|
+
"type": "git",
|
|
32
|
+
"url": "https://github.com/rkuang9/tfjs-stellar.git"
|
|
23
33
|
}
|
|
24
|
-
}
|
|
34
|
+
}
|
package/src/index.ts
DELETED
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
export * as models from "./models";
|
|
2
|
-
export * as losses from "./losses";
|
|
3
|
-
export * as metrics from "./metrics";
|
|
4
|
-
|
|
5
|
-
import { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
6
|
-
export { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
|
|
7
|
-
|
|
8
|
-
import { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
|
|
9
|
-
export { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
|
|
10
|
-
|
|
11
|
-
import { TransformerEncoder, type TransformerEncoderArgs, } from "@/layers/transformer_encoder";
|
|
12
|
-
export { TransformerEncoder, type TransformerEncoderArgs, } from "@/layers/transformer_encoder";
|
|
13
|
-
|
|
14
|
-
import { TransformerDecoder, type TransformerDecoderArgs, } from "@/layers/transformer_decoder";
|
|
15
|
-
export { TransformerDecoder, type TransformerDecoderArgs, } from "@/layers/transformer_decoder";
|
|
16
|
-
|
|
17
|
-
import { TokenAndPositionalEmbedding, type TokenAndPositionalEmbeddingArgs } from "@/layers/token_and_positional_embedding";
|
|
18
|
-
export { TokenAndPositionalEmbedding, type TokenAndPositionalEmbeddingArgs } from "@/layers/token_and_positional_embedding";
|
|
19
|
-
|
|
20
|
-
import { PositionalEncoding, type PositionalEncodingArgs } from "@/layers/positional_encoding";
|
|
21
|
-
export { PositionalEncoding, type PositionalEncodingArgs } from "@/layers/positional_encoding";
|
|
22
|
-
|
|
23
|
-
import { GPT2DecoderBlock, type GPTDecoderBlockArgs } from "@/layers/gpt_decoder_block";
|
|
24
|
-
export { GPT2DecoderBlock as GPTDecoderBlock, type GPTDecoderBlockArgs } from "@/layers/gpt_decoder_block";
|
|
25
|
-
|
|
26
|
-
import { LlmModel, type LlmModelArgs } from "@/models/llm_model";
|
|
27
|
-
export { LlmModel, type LlmModelArgs } from "@/models/llm_model";
|
|
28
|
-
|
|
29
|
-
import { UNetModel, type UNetModelArgs } from "@/models/u_net";
|
|
30
|
-
|
|
31
|
-
import { RotaryPositionEmbedding, type RotaryPositionEmbeddingArgs } from "@/layers/rotary_position_embedding";
|
|
32
|
-
export { RotaryPositionEmbedding, type RotaryPositionEmbeddingArgs } from "@/layers/rotary_position_embedding";
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
import { GptModel, type GptModelArgs } from "@/models/gpt_model";
|
|
36
|
-
export { GptModel, type GptModelArgs } from "@/models/gpt_model";
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
// The following exports give a keras-like import just like TFJS's tf.layers.<...>
|
|
40
|
-
|
|
41
|
-
export function llmModel(args: LlmModelArgs) {
|
|
42
|
-
return new LlmModel(args);
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
export function gptModel(args: GptModelArgs) {
|
|
47
|
-
return new GptModel(args);
|
|
48
|
-
}
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
export function tokenAndPositionalEmbedding(args: TokenAndPositionalEmbeddingArgs) {
|
|
52
|
-
return new TokenAndPositionalEmbedding(args);
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
export function transformerEncoder(args: TransformerEncoderArgs) {
|
|
57
|
-
return new TransformerEncoder(args);
|
|
58
|
-
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
export function transformerDecoder(args: TransformerDecoderArgs) {
|
|
62
|
-
return new TransformerDecoder(args);
|
|
63
|
-
}
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
export function multiheadAttention(args: MultiHeadAttentionArgs) {
|
|
67
|
-
return new MultiHeadAttention(args);
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
export function cachedRopeMultiheadAttention(args: MultiHeadAttentionArgs) {
|
|
72
|
-
return new CachedRoPEMultiHeadAttention(args);
|
|
73
|
-
}
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
export function positionalEncoding(args: PositionalEncodingArgs) {
|
|
77
|
-
return new PositionalEncoding(args);
|
|
78
|
-
}
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
export function gpt2DecoderBlock(args: GPTDecoderBlockArgs) {
|
|
82
|
-
return new GPT2DecoderBlock(args);
|
|
83
|
-
}
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
export function unetModel(args: UNetModelArgs) {
|
|
87
|
-
return new UNetModel(args);
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
export function rotaryPositionEmbedding(args: RotaryPositionEmbeddingArgs) {
|
|
92
|
-
return new RotaryPositionEmbedding(args);
|
|
93
|
-
}
|
|
@@ -1,107 +0,0 @@
|
|
|
1
|
-
import { RotaryPositionEmbedding } from "@/layers/rotary_position_embedding";
|
|
2
|
-
import * as tf from "@tensorflow/tfjs";
|
|
3
|
-
|
|
4
|
-
// disables warning for using the faster node backend,
|
|
5
|
-
// https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
|
|
6
|
-
tf.env().set('IS_NODE', false);
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
describe("RotaryPositionEmbedding tests", () => {
|
|
10
|
-
test("create cache", async () => {
|
|
11
|
-
const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
|
|
12
|
-
rope.build([]);
|
|
13
|
-
|
|
14
|
-
const expected_cosine_cache = tf.tensor([[[
|
|
15
|
-
[1, 1, 1, 1, 1, 1, 1, 1],
|
|
16
|
-
[0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334, 0.9999995231628418, 0.9999995231628418],
|
|
17
|
-
[-0.416146844625473, -0.416146844625473, 0.9800665974617004, 0.9800665974617004, 0.9998000264167786, 0.9998000264167786, 0.9999979734420776, 0.9999979734420776],
|
|
18
|
-
[-0.9899924993515015, -0.9899924993515015, 0.9553365111351013, 0.9553365111351013, 0.9995500445365906, 0.9995500445365906, 0.9999955296516418, 0.9999955296516418],
|
|
19
|
-
[-0.6536436080932617, -0.6536436080932617, 0.9210609793663025, 0.9210609793663025, 0.9992001056671143, 0.9992001056671143, 0.9999920129776001, 0.9999920129776001],
|
|
20
|
-
[0.28366219997406006, 0.28366219997406006, 0.8775825500488281, 0.8775825500488281, 0.9987502694129944, 0.9987502694129944, 0.9999874830245972, 0.9999874830245972],
|
|
21
|
-
[0.9601702690124512, 0.9601702690124512, 0.8253356218338013, 0.8253356218338013, 0.998200535774231, 0.998200535774231, 0.9999819993972778, 0.9999819993972778],
|
|
22
|
-
[0.7539022564888, 0.7539022564888, 0.7648422122001648, 0.7648422122001648, 0.9975510239601135, 0.9975510239601135, 0.9999755024909973, 0.9999755024909973],
|
|
23
|
-
[-0.1455000340938568, -0.1455000340938568, 0.6967067122459412, 0.6967067122459412, 0.9968017339706421, 0.9968017339706421, 0.9999679923057556, 0.9999679923057556],
|
|
24
|
-
[-0.9111302495002747, -0.9111302495002747, 0.6216099262237549, 0.6216099262237549, 0.9959527254104614, 0.9959527254104614, 0.9999595284461975, 0.9999595284461975],
|
|
25
|
-
[-0.83907151222229, -0.83907151222229, 0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334],
|
|
26
|
-
[0.004425697959959507, 0.004425697959959507, 0.4535960853099823, 0.4535960853099823, 0.9939560890197754, 0.9939560890197754, 0.999939501285553, 0.999939501285553],
|
|
27
|
-
[0.8438539505004883, 0.8438539505004883, 0.3623577058315277, 0.3623577058315277, 0.9928086400032043, 0.9928086400032043, 0.9999279975891113, 0.9999279975891113],
|
|
28
|
-
[0.9074468016624451, 0.9074468016624451, 0.26749876141548157, 0.26749876141548157, 0.9915618896484375, 0.9915618896484375, 0.9999154806137085, 0.9999154806137085],
|
|
29
|
-
[0.13673721253871918, 0.13673721253871918, 0.1699671596288681, 0.1699671596288681, 0.9902160167694092, 0.9902160167694092, 0.9999020099639893, 0.9999020099639893]
|
|
30
|
-
]]]);
|
|
31
|
-
|
|
32
|
-
const expected_sine_cache = tf.tensor([[[
|
|
33
|
-
[0, 0, 0, 0, 0, 0, 0, 0],
|
|
34
|
-
[0.8414709568023682, 0.8414709568023682, 0.0998334214091301, 0.0998334214091301, 0.009999833069741726, 0.009999833069741726, 0.0009999999310821295, 0.0009999999310821295],
|
|
35
|
-
[0.9092974066734314, 0.9092974066734314, 0.19866932928562164, 0.19866932928562164, 0.019998665899038315, 0.019998665899038315, 0.0019999986980110407, 0.0019999986980110407],
|
|
36
|
-
[0.14112000167369843, 0.14112000167369843, 0.29552021622657776, 0.29552021622657776, 0.029995499178767204, 0.029995499178767204, 0.0029999956022948027, 0.0029999956022948027],
|
|
37
|
-
[-0.756802499294281, -0.756802499294281, 0.3894183337688446, 0.3894183337688446, 0.03998933359980583, 0.03998933359980583, 0.003999989479780197, 0.003999989479780197],
|
|
38
|
-
[-0.9589242935180664, -0.9589242935180664, 0.4794255495071411, 0.4794255495071411, 0.04997916519641876, 0.04997916519641876, 0.0049999793991446495, 0.0049999793991446495],
|
|
39
|
-
[-0.279415488243103, -0.279415488243103, 0.5646424889564514, 0.5646424889564514, 0.059964004904031754, 0.059964004904031754, 0.0059999641962349415, 0.0059999641962349415],
|
|
40
|
-
[0.6569865942001343, 0.6569865942001343, 0.6442176699638367, 0.6442176699638367, 0.06994284689426422, 0.06994284689426422, 0.0069999429397284985, 0.0069999429397284985],
|
|
41
|
-
[0.9893582463264465, 0.9893582463264465, 0.7173560857772827, 0.7173560857772827, 0.07991468906402588, 0.07991468906402588, 0.007999914698302746, 0.007999914698302746],
|
|
42
|
-
[0.41211849451065063, 0.41211849451065063, 0.7833269238471985, 0.7833269238471985, 0.08987854421138763, 0.08987854421138763, 0.008999879471957684, 0.008999879471957684],
|
|
43
|
-
[-0.5440211296081543, -0.5440211296081543, 0.8414709568023682, 0.8414709568023682, 0.0998334139585495, 0.0998334139585495, 0.0099998340010643, 0.0099998340010643],
|
|
44
|
-
[-0.9999902248382568, -0.9999902248382568, 0.8912073969841003, 0.8912073969841003, 0.10977829992771149, 0.10977829992771149, 0.010999779216945171, 0.010999779216945171],
|
|
45
|
-
[-0.5365729331970215, -0.5365729331970215, 0.9320390820503235, 0.9320390820503235, 0.11971220374107361, 0.11971220374107361, 0.011999712325632572, 0.011999712325632572],
|
|
46
|
-
[0.4201670289039612, 0.4201670289039612, 0.9635581970214844, 0.9635581970214844, 0.12963414192199707, 0.12963414192199707, 0.012999634258449078, 0.012999634258449078],
|
|
47
|
-
[0.9906073808670044, 0.9906073808670044, 0.9854497313499451, 0.9854497313499451, 0.13954311609268188, 0.13954311609268188, 0.013999543152749538, 0.013999543152749538]
|
|
48
|
-
]]]);
|
|
49
|
-
|
|
50
|
-
const [cosine_cache, sine_cache] = rope.getWeights();
|
|
51
|
-
|
|
52
|
-
expect(await cosine_cache?.sub(expected_cosine_cache).sum().array() as number).toBeLessThanOrEqual(1e-6);
|
|
53
|
-
expect(await sine_cache?.sub(expected_sine_cache).sum().array() as number).toBeLessThanOrEqual(1e-6);
|
|
54
|
-
})
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
test("rotate inputs", async () => {
|
|
58
|
-
const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
|
|
59
|
-
|
|
60
|
-
const x = tf.tensor([[[
|
|
61
|
-
[0.0766048, 0.5706575, 0.6705932, 0.5273118, 0.4794086, 0.9378104, 0.9888024, 0.6926053],
|
|
62
|
-
[0.9064133, 0.5875182, 0.1681865, 0.3833345, 0.9901192, 0.4677338, 0.3353315, 0.02699],
|
|
63
|
-
[0.3033573, 0.4139377, 0.4062586, 0.9705839, 0.3582608, 0.328775, 0.1340587, 0.2193414],
|
|
64
|
-
[0.5565202, 0.4334963, 0.9912352, 0.3388563, 0.7991487, 0.1911893, 0.1140554, 0.6949552]]]
|
|
65
|
-
]); // batch=1, seq = 1, heads=4, embedDim=8
|
|
66
|
-
|
|
67
|
-
const expected_output = tf.tensor([[[
|
|
68
|
-
[0.07660479843616486, 0.57065749168396, 0.6705932021141052, 0.5273118019104004, 0.4794085919857025, 0.9378104209899902, 0.9888023734092712, 0.6926053166389465],
|
|
69
|
-
[-0.004642367362976074, 1.08015775680542, 0.12907665967941284, 0.39820998907089233, 0.9853923320770264, 0.47761136293411255, 0.33530429005622864, 0.027325313538312912],
|
|
70
|
-
[-0.5026336908340454, 0.10358311235904694, 0.20533521473407745, 1.0319478511810303, 0.3516140580177307, 0.33587393164634705, 0.1336197406053543, 0.21960905194282532],
|
|
71
|
-
[-0.6121258735656738, -0.3506217896938324, 0.8468242287635803, 0.6166517734527588, 0.7930541634559631, 0.2150741070508957, 0.11197001487016678, 0.695294201374054]
|
|
72
|
-
]]]);
|
|
73
|
-
|
|
74
|
-
const output = rope.apply(x) as tf.Tensor;
|
|
75
|
-
|
|
76
|
-
expect(await expected_output.sub(output).sum().array() as number).toBeLessThan(1e-6);
|
|
77
|
-
expect(rope.computeOutputShape(x.shape)).toEqual(x.shape);
|
|
78
|
-
expect(rope.computeOutputShape([x.shape])).toEqual(x.shape);
|
|
79
|
-
})
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
test("expand cache when input sequences are larger than rope's max sequence length", async () => {
|
|
83
|
-
const dim = 8;
|
|
84
|
-
const rope = new RotaryPositionEmbedding({ dim, maxSequenceLength: 15, theta: 1_000_000 });
|
|
85
|
-
const larger_sequence = 20;
|
|
86
|
-
const even_larger_sequence = 50;
|
|
87
|
-
|
|
88
|
-
rope.apply(tf.randomUniform([1, 1, larger_sequence, dim]));
|
|
89
|
-
|
|
90
|
-
rope.getWeights().forEach(weight => {
|
|
91
|
-
expect(weight.shape).toEqual([1, 1, 32, dim]);
|
|
92
|
-
});
|
|
93
|
-
|
|
94
|
-
rope.apply([tf.randomUniform([1, 1, even_larger_sequence, dim])]);
|
|
95
|
-
|
|
96
|
-
rope.getWeights().forEach(weight => {
|
|
97
|
-
expect(weight.shape).toEqual([1, 1, 64, dim]);
|
|
98
|
-
});
|
|
99
|
-
})
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
test("create layer", async () => {
|
|
103
|
-
// dim must be even
|
|
104
|
-
expect(() => new RotaryPositionEmbedding({ dim: 7, maxSequenceLength: 15 })).toThrow();
|
|
105
|
-
expect(() => new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 25 })).not.toThrow();
|
|
106
|
-
})
|
|
107
|
-
});
|
package/src/losses/index.ts
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export * from "./dice";
|
package/src/testing.ts
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
console.log("test")
|
package/tsconfig.json
DELETED
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
{
|
|
2
|
-
// Visit https://aka.ms/tsconfig to read more about this file
|
|
3
|
-
"compilerOptions": {
|
|
4
|
-
// File Layout
|
|
5
|
-
// "rootDir": "./src",
|
|
6
|
-
// "outDir": "./dist",
|
|
7
|
-
// Environment Settings
|
|
8
|
-
// See also https://aka.ms/tsconfig/module
|
|
9
|
-
"module": "es2022",
|
|
10
|
-
"target": "esnext",
|
|
11
|
-
"types": ["jest"],
|
|
12
|
-
// For nodejs:
|
|
13
|
-
// "lib": ["esnext"],
|
|
14
|
-
// "types": ["node"],
|
|
15
|
-
// and npm install -D @types/node
|
|
16
|
-
// Other Outputs
|
|
17
|
-
"sourceMap": true,
|
|
18
|
-
"declaration": true,
|
|
19
|
-
"declarationMap": true,
|
|
20
|
-
// Stricter Typechecking Options
|
|
21
|
-
//"noUncheckedIndexedAccess": true,
|
|
22
|
-
"exactOptionalPropertyTypes": true,
|
|
23
|
-
// Style Options
|
|
24
|
-
// "noImplicitReturns": true,
|
|
25
|
-
// "noImplicitOverride": true,
|
|
26
|
-
// "noUnusedLocals": true,
|
|
27
|
-
// "noUnusedParameters": true,
|
|
28
|
-
// "noFallthroughCasesInSwitch": true,
|
|
29
|
-
// "noPropertyAccessFromIndexSignature": true,
|
|
30
|
-
// Recommended Options
|
|
31
|
-
"strict": true,
|
|
32
|
-
"jsx": "react-jsx",
|
|
33
|
-
//"verbatimModuleSyntax": true,
|
|
34
|
-
"isolatedModules": true,
|
|
35
|
-
"noUncheckedSideEffectImports": true,
|
|
36
|
-
"moduleDetection": "force",
|
|
37
|
-
"skipLibCheck": true,
|
|
38
|
-
"paths": {
|
|
39
|
-
"@/*": [
|
|
40
|
-
"./src/*"
|
|
41
|
-
],
|
|
42
|
-
"e2e/*": [
|
|
43
|
-
"./e2e/*"
|
|
44
|
-
]
|
|
45
|
-
},
|
|
46
|
-
"moduleResolution": "bundler",
|
|
47
|
-
"esModuleInterop": true
|
|
48
|
-
}
|
|
49
|
-
}
|