@stellarapp/tfjs-stellar 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +14 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -0,0 +1,73 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { getScaleShape, getRandomCropStart } from "@/utils";
3
+ import { causal } from "@/masks";
4
+ // avoid TFJS node message during Jest testing
5
+ tf.env().set('IS_NODE', false);
6
+ describe("test custom TFJS utility functions", () => {
7
+ test("crop an image using the same shape, results in same shape", async () => {
8
+ // cropping an image of the same shape
9
+ const img_size = [133, 84];
10
+ const target_size = [133, 84];
11
+ expect(getRandomCropStart(img_size, target_size)).toEqual([0, 0, 0]);
12
+ });
13
+ it("should throw when crop is larger than image", async () => {
14
+ expect(() => getRandomCropStart([128, 128], [1000, 2000])).toThrow();
15
+ });
16
+ test("cropped image shape", async () => {
17
+ // cropping from wide to tall image
18
+ for (let i = 0; i < 100; i++) {
19
+ const img_size = [4923, 832];
20
+ const target_size = [333, 739];
21
+ const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size);
22
+ expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
23
+ expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
24
+ }
25
+ // cropping from tall to wide image
26
+ for (let i = 0; i < 100; i++) {
27
+ const img_size = [381, 999];
28
+ const target_size = [300, 157];
29
+ const [crop_start_h, crop_start_w, channels] = getRandomCropStart(img_size, target_size);
30
+ expect(crop_start_h).toBeLessThanOrEqual(img_size[0] - target_size[0]);
31
+ expect(crop_start_w).toBeLessThanOrEqual(img_size[1] - target_size[1]);
32
+ }
33
+ });
34
+ test("scale 1:1, results in the same shape", async () => {
35
+ const scale = getScaleShape([256, 256], [256, 256]);
36
+ expect(scale).toEqual([256, 256]);
37
+ });
38
+ test("scaled image shape", async () => {
39
+ // scaling squares result in squares
40
+ const scale1 = getScaleShape([256, 256], [128, 128]);
41
+ expect(scale1).toEqual([128, 128]);
42
+ const scale2 = getScaleShape([128, 128], [256, 256]);
43
+ expect(scale2).toEqual([256, 256]);
44
+ const scale3 = getScaleShape([123, 123], [321, 321]);
45
+ expect(scale3).toEqual([321, 321]);
46
+ const scale4 = getScaleShape([321, 321], [123, 123]);
47
+ expect(scale4).toEqual([123, 123]);
48
+ // scaling rectangles result in rectangles
49
+ const scale5 = getScaleShape([640, 480], [1280, 960]);
50
+ expect(scale5).toEqual([1280, 960]);
51
+ const scale6 = getScaleShape([480, 640], [960, 1280]);
52
+ expect(scale6).toEqual([960, 1280]);
53
+ const [scale7_h, scale7_w] = getScaleShape([777, 555], [555, 333]);
54
+ expect(scale7_h).toBeGreaterThan(scale7_w);
55
+ const [scale8_h, scale8_w] = getScaleShape([555, 777], [333, 555]);
56
+ expect(scale8_h).toBeLessThan(scale8_w);
57
+ });
58
+ test("causal attention map", async () => {
59
+ const seq_len = 4;
60
+ const causal_mask = causal(seq_len, seq_len);
61
+ const _ = -1e7;
62
+ const expected_mask = tf.tensor([
63
+ [0, _, _, _],
64
+ [0, 0, _, _],
65
+ [0, 0, 0, _],
66
+ [0, 0, 0, 0]
67
+ ]);
68
+ // this might fail due to precision issues on the masked positions,
69
+ // in which case use less <= to 6 or 12 (number of masked positions x2)
70
+ expect((await causal_mask.sub(expected_mask).sum().data())[0]).toEqual(0);
71
+ });
72
+ });
73
+ //# sourceMappingURL=utils.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"utils.test.js","sourceRoot":"","sources":["../src/utils.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,aAAa,EAAE,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAC5D,OAAO,EAAE,MAAM,EAAE,MAAM,SAAS,CAAC;AAEjC,8CAA8C;AAC9C,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,oCAAoC,EAAE,GAAG,EAAE;IAEhD,IAAI,CAAC,2DAA2D,EAAE,KAAK,IAAI,EAAE;QACzE,sCAAsC;QACtC,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAC/C,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,EAAE,CAAqB,CAAC;QAElD,MAAM,CAAC,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;IACzE,CAAC,CAAC,CAAC;IAGH,EAAE,CAAC,6CAA6C,EAAE,KAAK,IAAI,EAAE;QACzD,MAAM,CAAC,GAAG,EAAE,CAAC,kBAAkB,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,qBAAqB,EAAE,KAAK,IAAI,EAAE;QACnC,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,IAAI,EAAE,GAAG,CAAqB,CAAC;YACjD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;QAED,mCAAmC;QACnC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YAC3B,MAAM,QAAQ,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAChD,MAAM,WAAW,GAAG,CAAC,GAAG,EAAE,GAAG,CAAqB,CAAC;YAEnD,MAAM,CAAC,YAAY,EAAE,YAAY,EAAE,QAAQ,CAAC,GAAG,kBAAkB,CAAC,QAAQ,EAAE,WAAW,CAAC,CAAA;YAExF,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;YACvE,MAAM,CAAC,YAAY,CAAC,CAAC,mBAAmB,CAAC,QAAQ,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC;QAC3E,CAAC;IACL,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sCAAsC,EAAE,KAAK,IAAI,EAAE;QACpD,MAAM,KAAK,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACnD,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;IACtC,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,oBAAoB,EAAE,KAAK,IAAI,EAAE;QAClC,oCAAoC;QACpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QACpD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAC;QAEnC,0CAA0C;QAC1C,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,IAAI,EAAE,GAAG,CAAC,CAAC,CAAC;QAEpC,MAAM,MAAM,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAA;QACrD,MAAM,CAAC,MAAM,CAAC,CAAC,OAAO,CAAC,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC;QAEpC,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,eAAe,CAAC,QAAQ,CAAC,CAAC;QAE3C,MAAM,CAAC,QAAQ,EAAE,QAAQ,CAAC,GAAG,aAAa,CAAC,CAAC,GAAG,EAAE,GAAG,CAAC,EAAE,CAAC,GAAG,EAAE,GAAG,CAAC,CAAC,CAAA;QAClE,MAAM,CAAC,QAAQ,CAAC,CAAC,YAAY,CAAC,QAAQ,CAAC,CAAC;IAC5C,CAAC,CAAC,CAAC;IAGH,IAAI,CAAC,sBAAsB,EAAE,KAAK,IAAI,EAAE;QACpC,MAAM,OAAO,GAAG,CAAC,CAAC;QAClB,MAAM,WAAW,GAAG,MAAM,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;QAE7C,MAAM,CAAC,GAAG,CAAC,GAAG,CAAC;QACf,MAAM,aAAa,GAAG,EAAE,CAAC,MAAM,CAAC;YAC5B,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;YACZ,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;SACf,CAAC,CAAC;QAEH,mEAAmE;QACnE,uEAAuE;QACvE,MAAM,CAAC,CAAC,MAAM,WAAW,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC;IAC9E,CAAC,CAAC,CAAC;AAEP,CAAC,CAAC,CAAC"}
package/package.json CHANGED
@@ -1,18 +1,24 @@
1
1
  {
2
2
  "name": "@stellarapp/tfjs-stellar",
3
- "version": "1.0.0",
3
+ "version": "1.0.2",
4
4
  "description": "An extension of TensorFlow.js for implementing large language models.",
5
5
  "license": "ISC",
6
6
  "author": "",
7
7
  "type": "module",
8
- "main": "index.ts",
8
+ "main": "dist/index.js",
9
+ "types": "dist/index.d.ts",
10
+ "files": [
11
+ "dist"
12
+ ],
9
13
  "scripts": {
10
- "test": "npx jest"
14
+ "test": "npx jest",
15
+ "build": "tsc"
11
16
  },
12
17
  "devDependencies": {
13
18
  "@tensorflow/tfjs": "^4.22.0",
14
19
  "@types/jest": "^30.0.0",
15
20
  "@types/node": "^26.0.0",
21
+ "globals": "^17.6.0",
16
22
  "jest": "^30.4.2",
17
23
  "ts-jest": "^29.4.11",
18
24
  "tsx": "^4.22.4",
@@ -20,5 +26,9 @@
20
26
  },
21
27
  "peerDependencies": {
22
28
  "@tensorflow/tfjs": "*"
29
+ },
30
+ "repository": {
31
+ "type": "git",
32
+ "url": "https://github.com/rkuang9/tfjs-stellar.git"
23
33
  }
24
- }
34
+ }
package/src/index.ts DELETED
@@ -1,93 +0,0 @@
1
- export * as models from "./models";
2
- export * as losses from "./losses";
3
- export * as metrics from "./metrics";
4
-
5
- import { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
6
- export { MultiHeadAttention, type MultiHeadAttentionArgs } from "@/layers/multihead_attention";
7
-
8
- import { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
9
- export { CachedRoPEMultiHeadAttention } from "@/layers/cached_rope_multihead_attention";
10
-
11
- import { TransformerEncoder, type TransformerEncoderArgs, } from "@/layers/transformer_encoder";
12
- export { TransformerEncoder, type TransformerEncoderArgs, } from "@/layers/transformer_encoder";
13
-
14
- import { TransformerDecoder, type TransformerDecoderArgs, } from "@/layers/transformer_decoder";
15
- export { TransformerDecoder, type TransformerDecoderArgs, } from "@/layers/transformer_decoder";
16
-
17
- import { TokenAndPositionalEmbedding, type TokenAndPositionalEmbeddingArgs } from "@/layers/token_and_positional_embedding";
18
- export { TokenAndPositionalEmbedding, type TokenAndPositionalEmbeddingArgs } from "@/layers/token_and_positional_embedding";
19
-
20
- import { PositionalEncoding, type PositionalEncodingArgs } from "@/layers/positional_encoding";
21
- export { PositionalEncoding, type PositionalEncodingArgs } from "@/layers/positional_encoding";
22
-
23
- import { GPT2DecoderBlock, type GPTDecoderBlockArgs } from "@/layers/gpt_decoder_block";
24
- export { GPT2DecoderBlock as GPTDecoderBlock, type GPTDecoderBlockArgs } from "@/layers/gpt_decoder_block";
25
-
26
- import { LlmModel, type LlmModelArgs } from "@/models/llm_model";
27
- export { LlmModel, type LlmModelArgs } from "@/models/llm_model";
28
-
29
- import { UNetModel, type UNetModelArgs } from "@/models/u_net";
30
-
31
- import { RotaryPositionEmbedding, type RotaryPositionEmbeddingArgs } from "@/layers/rotary_position_embedding";
32
- export { RotaryPositionEmbedding, type RotaryPositionEmbeddingArgs } from "@/layers/rotary_position_embedding";
33
-
34
-
35
- import { GptModel, type GptModelArgs } from "@/models/gpt_model";
36
- export { GptModel, type GptModelArgs } from "@/models/gpt_model";
37
-
38
-
39
- // The following exports give a keras-like import just like TFJS's tf.layers.<...>
40
-
41
- export function llmModel(args: LlmModelArgs) {
42
- return new LlmModel(args);
43
- }
44
-
45
-
46
- export function gptModel(args: GptModelArgs) {
47
- return new GptModel(args);
48
- }
49
-
50
-
51
- export function tokenAndPositionalEmbedding(args: TokenAndPositionalEmbeddingArgs) {
52
- return new TokenAndPositionalEmbedding(args);
53
- }
54
-
55
-
56
- export function transformerEncoder(args: TransformerEncoderArgs) {
57
- return new TransformerEncoder(args);
58
- }
59
-
60
-
61
- export function transformerDecoder(args: TransformerDecoderArgs) {
62
- return new TransformerDecoder(args);
63
- }
64
-
65
-
66
- export function multiheadAttention(args: MultiHeadAttentionArgs) {
67
- return new MultiHeadAttention(args);
68
- }
69
-
70
-
71
- export function cachedRopeMultiheadAttention(args: MultiHeadAttentionArgs) {
72
- return new CachedRoPEMultiHeadAttention(args);
73
- }
74
-
75
-
76
- export function positionalEncoding(args: PositionalEncodingArgs) {
77
- return new PositionalEncoding(args);
78
- }
79
-
80
-
81
- export function gpt2DecoderBlock(args: GPTDecoderBlockArgs) {
82
- return new GPT2DecoderBlock(args);
83
- }
84
-
85
-
86
- export function unetModel(args: UNetModelArgs) {
87
- return new UNetModel(args);
88
- }
89
-
90
-
91
- export function rotaryPositionEmbedding(args: RotaryPositionEmbeddingArgs) {
92
- return new RotaryPositionEmbedding(args);
93
- }
@@ -1,107 +0,0 @@
1
- import { RotaryPositionEmbedding } from "@/layers/rotary_position_embedding";
2
- import * as tf from "@tensorflow/tfjs";
3
-
4
- // disables warning for using the faster node backend,
5
- // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
6
- tf.env().set('IS_NODE', false);
7
-
8
-
9
- describe("RotaryPositionEmbedding tests", () => {
10
- test("create cache", async () => {
11
- const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
12
- rope.build([]);
13
-
14
- const expected_cosine_cache = tf.tensor([[[
15
- [1, 1, 1, 1, 1, 1, 1, 1],
16
- [0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334, 0.9999995231628418, 0.9999995231628418],
17
- [-0.416146844625473, -0.416146844625473, 0.9800665974617004, 0.9800665974617004, 0.9998000264167786, 0.9998000264167786, 0.9999979734420776, 0.9999979734420776],
18
- [-0.9899924993515015, -0.9899924993515015, 0.9553365111351013, 0.9553365111351013, 0.9995500445365906, 0.9995500445365906, 0.9999955296516418, 0.9999955296516418],
19
- [-0.6536436080932617, -0.6536436080932617, 0.9210609793663025, 0.9210609793663025, 0.9992001056671143, 0.9992001056671143, 0.9999920129776001, 0.9999920129776001],
20
- [0.28366219997406006, 0.28366219997406006, 0.8775825500488281, 0.8775825500488281, 0.9987502694129944, 0.9987502694129944, 0.9999874830245972, 0.9999874830245972],
21
- [0.9601702690124512, 0.9601702690124512, 0.8253356218338013, 0.8253356218338013, 0.998200535774231, 0.998200535774231, 0.9999819993972778, 0.9999819993972778],
22
- [0.7539022564888, 0.7539022564888, 0.7648422122001648, 0.7648422122001648, 0.9975510239601135, 0.9975510239601135, 0.9999755024909973, 0.9999755024909973],
23
- [-0.1455000340938568, -0.1455000340938568, 0.6967067122459412, 0.6967067122459412, 0.9968017339706421, 0.9968017339706421, 0.9999679923057556, 0.9999679923057556],
24
- [-0.9111302495002747, -0.9111302495002747, 0.6216099262237549, 0.6216099262237549, 0.9959527254104614, 0.9959527254104614, 0.9999595284461975, 0.9999595284461975],
25
- [-0.83907151222229, -0.83907151222229, 0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334],
26
- [0.004425697959959507, 0.004425697959959507, 0.4535960853099823, 0.4535960853099823, 0.9939560890197754, 0.9939560890197754, 0.999939501285553, 0.999939501285553],
27
- [0.8438539505004883, 0.8438539505004883, 0.3623577058315277, 0.3623577058315277, 0.9928086400032043, 0.9928086400032043, 0.9999279975891113, 0.9999279975891113],
28
- [0.9074468016624451, 0.9074468016624451, 0.26749876141548157, 0.26749876141548157, 0.9915618896484375, 0.9915618896484375, 0.9999154806137085, 0.9999154806137085],
29
- [0.13673721253871918, 0.13673721253871918, 0.1699671596288681, 0.1699671596288681, 0.9902160167694092, 0.9902160167694092, 0.9999020099639893, 0.9999020099639893]
30
- ]]]);
31
-
32
- const expected_sine_cache = tf.tensor([[[
33
- [0, 0, 0, 0, 0, 0, 0, 0],
34
- [0.8414709568023682, 0.8414709568023682, 0.0998334214091301, 0.0998334214091301, 0.009999833069741726, 0.009999833069741726, 0.0009999999310821295, 0.0009999999310821295],
35
- [0.9092974066734314, 0.9092974066734314, 0.19866932928562164, 0.19866932928562164, 0.019998665899038315, 0.019998665899038315, 0.0019999986980110407, 0.0019999986980110407],
36
- [0.14112000167369843, 0.14112000167369843, 0.29552021622657776, 0.29552021622657776, 0.029995499178767204, 0.029995499178767204, 0.0029999956022948027, 0.0029999956022948027],
37
- [-0.756802499294281, -0.756802499294281, 0.3894183337688446, 0.3894183337688446, 0.03998933359980583, 0.03998933359980583, 0.003999989479780197, 0.003999989479780197],
38
- [-0.9589242935180664, -0.9589242935180664, 0.4794255495071411, 0.4794255495071411, 0.04997916519641876, 0.04997916519641876, 0.0049999793991446495, 0.0049999793991446495],
39
- [-0.279415488243103, -0.279415488243103, 0.5646424889564514, 0.5646424889564514, 0.059964004904031754, 0.059964004904031754, 0.0059999641962349415, 0.0059999641962349415],
40
- [0.6569865942001343, 0.6569865942001343, 0.6442176699638367, 0.6442176699638367, 0.06994284689426422, 0.06994284689426422, 0.0069999429397284985, 0.0069999429397284985],
41
- [0.9893582463264465, 0.9893582463264465, 0.7173560857772827, 0.7173560857772827, 0.07991468906402588, 0.07991468906402588, 0.007999914698302746, 0.007999914698302746],
42
- [0.41211849451065063, 0.41211849451065063, 0.7833269238471985, 0.7833269238471985, 0.08987854421138763, 0.08987854421138763, 0.008999879471957684, 0.008999879471957684],
43
- [-0.5440211296081543, -0.5440211296081543, 0.8414709568023682, 0.8414709568023682, 0.0998334139585495, 0.0998334139585495, 0.0099998340010643, 0.0099998340010643],
44
- [-0.9999902248382568, -0.9999902248382568, 0.8912073969841003, 0.8912073969841003, 0.10977829992771149, 0.10977829992771149, 0.010999779216945171, 0.010999779216945171],
45
- [-0.5365729331970215, -0.5365729331970215, 0.9320390820503235, 0.9320390820503235, 0.11971220374107361, 0.11971220374107361, 0.011999712325632572, 0.011999712325632572],
46
- [0.4201670289039612, 0.4201670289039612, 0.9635581970214844, 0.9635581970214844, 0.12963414192199707, 0.12963414192199707, 0.012999634258449078, 0.012999634258449078],
47
- [0.9906073808670044, 0.9906073808670044, 0.9854497313499451, 0.9854497313499451, 0.13954311609268188, 0.13954311609268188, 0.013999543152749538, 0.013999543152749538]
48
- ]]]);
49
-
50
- const [cosine_cache, sine_cache] = rope.getWeights();
51
-
52
- expect(await cosine_cache?.sub(expected_cosine_cache).sum().array() as number).toBeLessThanOrEqual(1e-6);
53
- expect(await sine_cache?.sub(expected_sine_cache).sum().array() as number).toBeLessThanOrEqual(1e-6);
54
- })
55
-
56
-
57
- test("rotate inputs", async () => {
58
- const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
59
-
60
- const x = tf.tensor([[[
61
- [0.0766048, 0.5706575, 0.6705932, 0.5273118, 0.4794086, 0.9378104, 0.9888024, 0.6926053],
62
- [0.9064133, 0.5875182, 0.1681865, 0.3833345, 0.9901192, 0.4677338, 0.3353315, 0.02699],
63
- [0.3033573, 0.4139377, 0.4062586, 0.9705839, 0.3582608, 0.328775, 0.1340587, 0.2193414],
64
- [0.5565202, 0.4334963, 0.9912352, 0.3388563, 0.7991487, 0.1911893, 0.1140554, 0.6949552]]]
65
- ]); // batch=1, seq = 1, heads=4, embedDim=8
66
-
67
- const expected_output = tf.tensor([[[
68
- [0.07660479843616486, 0.57065749168396, 0.6705932021141052, 0.5273118019104004, 0.4794085919857025, 0.9378104209899902, 0.9888023734092712, 0.6926053166389465],
69
- [-0.004642367362976074, 1.08015775680542, 0.12907665967941284, 0.39820998907089233, 0.9853923320770264, 0.47761136293411255, 0.33530429005622864, 0.027325313538312912],
70
- [-0.5026336908340454, 0.10358311235904694, 0.20533521473407745, 1.0319478511810303, 0.3516140580177307, 0.33587393164634705, 0.1336197406053543, 0.21960905194282532],
71
- [-0.6121258735656738, -0.3506217896938324, 0.8468242287635803, 0.6166517734527588, 0.7930541634559631, 0.2150741070508957, 0.11197001487016678, 0.695294201374054]
72
- ]]]);
73
-
74
- const output = rope.apply(x) as tf.Tensor;
75
-
76
- expect(await expected_output.sub(output).sum().array() as number).toBeLessThan(1e-6);
77
- expect(rope.computeOutputShape(x.shape)).toEqual(x.shape);
78
- expect(rope.computeOutputShape([x.shape])).toEqual(x.shape);
79
- })
80
-
81
-
82
- test("expand cache when input sequences are larger than rope's max sequence length", async () => {
83
- const dim = 8;
84
- const rope = new RotaryPositionEmbedding({ dim, maxSequenceLength: 15, theta: 1_000_000 });
85
- const larger_sequence = 20;
86
- const even_larger_sequence = 50;
87
-
88
- rope.apply(tf.randomUniform([1, 1, larger_sequence, dim]));
89
-
90
- rope.getWeights().forEach(weight => {
91
- expect(weight.shape).toEqual([1, 1, 32, dim]);
92
- });
93
-
94
- rope.apply([tf.randomUniform([1, 1, even_larger_sequence, dim])]);
95
-
96
- rope.getWeights().forEach(weight => {
97
- expect(weight.shape).toEqual([1, 1, 64, dim]);
98
- });
99
- })
100
-
101
-
102
- test("create layer", async () => {
103
- // dim must be even
104
- expect(() => new RotaryPositionEmbedding({ dim: 7, maxSequenceLength: 15 })).toThrow();
105
- expect(() => new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 25 })).not.toThrow();
106
- })
107
- });
@@ -1 +0,0 @@
1
- export * from "./dice";
package/src/testing.ts DELETED
@@ -1 +0,0 @@
1
- console.log("test")
package/tsconfig.json DELETED
@@ -1,49 +0,0 @@
1
- {
2
- // Visit https://aka.ms/tsconfig to read more about this file
3
- "compilerOptions": {
4
- // File Layout
5
- // "rootDir": "./src",
6
- // "outDir": "./dist",
7
- // Environment Settings
8
- // See also https://aka.ms/tsconfig/module
9
- "module": "es2022",
10
- "target": "esnext",
11
- "types": ["jest"],
12
- // For nodejs:
13
- // "lib": ["esnext"],
14
- // "types": ["node"],
15
- // and npm install -D @types/node
16
- // Other Outputs
17
- "sourceMap": true,
18
- "declaration": true,
19
- "declarationMap": true,
20
- // Stricter Typechecking Options
21
- //"noUncheckedIndexedAccess": true,
22
- "exactOptionalPropertyTypes": true,
23
- // Style Options
24
- // "noImplicitReturns": true,
25
- // "noImplicitOverride": true,
26
- // "noUnusedLocals": true,
27
- // "noUnusedParameters": true,
28
- // "noFallthroughCasesInSwitch": true,
29
- // "noPropertyAccessFromIndexSignature": true,
30
- // Recommended Options
31
- "strict": true,
32
- "jsx": "react-jsx",
33
- //"verbatimModuleSyntax": true,
34
- "isolatedModules": true,
35
- "noUncheckedSideEffectImports": true,
36
- "moduleDetection": "force",
37
- "skipLibCheck": true,
38
- "paths": {
39
- "@/*": [
40
- "./src/*"
41
- ],
42
- "e2e/*": [
43
- "./e2e/*"
44
- ]
45
- },
46
- "moduleResolution": "bundler",
47
- "esModuleInterop": true
48
- }
49
- }