@stellarapp/tfjs-stellar 1.0.3 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +17 -0
  2. package/dist/index.d.ts +3 -1
  3. package/dist/index.d.ts.map +1 -1
  4. package/dist/index.js +3 -1
  5. package/dist/index.js.map +1 -1
  6. package/dist/kv_cache.d.ts +2 -0
  7. package/dist/kv_cache.d.ts.map +1 -1
  8. package/dist/kv_cache.js +6 -0
  9. package/dist/kv_cache.js.map +1 -1
  10. package/dist/models/index.d.ts +2 -1
  11. package/dist/models/index.d.ts.map +1 -1
  12. package/dist/models/index.js +2 -1
  13. package/dist/models/index.js.map +1 -1
  14. package/package.json +1 -1
  15. package/dist/jest.config.d.ts +0 -8
  16. package/dist/jest.config.d.ts.map +0 -1
  17. package/dist/jest.config.js +0 -147
  18. package/dist/jest.config.js.map +0 -1
  19. package/dist/src/index.d.ts +0 -6
  20. package/dist/src/index.d.ts.map +0 -1
  21. package/dist/src/index.js +0 -6
  22. package/dist/src/index.js.map +0 -1
  23. package/dist/src/kv_cache.d.ts +0 -53
  24. package/dist/src/kv_cache.d.ts.map +0 -1
  25. package/dist/src/kv_cache.js +0 -135
  26. package/dist/src/kv_cache.js.map +0 -1
  27. package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
  28. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
  29. package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
  30. package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
  31. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
  32. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
  33. package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
  34. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
  35. package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
  36. package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
  37. package/dist/src/layers/gpt_decoder_block.js +0 -51
  38. package/dist/src/layers/gpt_decoder_block.js.map +0 -1
  39. package/dist/src/layers/index.d.ts +0 -17
  40. package/dist/src/layers/index.d.ts.map +0 -1
  41. package/dist/src/layers/index.js +0 -33
  42. package/dist/src/layers/index.js.map +0 -1
  43. package/dist/src/layers/multihead_attention.d.ts +0 -106
  44. package/dist/src/layers/multihead_attention.d.ts.map +0 -1
  45. package/dist/src/layers/multihead_attention.js +0 -269
  46. package/dist/src/layers/multihead_attention.js.map +0 -1
  47. package/dist/src/layers/multihead_attention.test.d.ts +0 -2
  48. package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
  49. package/dist/src/layers/multihead_attention.test.js +0 -160
  50. package/dist/src/layers/multihead_attention.test.js.map +0 -1
  51. package/dist/src/layers/positional_encoding.d.ts +0 -37
  52. package/dist/src/layers/positional_encoding.d.ts.map +0 -1
  53. package/dist/src/layers/positional_encoding.js +0 -115
  54. package/dist/src/layers/positional_encoding.js.map +0 -1
  55. package/dist/src/layers/positional_encoding.test.d.ts +0 -2
  56. package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
  57. package/dist/src/layers/positional_encoding.test.js +0 -95
  58. package/dist/src/layers/positional_encoding.test.js.map +0 -1
  59. package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
  60. package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
  61. package/dist/src/layers/rotary_position_embedding.js +0 -99
  62. package/dist/src/layers/rotary_position_embedding.js.map +0 -1
  63. package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
  64. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
  65. package/dist/src/layers/rotary_position_embedding.test.js +0 -88
  66. package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
  67. package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
  68. package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
  69. package/dist/src/layers/token_and_positional_embedding.js +0 -109
  70. package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
  71. package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
  72. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
  73. package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
  74. package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
  75. package/dist/src/layers/transformer_decoder.d.ts +0 -69
  76. package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
  77. package/dist/src/layers/transformer_decoder.js +0 -182
  78. package/dist/src/layers/transformer_decoder.js.map +0 -1
  79. package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
  80. package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
  81. package/dist/src/layers/transformer_decoder.test.js +0 -72
  82. package/dist/src/layers/transformer_decoder.test.js.map +0 -1
  83. package/dist/src/layers/transformer_encoder.d.ts +0 -55
  84. package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
  85. package/dist/src/layers/transformer_encoder.js +0 -175
  86. package/dist/src/layers/transformer_encoder.js.map +0 -1
  87. package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
  88. package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
  89. package/dist/src/layers/transformer_encoder.test.js +0 -58
  90. package/dist/src/layers/transformer_encoder.test.js.map +0 -1
  91. package/dist/src/losses/dice.d.ts +0 -30
  92. package/dist/src/losses/dice.d.ts.map +0 -1
  93. package/dist/src/losses/dice.js +0 -93
  94. package/dist/src/losses/dice.js.map +0 -1
  95. package/dist/src/losses/index.d.ts +0 -2
  96. package/dist/src/losses/index.d.ts.map +0 -1
  97. package/dist/src/losses/index.js +0 -2
  98. package/dist/src/losses/index.js.map +0 -1
  99. package/dist/src/masks.d.ts +0 -20
  100. package/dist/src/masks.d.ts.map +0 -1
  101. package/dist/src/masks.js +0 -37
  102. package/dist/src/masks.js.map +0 -1
  103. package/dist/src/metrics.d.ts +0 -20
  104. package/dist/src/metrics.d.ts.map +0 -1
  105. package/dist/src/metrics.js +0 -28
  106. package/dist/src/metrics.js.map +0 -1
  107. package/dist/src/models/gpt_model.d.ts +0 -94
  108. package/dist/src/models/gpt_model.d.ts.map +0 -1
  109. package/dist/src/models/gpt_model.js +0 -154
  110. package/dist/src/models/gpt_model.js.map +0 -1
  111. package/dist/src/models/index.d.ts +0 -3
  112. package/dist/src/models/index.d.ts.map +0 -1
  113. package/dist/src/models/index.js +0 -3
  114. package/dist/src/models/index.js.map +0 -1
  115. package/dist/src/models/llm_model.d.ts +0 -87
  116. package/dist/src/models/llm_model.d.ts.map +0 -1
  117. package/dist/src/models/llm_model.js +0 -245
  118. package/dist/src/models/llm_model.js.map +0 -1
  119. package/dist/src/models/u_net.d.ts +0 -40
  120. package/dist/src/models/u_net.d.ts.map +0 -1
  121. package/dist/src/models/u_net.js +0 -151
  122. package/dist/src/models/u_net.js.map +0 -1
  123. package/dist/src/tfjs_types.d.ts +0 -10
  124. package/dist/src/tfjs_types.d.ts.map +0 -1
  125. package/dist/src/tfjs_types.js +0 -2
  126. package/dist/src/tfjs_types.js.map +0 -1
  127. package/dist/src/utils.d.ts +0 -28
  128. package/dist/src/utils.d.ts.map +0 -1
  129. package/dist/src/utils.js +0 -63
  130. package/dist/src/utils.js.map +0 -1
  131. package/dist/src/utils.test.d.ts +0 -2
  132. package/dist/src/utils.test.d.ts.map +0 -1
  133. package/dist/src/utils.test.js +0 -73
  134. package/dist/src/utils.test.js.map +0 -1
@@ -1 +0,0 @@
1
- {"version":3,"file":"positional_encoding.js","sourceRoot":"","sources":["../../../src/layers/positional_encoding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAavC;;;;;;;;;GASG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IACvB,iBAAiB,CAAS;IAC1B,QAAQ,CAAS;IAC1B,mBAAmB,CAAmB;IAG9C,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,iBAAiB,GAAG,IAAI,CAAC,iBAAiB,IAAI,IAAI,CAAC;QACxD,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAE9B,IAAI,IAAI,CAAC,iBAAiB,GAAG,CAAC,EAAE,CAAC;YAC7B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,oBAAoB;gBAC5E,KAAK,IAAI,CAAC,iBAAiB,0BAA0B,CAAC,CAAC;QAC/D,CAAC;QAED,IAAI,IAAI,CAAC,QAAQ,GAAG,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,WAAW;gBACnE,KAAK,IAAI,CAAC,QAAQ,0BAA0B,CAAC,CAAC;QACtD,CAAC;QAED,yCAAyC;QACzC,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,SAAS,CAAC,sBAAsB,EAC5D,CAAC,IAAI,CAAC,iBAAiB,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,SAAS,EAClD,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IACnD,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;QACzD,MAAM,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC;QAElC,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,QAAQ,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,6BAA6B;gBAC9E,mBAAmB,IAAI,CAAC,iBAAiB,MAAM,IAAI,CAAC,QAAQ,kBAAkB,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC;QACrG,CAAC;QAED,IAAI,SAAS,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrC,6BAA6B;YAC7B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,yBAAyB;gBAC1E,qBAAqB,SAAS,iDAAiD;gBAC/E,IAAI,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC;QACtC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,mBAAmB,CAAC,IAAI,EAAE;iBAC3C,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,SAAS,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,kCAAkC;iBAC5E,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,2DAA2D;QACpF,CAAC,CAAC,CAAA;IACN,CAAC;IAED;;;;OAIG;IACM,KAAK,CAAC,UAAiC;QAC5C,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC;YAEpD,qDAAqD;YACrD,mEAAmE;YACnE,MAAM,SAAS,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC;iBACnD,OAAO,CAAC,CAAC,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC,CAAC;gBACrC,4FAA4F;iBAC3F,WAAW,CAAC,CAAC,IAAI,CAAC,iBAAiB,EAAE,cAAc,CAAC,CAAC,CAAC;YAE3D,oEAAoE;YACpE,iFAAiF;YACjF,mFAAmF;YACnF,yBAAyB;YACzB,sFAAsF;YACtF,MAAM,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErF,MAAM,UAAU,GAAG,SAAS,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC;YAE9C,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;YAChC,MAAM,MAAM,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;YAElC,uEAAuE;YACvE,4BAA4B;YAC5B,4BAA4B;YAC5B,MAAM;YACN,MAAM,WAAW,GAAG,EAAE,CAAC;YACvB,MAAM,QAAQ,GAAG,CAAC,CAAC,CAAC;YACpB,MAAM,OAAO,GAAG,CAAC,CAAC;YAClB,MAAM,SAAS,GAAG,CAAC,CAAC;YAEpB,KAAK,IAAI,SAAS,GAAG,CAAC,EAAE,SAAS,GAAG,IAAI,CAAC,QAAQ,GAAG,CAAC,EAAE,SAAS,EAAE,EAAE,CAAC;gBACjE,WAAW,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,EAAE,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,CAAA;gBAEzE,IAAI,SAAS,IAAI,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAC,EAAE,CAAC;oBAC7C,8DAA8D;oBAC9D,yFAAyF;oBACzF,qCAAqC;oBACrC,WAAW,CAAC,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,EAAE,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,CAAA;gBAC/E,CAAC;YACL,CAAC;YAED,8BAA8B;YAC9B,IAAI,CAAC,UAAU,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACjD,CAAC,CAAC,CAAC;QAEH,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGQ,kBAAkB,CAAC,UAAiC;QACzD,OAAO,UAAU,CAAC;IACtB,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,QAAQ,EAAE,IAAI,CAAC,QAAQ;SAC1B,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -1,2 +0,0 @@
1
- export {};
2
- //# sourceMappingURL=positional_encoding.test.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"positional_encoding.test.d.ts","sourceRoot":"","sources":["../../../src/layers/positional_encoding.test.ts"],"names":[],"mappings":""}
@@ -1,95 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { PositionalEncoding } from '@/layers/positional_encoding';
3
- // disables warning for using the faster node backend,
4
- // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
5
- tf.env().set('IS_NODE', false);
6
- describe("PositionalEncoding tests", () => {
7
- it("should fail to instantiate a layer", () => {
8
- expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: 0 })).toThrow();
9
- expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: -1 })).toThrow();
10
- expect(() => new PositionalEncoding({ maxSequenceLength: 0, embedDim: 32 })).toThrow();
11
- expect(() => new PositionalEncoding({ maxSequenceLength: -1, embedDim: 32 })).toThrow();
12
- });
13
- test("successfull forward calls", () => {
14
- const embed_dims = 32;
15
- const sequences = 4;
16
- const input = tf.randomUniform([2, sequences, embed_dims]);
17
- const positional = new PositionalEncoding({ embedDim: embed_dims });
18
- expect(() => positional.apply(input)).not.toThrow();
19
- expect(() => positional.apply([input])).not.toThrow();
20
- expect(positional.computeOutputShape(input.shape)).toEqual(input.shape);
21
- });
22
- it("should throw when input sequences are too large, embedding dims don't match, input aren't rank 3", () => {
23
- const sequences_too_long = tf.randomUniform([100, 32]);
24
- const embeddings_too_large = tf.randomUniform([32, 100]);
25
- const wrong_rank = tf.randomUniform([10, 32, 32]);
26
- const positional = new PositionalEncoding({ maxSequenceLength: 10, embedDim: 32 });
27
- expect(() => positional.apply(sequences_too_long)).toThrow();
28
- expect(() => positional.apply(embeddings_too_large)).toThrow();
29
- expect(() => positional.apply(wrong_rank)).toThrow();
30
- });
31
- it("should return a non-empty config dict", () => {
32
- const attention = new PositionalEncoding({ embedDim: 32 });
33
- expect(Object.keys(attention.getConfig())).not.toBe(0);
34
- });
35
- // PyTorch implementation at found at
36
- // https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
37
- it("should be within 1e-6 of PyTorch's implementation", () => {
38
- const pytorch_embed4 = tf.tensor([
39
- [[0.0000000, 1.0000000, 0.0000000, 1.0000000],
40
- [0.8414710, 0.5403023, 0.0099998, 0.9999500],
41
- [0.9092974, -0.4161468, 0.0199987, 0.9998000],
42
- [0.1411200, -0.9899925, 0.0299955, 0.9995500],
43
- [-0.7568025, -0.6536436, 0.0399893, 0.9992001],
44
- [-0.9589243, 0.2836622, 0.0499792, 0.9987503],
45
- [-0.2794155, 0.9601703, 0.0599640, 0.9982005],
46
- [0.6569866, 0.7539023, 0.0699428, 0.9975510],
47
- [0.9893582, -0.1455000, 0.0799147, 0.9968017],
48
- [0.4121185, -0.9111302, 0.0898785, 0.9959527]]
49
- ]);
50
- const pytorch_embed8 = tf.tensor([
51
- [[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00,
52
- 0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00],
53
- [8.4147096e-01, 5.4030234e-01, 9.9833414e-02, 9.9500418e-01,
54
- 9.9998331e-03, 9.9994999e-01, 9.9999981e-04, 9.9999952e-01],
55
- [9.0929741e-01, -4.1614684e-01, 1.9866931e-01, 9.8006660e-01,
56
- 1.9998666e-02, 9.9980003e-01, 1.9999985e-03, 9.9999803e-01],
57
- [1.4112000e-01, -9.8999250e-01, 2.9552019e-01, 9.5533651e-01,
58
- 2.9995499e-02, 9.9955004e-01, 2.9999954e-03, 9.9999553e-01],
59
- [-7.5680250e-01, -6.5364361e-01, 3.8941833e-01, 9.2106098e-01,
60
- 3.9989334e-02, 9.9920011e-01, 3.9999890e-03, 9.9999201e-01],
61
- [-9.5892429e-01, 2.8366220e-01, 4.7942552e-01, 8.7758255e-01,
62
- 4.9979165e-02, 9.9875027e-01, 4.9999789e-03, 9.9998754e-01],
63
- [-2.7941549e-01, 9.6017027e-01, 5.6464243e-01, 8.2533562e-01,
64
- 5.9964005e-02, 9.9820054e-01, 5.9999637e-03, 9.9998200e-01],
65
- [6.5698659e-01, 7.5390226e-01, 6.4421761e-01, 7.6484221e-01,
66
- 6.9942847e-02, 9.9755102e-01, 6.9999420e-03, 9.9997550e-01],
67
- [9.8935825e-01, -1.4550003e-01, 7.1735609e-01, 6.9670677e-01,
68
- 7.9914689e-02, 9.9680167e-01, 7.9999138e-03, 9.9996799e-01],
69
- [4.1211849e-01, -9.1113025e-01, 7.8332686e-01, 6.2160999e-01,
70
- 8.9878544e-02, 9.9595273e-01, 8.9998785e-03, 9.9995953e-01]]
71
- ]);
72
- const positional4 = new PositionalEncoding({ embedDim: 4, maxSequenceLength: 10 });
73
- positional4.build([]);
74
- const positional8 = new PositionalEncoding({ embedDim: 8, maxSequenceLength: 10 });
75
- positional8.build([]);
76
- const margin_of_error = 1e-6;
77
- // the difference between this and PyTorch's implementation
78
- //should be insignificantly small
79
- expect(positional4.getWeights()[0]
80
- .sub(pytorch_embed4)
81
- .abs()
82
- .arraySync()
83
- .flat(2)
84
- .filter(i => i > margin_of_error)
85
- .length).toBe(0);
86
- expect(positional8.getWeights()[0]
87
- .sub(pytorch_embed8)
88
- .abs()
89
- .arraySync()
90
- .flat(2)
91
- .filter(i => i > margin_of_error)
92
- .length).toBe(0);
93
- });
94
- });
95
- //# sourceMappingURL=positional_encoding.test.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"positional_encoding.test.js","sourceRoot":"","sources":["../../../src/layers/positional_encoding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC1C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC5F,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,UAAU,CAAC,CAAC,CAAC;QAE3D,MAAM,UAAU,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtD,MAAM,CAAC,UAAU,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;IAC5E,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,kGAAkG,EAAE,GAAG,EAAE;QACxG,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC;QACvD,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QACzD,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC;QAEnF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,oBAAoB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC/D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC;QAC3D,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,qCAAqC;IACrC,mFAAmF;IACnF,EAAE,CAAC,mDAAmD,EAAE,GAAG,EAAE;QACzD,MAAM,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC;YAC7B,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC5C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9C,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC5C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAAC,CAAC,CAAC;QAErD,MAAM,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC;YAC7B,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACvD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACzD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACvD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC,CAAC;SAAC,CAAC,CAAC;QAEvE,MAAM,WAAW,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QACnF,WAAW,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEtB,MAAM,WAAW,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QACnF,WAAW,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEtB,MAAM,eAAe,GAAG,IAAI,CAAC;QAE7B,2DAA2D;QAC3D,iCAAiC;QACjC,MAAM,CAAE,WAAW,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;aAC9B,GAAG,CAAC,cAAc,CAAC;aACnB,GAAG,EAAE;aACL,SAAS,EAAS;aAClB,IAAI,CAAC,CAAC,CAAC;aACP,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,eAAe,CAAC;aAChC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAErB,MAAM,CAAE,WAAW,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;aAC9B,GAAG,CAAC,cAAc,CAAC;aACnB,GAAG,EAAE;aACL,SAAS,EAAS;aAClB,IAAI,CAAC,CAAC,CAAC;aACP,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,eAAe,CAAC;aAChC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzB,CAAC,CAAC,CAAC;AACP,CAAC,CAAC,CAAC"}
@@ -1,39 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { type LayerArgs } from "@tensorflow/tfjs-layers/dist/engine/topology";
3
- export declare function applyRope(x: tf.Tensor, dim: number, cosine_cache: tf.Tensor, sine_cache: tf.Tensor): tf.Tensor<tf.Rank>;
4
- export declare function rotateHalf(x: tf.Tensor, dim: number): tf.Tensor;
5
- export declare function createRoPECache(dim: number, max_sequence_length: number, theta?: number): tf.Tensor<tf.Rank>[];
6
- export interface RotaryPositionEmbeddingArgs extends LayerArgs {
7
- /**
8
- * The dimension of each head (rounded down), e.g. `Math.floor(embedDim / numHeads)`
9
- */
10
- dim: number;
11
- /**
12
- * The RoPE cache will be pre-calculated up to the max sequence length, and re-caculated as needed. Defaults to `4096`.
13
- */
14
- maxSequenceLength?: number;
15
- /**
16
- * The base for the geometric progression used to compute the rotation angles. Defaults to `10_000`.
17
- */
18
- theta?: number;
19
- }
20
- /**
21
- * Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
22
- * Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
23
- */
24
- export declare class RotaryPositionEmbedding extends tf.layers.Layer {
25
- static className: string;
26
- protected dim: number;
27
- protected max_sequence_length: number;
28
- protected theta: number;
29
- protected cosine_cache: tf.LayerVariable;
30
- protected sine_cache: tf.LayerVariable;
31
- constructor({ dim, maxSequenceLength, theta, ...args }: RotaryPositionEmbeddingArgs);
32
- call(inputs: tf.Tensor | tf.Tensor[], kwargs: any): tf.Tensor | tf.Tensor[];
33
- build(input_shape: tf.Shape | tf.Shape[]): void;
34
- /**
35
- * Output shape: [batch, head, sequence, head_dim]
36
- */
37
- computeOutputShape(input_shape: tf.Shape | tf.Shape[]): tf.Shape;
38
- }
39
- //# sourceMappingURL=rotary_position_embedding.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"rotary_position_embedding.d.ts","sourceRoot":"","sources":["../../../src/layers/rotary_position_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAG9E,wBAAgB,SAAS,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,YAAY,EAAE,EAAE,CAAC,MAAM,EAAE,UAAU,EAAE,EAAE,CAAC,MAAM,sBAalG;AAGD,wBAAgB,UAAU,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,CAgB/D;AAGD,wBAAgB,eAAe,CAAC,GAAG,EAAE,MAAM,EAAE,mBAAmB,EAAE,MAAM,EAAE,KAAK,GAAE,MAAe,wBAqB/F;AAGD,MAAM,WAAW,2BAA4B,SAAQ,SAAS;IAC1D;;OAEG;IACH,GAAG,EAAE,MAAM,CAAC;IACZ;;OAEG;IACH,iBAAiB,CAAC,EAAE,MAAM,CAAC;IAC3B;;OAEG;IACH,KAAK,CAAC,EAAE,MAAM,CAAC;CAClB;AAGD;;;GAGG;AACH,qBAAa,uBAAwB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACxD,MAAM,CAAC,SAAS,SAA6B;IAE7C,SAAS,CAAC,GAAG,EAAE,MAAM,CAAC;IACtB,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,KAAK,EAAE,MAAM,CAAC;IAGxB,SAAS,CAAC,YAAY,EAAE,EAAE,CAAC,aAAa,CAAC;IACzC,SAAS,CAAC,UAAU,EAAE,EAAE,CAAC,aAAa,CAAC;gBAE3B,EAAE,GAAG,EAAE,iBAAwB,EAAE,KAAc,EAAE,GAAG,IAAI,EAAE,EAAE,2BAA2B;IAqB1F,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,GAAG,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAkB3E,KAAK,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAmBjD;;OAEG;IACI,kBAAkB,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;CAK/D"}
@@ -1,99 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- export function applyRope(x, dim, cosine_cache, sine_cache) {
3
- return tf.tidy(() => {
4
- const seq_length = x.shape[2];
5
- // get a slice of the pre-computed cache, up to the input's sequence length
6
- const cosine = cosine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
7
- const sine = sine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
8
- // apply RoPE formula (x1 * cosine) + (rotate(-x2) * sine)
9
- const rotated_x = rotateHalf(x, dim);
10
- return tf.add(tf.mul(x, cosine), tf.mul(rotated_x, sine));
11
- });
12
- }
13
- export function rotateHalf(x, dim) {
14
- return tf.tidy(() => {
15
- // reshape the last dimension such that adjacent coordinates are paired together
16
- // [x1, x2, x3, x4] -> [[x1, x2], [x3, x4]]
17
- // the leading dimensions are flattened because TFJS has issues during
18
- // backpropagation with 5D slicing
19
- const reshaped = x.reshape([-1, dim / 2, 2]);
20
- const x1 = reshaped.slice([0, 0, 0], [-1, -1, 1]);
21
- const x2 = reshaped.slice([0, 0, 1], [-1, -1, 1]);
22
- // [x1, x2] -> [-x2, x1]
23
- const rotated = tf.concat([tf.neg(x2), x1], -1);
24
- return rotated.reshape(x.shape);
25
- });
26
- }
27
- export function createRoPECache(dim, max_sequence_length, theta = 10_000) {
28
- return tf.tidy(() => {
29
- // [dim]
30
- const inv_frequencies = tf.div(1, tf.pow(theta, tf.range(0, Math.floor(dim / 2) * 2, 2, "float32").div(dim)));
31
- // [max_sequene_length]
32
- const sequence_indices = tf.range(0, max_sequence_length);
33
- //
34
- const freq = tf.outerProduct(sequence_indices, inv_frequencies);
35
- // cache final shape [max_sequence_length, dim]
36
- const freq_pairs = tf.stack([freq, freq], -1)
37
- .reshape([max_sequence_length, dim]);
38
- return [
39
- tf.keep(tf.cos(freq_pairs).expandDims(0).expandDims(0)),
40
- tf.keep(tf.sin(freq_pairs).expandDims(0).expandDims(0))
41
- ];
42
- });
43
- }
44
- /**
45
- * Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
46
- * Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
47
- */
48
- export class RotaryPositionEmbedding extends tf.layers.Layer {
49
- static className = "RotaryPositionEmbedding";
50
- dim;
51
- max_sequence_length;
52
- theta;
53
- // cached sine and cosine frequencies, untrainable weights
54
- cosine_cache;
55
- sine_cache;
56
- constructor({ dim, maxSequenceLength = 4096, theta = 10_000, ...args }) {
57
- super(args);
58
- if (dim % 2 !== 0) {
59
- throw Error(`${this.getClassName()}::constructor ${this.name} expected dim to be even, got ${dim}`);
60
- }
61
- this.dim = dim;
62
- this.max_sequence_length = maxSequenceLength;
63
- this.theta = theta;
64
- this.cosine_cache = this.addWeight("sine_cache", [1, 1, maxSequenceLength, Math.floor(this.dim)], "float32", tf.initializers.zeros(), undefined, false);
65
- this.sine_cache = this.addWeight("cosine_cache", [1, 1, maxSequenceLength, Math.floor(this.dim)], "float32", tf.initializers.zeros(), undefined, false);
66
- }
67
- call(inputs, kwargs) {
68
- const shape = Array.isArray(inputs) ? inputs[0].shape : inputs.shape;
69
- const seq_length = shape[2];
70
- if (seq_length > this.max_sequence_length) {
71
- // expand cache to the nearest power of 2
72
- this.max_sequence_length = Math.pow(2, Math.ceil(Math.log2(seq_length)));
73
- this.build([]);
74
- }
75
- return applyRope(Array.isArray(inputs) ? inputs[0] : inputs, this.dim, this.cosine_cache.read(), this.sine_cache.read());
76
- }
77
- build(input_shape) {
78
- const [cosine, sine] = createRoPECache(this.dim, this.max_sequence_length, this.theta);
79
- this.cosine_cache.dispose();
80
- this.sine_cache.dispose();
81
- this.cosine_cache = new tf.LayerVariable(cosine);
82
- this.sine_cache = new tf.LayerVariable(sine);
83
- this.nonTrainableWeights = [
84
- new tf.LayerVariable(cosine),
85
- new tf.LayerVariable(sine)
86
- ];
87
- this.setWeights([cosine, sine]);
88
- }
89
- /**
90
- * Output shape: [batch, head, sequence, head_dim]
91
- */
92
- computeOutputShape(input_shape) {
93
- return Array.isArray(input_shape[0])
94
- ? input_shape[0]
95
- : input_shape;
96
- }
97
- }
98
- tf.serialization.registerClass(RotaryPositionEmbedding);
99
- //# sourceMappingURL=rotary_position_embedding.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"rotary_position_embedding.js","sourceRoot":"","sources":["../../../src/layers/rotary_position_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,MAAM,UAAU,SAAS,CAAC,CAAY,EAAE,GAAW,EAAE,YAAuB,EAAE,UAAqB;IAC/F,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,MAAM,UAAU,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC;QAE/B,2EAA2E;QAC3E,MAAM,MAAM,GAAG,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,GAAG,CAAC,CAAC,CAAC;QACzE,MAAM,IAAI,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,GAAG,CAAC,CAAC,CAAC;QAErE,0DAA0D;QAC1D,MAAM,SAAS,GAAG,UAAU,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAErC,OAAO,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC;IAC9D,CAAC,CAAC,CAAC;AACP,CAAC;AAGD,MAAM,UAAU,UAAU,CAAC,CAAY,EAAE,GAAW;IAChD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,gFAAgF;QAChF,2CAA2C;QAC3C,sEAAsE;QACtE,kCAAkC;QAClC,MAAM,QAAQ,GAAG,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,EAAE,GAAG,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE7C,MAAM,EAAE,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClD,MAAM,EAAE,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAElD,wBAAwB;QACxB,MAAM,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAEhD,OAAO,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;AACP,CAAC;AAGD,MAAM,UAAU,eAAe,CAAC,GAAW,EAAE,mBAA2B,EAAE,QAAgB,MAAM;IAC5F,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,QAAQ;QACR,MAAM,eAAe,GAAG,EAAE,CAAC,GAAG,CAAc,CAAC,EAAE,EAAE,CAAC,GAAG,CACjD,KAAK,EACL,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,GAAG,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QAElE,uBAAuB;QACvB,MAAM,gBAAgB,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,mBAAmB,CAAC,CAAC;QAC1D,GAAG;QACH,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAC,gBAAgB,EAAE,eAAe,CAAC,CAAC;QAEhE,+CAA+C;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC;aACxC,OAAO,CAAC,CAAC,mBAAmB,EAAE,GAAG,CAAC,CAAC,CAAC;QAEzC,OAAO;YACH,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YACvD,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;SAC1D,CAAA;IACL,CAAC,CAAC,CAAC;AACP,CAAC;AAmBD;;;GAGG;AACH,MAAM,OAAO,uBAAwB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACxD,MAAM,CAAC,SAAS,GAAG,yBAAyB,CAAC;IAEnC,GAAG,CAAS;IACZ,mBAAmB,CAAS;IAC5B,KAAK,CAAS;IAExB,0DAA0D;IAChD,YAAY,CAAmB;IAC/B,UAAU,CAAmB;IAEvC,YAAY,EAAE,GAAG,EAAE,iBAAiB,GAAG,IAAI,EAAE,KAAK,GAAG,MAAM,EAAE,GAAG,IAAI,EAA+B;QAC/F,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,GAAG,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC;YAChB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,iCAAiC,GAAG,EAAE,CAAC,CAAC;QACxG,CAAC;QAED,IAAI,CAAC,GAAG,GAAG,GAAG,CAAC;QACf,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;QAEnB,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,SAAS,CAAC,YAAY,EAC3C,CAAC,CAAC,EAAE,CAAC,EAAE,iBAAiB,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAC/C,SAAS,EAAE,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;QAE1D,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,cAAc,EAC3C,CAAC,CAAC,EAAE,CAAC,EAAE,iBAAiB,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAC/C,SAAS,EAAE,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IAC9D,CAAC;IAGQ,IAAI,CAAC,MAA+B,EAAE,MAAW;QACtD,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC;QACrE,MAAM,UAAU,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;QAE5B,IAAI,UAAU,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YACxC,yCAAyC;YACzC,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YACzE,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QACnB,CAAC;QAED,OAAO,SAAS,CACZ,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,EAC1C,IAAI,CAAC,GAAG,EACR,IAAI,CAAC,YAAY,CAAC,IAAI,EAAE,EACxB,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,CAAC,CAAA;IAC/B,CAAC;IAGQ,KAAK,CAAC,WAAkC;QAC7C,MAAM,CAAC,MAAM,EAAE,IAAI,CAAC,GAAG,eAAe,CAClC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,mBAAmB,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;QAEpD,IAAI,CAAC,YAAY,CAAC,OAAO,EAAE,CAAC;QAC5B,IAAI,CAAC,UAAU,CAAC,OAAO,EAAE,CAAC;QAE1B,IAAI,CAAC,YAAY,GAAG,IAAI,EAAE,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;QACjD,IAAI,CAAC,UAAU,GAAG,IAAI,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;QAE7C,IAAI,CAAC,mBAAmB,GAAG;YACvB,IAAI,EAAE,CAAC,aAAa,CAAC,MAAM,CAAC;YAC5B,IAAI,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC;SAC7B,CAAC;QAEF,IAAI,CAAC,UAAU,CAAC,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC,CAAC;IACpC,CAAC;IAGD;;OAEG;IACI,kBAAkB,CAAC,WAAkC;QACxD,OAAO,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;YAChC,CAAC,CAAC,WAAW,CAAC,CAAC,CAAa;YAC5B,CAAC,CAAC,WAAuB,CAAC;IAClC,CAAC;;AAGL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,uBAAuB,CAAC,CAAC"}
@@ -1,2 +0,0 @@
1
- export {};
2
- //# sourceMappingURL=rotary_position_embedding.test.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"rotary_position_embedding.test.d.ts","sourceRoot":"","sources":["../../../src/layers/rotary_position_embedding.test.ts"],"names":[],"mappings":""}
@@ -1,88 +0,0 @@
1
- import { RotaryPositionEmbedding } from "@/layers/rotary_position_embedding";
2
- import * as tf from "@tensorflow/tfjs";
3
- // disables warning for using the faster node backend,
4
- // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
5
- tf.env().set('IS_NODE', false);
6
- describe("RotaryPositionEmbedding tests", () => {
7
- test("create cache", async () => {
8
- const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
9
- rope.build([]);
10
- const expected_cosine_cache = tf.tensor([[[
11
- [1, 1, 1, 1, 1, 1, 1, 1],
12
- [0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334, 0.9999995231628418, 0.9999995231628418],
13
- [-0.416146844625473, -0.416146844625473, 0.9800665974617004, 0.9800665974617004, 0.9998000264167786, 0.9998000264167786, 0.9999979734420776, 0.9999979734420776],
14
- [-0.9899924993515015, -0.9899924993515015, 0.9553365111351013, 0.9553365111351013, 0.9995500445365906, 0.9995500445365906, 0.9999955296516418, 0.9999955296516418],
15
- [-0.6536436080932617, -0.6536436080932617, 0.9210609793663025, 0.9210609793663025, 0.9992001056671143, 0.9992001056671143, 0.9999920129776001, 0.9999920129776001],
16
- [0.28366219997406006, 0.28366219997406006, 0.8775825500488281, 0.8775825500488281, 0.9987502694129944, 0.9987502694129944, 0.9999874830245972, 0.9999874830245972],
17
- [0.9601702690124512, 0.9601702690124512, 0.8253356218338013, 0.8253356218338013, 0.998200535774231, 0.998200535774231, 0.9999819993972778, 0.9999819993972778],
18
- [0.7539022564888, 0.7539022564888, 0.7648422122001648, 0.7648422122001648, 0.9975510239601135, 0.9975510239601135, 0.9999755024909973, 0.9999755024909973],
19
- [-0.1455000340938568, -0.1455000340938568, 0.6967067122459412, 0.6967067122459412, 0.9968017339706421, 0.9968017339706421, 0.9999679923057556, 0.9999679923057556],
20
- [-0.9111302495002747, -0.9111302495002747, 0.6216099262237549, 0.6216099262237549, 0.9959527254104614, 0.9959527254104614, 0.9999595284461975, 0.9999595284461975],
21
- [-0.83907151222229, -0.83907151222229, 0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334],
22
- [0.004425697959959507, 0.004425697959959507, 0.4535960853099823, 0.4535960853099823, 0.9939560890197754, 0.9939560890197754, 0.999939501285553, 0.999939501285553],
23
- [0.8438539505004883, 0.8438539505004883, 0.3623577058315277, 0.3623577058315277, 0.9928086400032043, 0.9928086400032043, 0.9999279975891113, 0.9999279975891113],
24
- [0.9074468016624451, 0.9074468016624451, 0.26749876141548157, 0.26749876141548157, 0.9915618896484375, 0.9915618896484375, 0.9999154806137085, 0.9999154806137085],
25
- [0.13673721253871918, 0.13673721253871918, 0.1699671596288681, 0.1699671596288681, 0.9902160167694092, 0.9902160167694092, 0.9999020099639893, 0.9999020099639893]
26
- ]]]);
27
- const expected_sine_cache = tf.tensor([[[
28
- [0, 0, 0, 0, 0, 0, 0, 0],
29
- [0.8414709568023682, 0.8414709568023682, 0.0998334214091301, 0.0998334214091301, 0.009999833069741726, 0.009999833069741726, 0.0009999999310821295, 0.0009999999310821295],
30
- [0.9092974066734314, 0.9092974066734314, 0.19866932928562164, 0.19866932928562164, 0.019998665899038315, 0.019998665899038315, 0.0019999986980110407, 0.0019999986980110407],
31
- [0.14112000167369843, 0.14112000167369843, 0.29552021622657776, 0.29552021622657776, 0.029995499178767204, 0.029995499178767204, 0.0029999956022948027, 0.0029999956022948027],
32
- [-0.756802499294281, -0.756802499294281, 0.3894183337688446, 0.3894183337688446, 0.03998933359980583, 0.03998933359980583, 0.003999989479780197, 0.003999989479780197],
33
- [-0.9589242935180664, -0.9589242935180664, 0.4794255495071411, 0.4794255495071411, 0.04997916519641876, 0.04997916519641876, 0.0049999793991446495, 0.0049999793991446495],
34
- [-0.279415488243103, -0.279415488243103, 0.5646424889564514, 0.5646424889564514, 0.059964004904031754, 0.059964004904031754, 0.0059999641962349415, 0.0059999641962349415],
35
- [0.6569865942001343, 0.6569865942001343, 0.6442176699638367, 0.6442176699638367, 0.06994284689426422, 0.06994284689426422, 0.0069999429397284985, 0.0069999429397284985],
36
- [0.9893582463264465, 0.9893582463264465, 0.7173560857772827, 0.7173560857772827, 0.07991468906402588, 0.07991468906402588, 0.007999914698302746, 0.007999914698302746],
37
- [0.41211849451065063, 0.41211849451065063, 0.7833269238471985, 0.7833269238471985, 0.08987854421138763, 0.08987854421138763, 0.008999879471957684, 0.008999879471957684],
38
- [-0.5440211296081543, -0.5440211296081543, 0.8414709568023682, 0.8414709568023682, 0.0998334139585495, 0.0998334139585495, 0.0099998340010643, 0.0099998340010643],
39
- [-0.9999902248382568, -0.9999902248382568, 0.8912073969841003, 0.8912073969841003, 0.10977829992771149, 0.10977829992771149, 0.010999779216945171, 0.010999779216945171],
40
- [-0.5365729331970215, -0.5365729331970215, 0.9320390820503235, 0.9320390820503235, 0.11971220374107361, 0.11971220374107361, 0.011999712325632572, 0.011999712325632572],
41
- [0.4201670289039612, 0.4201670289039612, 0.9635581970214844, 0.9635581970214844, 0.12963414192199707, 0.12963414192199707, 0.012999634258449078, 0.012999634258449078],
42
- [0.9906073808670044, 0.9906073808670044, 0.9854497313499451, 0.9854497313499451, 0.13954311609268188, 0.13954311609268188, 0.013999543152749538, 0.013999543152749538]
43
- ]]]);
44
- const [cosine_cache, sine_cache] = rope.getWeights();
45
- expect(await cosine_cache?.sub(expected_cosine_cache).sum().array()).toBeLessThanOrEqual(1e-6);
46
- expect(await sine_cache?.sub(expected_sine_cache).sum().array()).toBeLessThanOrEqual(1e-6);
47
- });
48
- test("rotate inputs", async () => {
49
- const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
50
- const x = tf.tensor([[[
51
- [0.0766048, 0.5706575, 0.6705932, 0.5273118, 0.4794086, 0.9378104, 0.9888024, 0.6926053],
52
- [0.9064133, 0.5875182, 0.1681865, 0.3833345, 0.9901192, 0.4677338, 0.3353315, 0.02699],
53
- [0.3033573, 0.4139377, 0.4062586, 0.9705839, 0.3582608, 0.328775, 0.1340587, 0.2193414],
54
- [0.5565202, 0.4334963, 0.9912352, 0.3388563, 0.7991487, 0.1911893, 0.1140554, 0.6949552]
55
- ]]
56
- ]); // batch=1, seq = 1, heads=4, embedDim=8
57
- const expected_output = tf.tensor([[[
58
- [0.07660479843616486, 0.57065749168396, 0.6705932021141052, 0.5273118019104004, 0.4794085919857025, 0.9378104209899902, 0.9888023734092712, 0.6926053166389465],
59
- [-0.004642367362976074, 1.08015775680542, 0.12907665967941284, 0.39820998907089233, 0.9853923320770264, 0.47761136293411255, 0.33530429005622864, 0.027325313538312912],
60
- [-0.5026336908340454, 0.10358311235904694, 0.20533521473407745, 1.0319478511810303, 0.3516140580177307, 0.33587393164634705, 0.1336197406053543, 0.21960905194282532],
61
- [-0.6121258735656738, -0.3506217896938324, 0.8468242287635803, 0.6166517734527588, 0.7930541634559631, 0.2150741070508957, 0.11197001487016678, 0.695294201374054]
62
- ]]]);
63
- const output = rope.apply(x);
64
- expect(await expected_output.sub(output).sum().array()).toBeLessThan(1e-6);
65
- expect(rope.computeOutputShape(x.shape)).toEqual(x.shape);
66
- expect(rope.computeOutputShape([x.shape])).toEqual(x.shape);
67
- });
68
- test("expand cache when input sequences are larger than rope's max sequence length", async () => {
69
- const dim = 8;
70
- const rope = new RotaryPositionEmbedding({ dim, maxSequenceLength: 15, theta: 1_000_000 });
71
- const larger_sequence = 20;
72
- const even_larger_sequence = 50;
73
- rope.apply(tf.randomUniform([1, 1, larger_sequence, dim]));
74
- rope.getWeights().forEach(weight => {
75
- expect(weight.shape).toEqual([1, 1, 32, dim]);
76
- });
77
- rope.apply([tf.randomUniform([1, 1, even_larger_sequence, dim])]);
78
- rope.getWeights().forEach(weight => {
79
- expect(weight.shape).toEqual([1, 1, 64, dim]);
80
- });
81
- });
82
- test("create layer", async () => {
83
- // dim must be even
84
- expect(() => new RotaryPositionEmbedding({ dim: 7, maxSequenceLength: 15 })).toThrow();
85
- expect(() => new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 25 })).not.toThrow();
86
- });
87
- });
88
- //# sourceMappingURL=rotary_position_embedding.test.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"rotary_position_embedding.test.js","sourceRoot":"","sources":["../../../src/layers/rotary_position_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,uBAAuB,EAAE,MAAM,oCAAoC,CAAC;AAC7E,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,+BAA+B,EAAE,GAAG,EAAE;IAC3C,IAAI,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QAC5E,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEf,MAAM,qBAAqB,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBACtC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;oBACxB,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,iBAAiB,EAAE,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC9J,CAAC,eAAe,EAAE,eAAe,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC1J,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,gBAAgB,EAAE,CAAC,gBAAgB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC9J,CAAC,oBAAoB,EAAE,oBAAoB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,iBAAiB,EAAE,iBAAiB,CAAC;oBAClK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;iBACrK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,mBAAmB,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBACpC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;oBACxB,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC5K,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC9K,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBACxK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;iBACzK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,CAAC,YAAY,EAAE,UAAU,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC;QAErD,MAAM,CAAC,MAAM,YAAY,EAAE,GAAG,CAAC,qBAAqB,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC;QACzG,MAAM,CAAC,MAAM,UAAU,EAAE,GAAG,CAAC,mBAAmB,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC;IACzG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,KAAK,IAAI,EAAE;QAC7B,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QAE5E,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBAClB,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;oBACxF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,OAAO,CAAC;oBACtF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,CAAC;oBACvF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;iBAAC,CAAC;SAC7F,CAAC,CAAC,CAAC,wCAAwC;QAE5C,MAAM,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBAChC,CAAC,mBAAmB,EAAE,gBAAgB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC/J,CAAC,CAAC,oBAAoB,EAAE,gBAAgB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,CAAC;oBACvK,CAAC,CAAC,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,mBAAmB,CAAC;oBACrK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,iBAAiB,CAAC;iBACrK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAc,CAAC;QAE1C,MAAM,CAAC,MAAM,eAAe,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QACrF,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QAC1D,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IAChE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,8EAA8E,EAAE,KAAK,IAAI,EAAE;QAC5F,MAAM,GAAG,GAAG,CAAC,CAAC;QACd,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,iBAAiB,EAAE,EAAE,EAAE,KAAK,EAAE,SAAS,EAAE,CAAC,CAAC;QAC3F,MAAM,eAAe,GAAG,EAAE,CAAC;QAC3B,MAAM,oBAAoB,GAAG,EAAE,CAAC;QAEhC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,eAAe,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC;QAE3D,IAAI,CAAC,UAAU,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC/B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QAClD,CAAC,CAAC,CAAC;QAEH,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,oBAAoB,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;QAElE,IAAI,CAAC,UAAU,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC/B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QAClD,CAAC,CAAC,CAAC;IACP,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,mBAAmB;QACnB,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC/F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
@@ -1,47 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
- import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
- import { type PositionalEncodingArgs } from '@/layers/positional_encoding';
5
- export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEncodingArgs {
6
- vocabularySize: number;
7
- dropout?: number;
8
- }
9
- /**
10
- * This class implements combines sinusoidal positional encoding from the
11
- * 2017 paper "Attention Is All You Need" with a normal embedding layer to
12
- * form a simplified single embedding layer.
13
- *
14
- * This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
15
- * it through an embedding layer before adding sinusoidal positional encoding.
16
- *
17
- * @param embedDim the size of each token/word's embedding
18
- * @param vocabularySize the number of tokens to embed
19
- * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
20
- * @param dropout applies dropout to the positionally encoded embeddings, default `0.1`
21
- */
22
- export declare class TokenAndPositionalEmbedding extends tf.layers.Layer {
23
- static className: string;
24
- private readonly embedDim;
25
- private readonly vocabularySize;
26
- private embedding;
27
- private positional;
28
- private readonly maxSequenceLength;
29
- private readonly dropout;
30
- private dropoutLayer;
31
- constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }: TokenAndPositionalEmbeddingArgs);
32
- /**
33
- * Forward propagation.
34
- */
35
- call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor<tf.Rank>;
36
- /**
37
- * Build the sublayers and enable serialization
38
- */
39
- build(inputShape: tf.Shape | tf.Shape[]): void;
40
- /**
41
- * The output shape, for an input shape of [batch, sequences], is
42
- * [batch, sequences, embedDim]
43
- */
44
- computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
45
- getConfig(): tf.serialization.ConfigDict;
46
- }
47
- //# sourceMappingURL=token_and_positional_embedding.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"token_and_positional_embedding.d.ts","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAEjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,+BAAgC,SAAQ,SAAS,EAAE,sBAAsB;IACtF,cAAc,EAAE,MAAM,CAAC;IACvB,OAAO,CAAC,EAAE,MAAM,CAAA;CACnB;AAGD;;;;;;;;;;;;GAYG;AACH,qBAAa,2BAA4B,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IAC5D,MAAM,CAAC,SAAS,SAAiC;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAS;IACxC,OAAO,CAAC,SAAS,CAAkB;IAEnC,OAAO,CAAC,UAAU,CAAiB;IACnC,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAS;IAC3C,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IAEjC,OAAO,CAAC,YAAY,CAAkB;gBAG1B,EAAE,QAAQ,EAAE,cAAc,EAAE,iBAAiB,EAAE,OAAO,EAAE,GAAG,IAAI,EAAE,EAAE,+BAA+B;IA0B9G;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM;IAe7D;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAgCvD;;;OAGG;IACM,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAQ5E,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAcpD"}
@@ -1,109 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { PositionalEncoding } from '@/layers/positional_encoding';
3
- /**
4
- * This class implements combines sinusoidal positional encoding from the
5
- * 2017 paper "Attention Is All You Need" with a normal embedding layer to
6
- * form a simplified single embedding layer.
7
- *
8
- * This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
9
- * it through an embedding layer before adding sinusoidal positional encoding.
10
- *
11
- * @param embedDim the size of each token/word's embedding
12
- * @param vocabularySize the number of tokens to embed
13
- * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
14
- * @param dropout applies dropout to the positionally encoded embeddings, default `0.1`
15
- */
16
- export class TokenAndPositionalEmbedding extends tf.layers.Layer {
17
- static className = "TokenAndPositionalEmbedding";
18
- embedDim;
19
- vocabularySize;
20
- embedding;
21
- positional;
22
- maxSequenceLength;
23
- dropout;
24
- dropoutLayer;
25
- constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }) {
26
- super(args);
27
- this.embedDim = embedDim;
28
- this.vocabularySize = vocabularySize;
29
- this.maxSequenceLength = maxSequenceLength ?? 5120;
30
- this.dropout = dropout ?? 0.1;
31
- if (this.dropout >= 1) {
32
- throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
33
- }
34
- this.embedding = tf.layers.embedding({
35
- inputDim: this.vocabularySize,
36
- outputDim: this.embedDim,
37
- });
38
- this.positional = new PositionalEncoding({
39
- maxSequenceLength: this.maxSequenceLength,
40
- embedDim: this.embedDim,
41
- });
42
- this.dropoutLayer = tf.layers.dropout({ rate: this.dropout });
43
- }
44
- /**
45
- * Forward propagation.
46
- */
47
- call(inputs, kwargs) {
48
- if (Array.isArray(inputs) && inputs.length != 1) {
49
- throw Error(`${this.getClassName()}::call ${this.name} expects exactly` +
50
- ` 1 tensor input, received ${inputs.length}`);
51
- }
52
- return tf.tidy(() => {
53
- let output = this.positional.apply(this.embedding.apply(inputs));
54
- output = this.dropoutLayer.apply(output);
55
- return output;
56
- });
57
- }
58
- /**
59
- * Build the sublayers and enable serialization
60
- */
61
- build(inputShape) {
62
- let input_shapes = [];
63
- // only consider the first shape if multiple provided
64
- if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
65
- // input is an array of shapes
66
- input_shapes = inputShape;
67
- }
68
- else if (inputShape.length != 0) {
69
- // input is a single shape
70
- input_shapes = [inputShape];
71
- }
72
- if (input_shapes[0].length != 2 || input_shapes[0][1] > this.maxSequenceLength) {
73
- throw Error(`${this.getClassName()}::build ${this.name} expected an input of` +
74
- ` shape [batch, tokens] where tokens < ${this.maxSequenceLength},` +
75
- ` received ${JSON.stringify(input_shapes[0])}`);
76
- }
77
- // initialize the sublayers' weights
78
- this.embedding.build(input_shapes[0]);
79
- this.positional.build(this.embedding.computeOutputShape(input_shapes[0]));
80
- // no need to rename weights, haven't found a case where their names collide
81
- this.trainableWeights = [
82
- ...this.embedding.trainableWeights,
83
- ...this.positional.trainableWeights
84
- ];
85
- super.build(input_shapes[0]);
86
- }
87
- /**
88
- * The output shape, for an input shape of [batch, sequences], is
89
- * [batch, sequences, embedDim]
90
- */
91
- computeOutputShape(inputShape) {
92
- const embedding_shape = this.embedding.computeOutputShape(inputShape);
93
- const positional_shape = this.positional.computeOutputShape(embedding_shape);
94
- return positional_shape;
95
- }
96
- getConfig() {
97
- const base_config = super.getConfig();
98
- const config = {
99
- embedDim: this.embedDim,
100
- vocabularySize: this.vocabularySize,
101
- maxSequenceLength: this.maxSequenceLength,
102
- dropout: this.dropout,
103
- };
104
- Object.assign(config, base_config);
105
- return config;
106
- }
107
- }
108
- tf.serialization.registerClass(TokenAndPositionalEmbedding);
109
- //# sourceMappingURL=token_and_positional_embedding.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"token_and_positional_embedding.js","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAS/F;;;;;;;;;;;;GAYG;AACH,MAAM,OAAO,2BAA4B,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IAC5D,MAAM,CAAC,SAAS,GAAG,6BAA6B,CAAC;IAEhC,QAAQ,CAAS;IACjB,cAAc,CAAS;IAChC,SAAS,CAAkB;IAE3B,UAAU,CAAiB;IAClB,iBAAiB,CAAS;IAC1B,OAAO,CAAS;IAEzB,YAAY,CAAkB;IAGtC,YAAY,EAAE,QAAQ,EAAE,cAAc,EAAE,iBAAiB,EAAE,OAAO,EAAE,GAAG,IAAI,EAAmC;QAC1G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,cAAc,GAAG,cAAc,CAAC;QACrC,IAAI,CAAC,iBAAiB,GAAG,iBAAiB,IAAI,IAAI,CAAC;QACnD,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAE9B,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,MAAM,CAAC,SAAS,CAAC;YACjC,QAAQ,EAAE,IAAI,CAAC,cAAc;YAC7B,SAAS,EAAE,IAAI,CAAC,QAAQ;SAC3B,CAAC,CAAC;QAEH,IAAI,CAAC,UAAU,GAAG,IAAI,kBAAkB,CAAC;YACrC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,QAAQ,EAAE,IAAI,CAAC,QAAQ;SAC1B,CAAC,CAAC;QAEH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;IAClE,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC9C,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kBAAkB;gBACnE,6BAA6B,MAAM,CAAC,MAAM,EAAE,CAAC,CAAC;QACtD,CAAC;QAED,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,MAAM,CAAC,CAAc,CAAC;YAC9E,MAAM,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,MAAM,CAAc,CAAC;YAEtD,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,qDAAqD;QACrD,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAE,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;YAC9E,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,uBAAuB;gBACzE,yCAAyC,IAAI,CAAC,iBAAiB,GAAG;gBAClE,aAAa,IAAI,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;QACxD,CAAC;QAED,oCAAoC;QACpC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAE1E,4EAA4E;QAC5E,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,SAAS,CAAC,gBAAgB;YAClC,GAAG,IAAI,CAAC,UAAU,CAAC,gBAAgB;SACtC,CAAC;QAEF,KAAK,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;IACjC,CAAC;IAGD;;;OAGG;IACM,kBAAkB,CAAC,UAAiC;QACzD,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;QACtE,MAAM,gBAAgB,GAAG,IAAI,CAAC,UAAU,CAAC,kBAAkB,CAAC,eAAe,CAAC,CAAC;QAE7E,OAAO,gBAAgB,CAAC;IAC5B,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,2BAA2B,CAAC,CAAC"}
@@ -1,2 +0,0 @@
1
- export {};
2
- //# sourceMappingURL=token_and_positional_embedding.test.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"token_and_positional_embedding.test.d.ts","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":""}