@stellarapp/tfjs-stellar 1.0.0 → 1.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. package/LICENSE +21 -0
  2. package/README.md +47 -0
  3. package/dist/index.d.ts +7 -0
  4. package/dist/index.d.ts.map +1 -0
  5. package/dist/index.js +7 -0
  6. package/dist/index.js.map +1 -0
  7. package/dist/jest.config.d.ts +8 -0
  8. package/dist/jest.config.d.ts.map +1 -0
  9. package/{jest.config.ts → dist/jest.config.js} +8 -64
  10. package/dist/jest.config.js.map +1 -0
  11. package/dist/kv_cache.d.ts +53 -0
  12. package/dist/kv_cache.d.ts.map +1 -0
  13. package/{src/kv_cache.ts → dist/kv_cache.js} +35 -105
  14. package/dist/kv_cache.js.map +1 -0
  15. package/dist/layers/cached_rope_multihead_attention.d.ts +31 -0
  16. package/dist/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  17. package/dist/layers/cached_rope_multihead_attention.js +76 -0
  18. package/dist/layers/cached_rope_multihead_attention.js.map +1 -0
  19. package/dist/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  20. package/dist/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  21. package/{src/layers/cached_rope_multihead_attention.test.ts → dist/layers/cached_rope_multihead_attention.test.js} +14 -30
  22. package/dist/layers/cached_rope_multihead_attention.test.js.map +1 -0
  23. package/dist/layers/gpt_decoder_block.d.ts +34 -0
  24. package/dist/layers/gpt_decoder_block.d.ts.map +1 -0
  25. package/{src/layers/gpt_decoder_block.ts → dist/layers/gpt_decoder_block.js} +10 -36
  26. package/dist/layers/gpt_decoder_block.js.map +1 -0
  27. package/dist/layers/index.d.ts +17 -0
  28. package/dist/layers/index.d.ts.map +1 -0
  29. package/dist/layers/index.js +33 -0
  30. package/dist/layers/index.js.map +1 -0
  31. package/dist/layers/multihead_attention.d.ts +106 -0
  32. package/dist/layers/multihead_attention.d.ts.map +1 -0
  33. package/{src/layers/multihead_attention.ts → dist/layers/multihead_attention.js} +60 -162
  34. package/dist/layers/multihead_attention.js.map +1 -0
  35. package/dist/layers/multihead_attention.test.d.ts +2 -0
  36. package/dist/layers/multihead_attention.test.d.ts.map +1 -0
  37. package/{src/layers/multihead_attention.test.ts → dist/layers/multihead_attention.test.js} +48 -100
  38. package/dist/layers/multihead_attention.test.js.map +1 -0
  39. package/dist/layers/positional_encoding.d.ts +37 -0
  40. package/dist/layers/positional_encoding.d.ts.map +1 -0
  41. package/{src/layers/positional_encoding.ts → dist/layers/positional_encoding.js} +17 -60
  42. package/dist/layers/positional_encoding.js.map +1 -0
  43. package/dist/layers/positional_encoding.test.d.ts +2 -0
  44. package/dist/layers/positional_encoding.test.d.ts.map +1 -0
  45. package/{src/layers/positional_encoding.test.ts → dist/layers/positional_encoding.test.js} +39 -57
  46. package/dist/layers/positional_encoding.test.js.map +1 -0
  47. package/dist/layers/rotary_position_embedding.d.ts +39 -0
  48. package/dist/layers/rotary_position_embedding.d.ts.map +1 -0
  49. package/{src/layers/rotary_position_embedding.ts → dist/layers/rotary_position_embedding.js} +22 -86
  50. package/dist/layers/rotary_position_embedding.js.map +1 -0
  51. package/dist/layers/rotary_position_embedding.test.d.ts +2 -0
  52. package/dist/layers/rotary_position_embedding.test.d.ts.map +1 -0
  53. package/dist/layers/rotary_position_embedding.test.js +88 -0
  54. package/dist/layers/rotary_position_embedding.test.js.map +1 -0
  55. package/dist/layers/token_and_positional_embedding.d.ts +47 -0
  56. package/dist/layers/token_and_positional_embedding.d.ts.map +1 -0
  57. package/{src/layers/token_and_positional_embedding.ts → dist/layers/token_and_positional_embedding.js} +27 -67
  58. package/dist/layers/token_and_positional_embedding.js.map +1 -0
  59. package/dist/layers/token_and_positional_embedding.test.d.ts +2 -0
  60. package/dist/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  61. package/{src/layers/token_and_positional_embedding.test.ts → dist/layers/token_and_positional_embedding.test.js} +7 -30
  62. package/dist/layers/token_and_positional_embedding.test.js.map +1 -0
  63. package/dist/layers/transformer_decoder.d.ts +69 -0
  64. package/dist/layers/transformer_decoder.d.ts.map +1 -0
  65. package/dist/layers/transformer_decoder.js +182 -0
  66. package/dist/layers/transformer_decoder.js.map +1 -0
  67. package/dist/layers/transformer_decoder.test.d.ts +2 -0
  68. package/dist/layers/transformer_decoder.test.d.ts.map +1 -0
  69. package/{src/layers/transformer_decoder.test.ts → dist/layers/transformer_decoder.test.js} +20 -48
  70. package/dist/layers/transformer_decoder.test.js.map +1 -0
  71. package/dist/layers/transformer_encoder.d.ts +55 -0
  72. package/dist/layers/transformer_encoder.d.ts.map +1 -0
  73. package/{src/layers/transformer_encoder.ts → dist/layers/transformer_encoder.js} +41 -90
  74. package/dist/layers/transformer_encoder.js.map +1 -0
  75. package/dist/layers/transformer_encoder.test.d.ts +2 -0
  76. package/dist/layers/transformer_encoder.test.d.ts.map +1 -0
  77. package/{src/layers/transformer_encoder.test.ts → dist/layers/transformer_encoder.test.js} +18 -45
  78. package/dist/layers/transformer_encoder.test.js.map +1 -0
  79. package/dist/losses/dice.d.ts +30 -0
  80. package/dist/losses/dice.d.ts.map +1 -0
  81. package/{src/losses/dice.ts → dist/losses/dice.js} +17 -80
  82. package/dist/losses/dice.js.map +1 -0
  83. package/dist/losses/index.d.ts +2 -0
  84. package/dist/losses/index.d.ts.map +1 -0
  85. package/dist/losses/index.js +2 -0
  86. package/dist/losses/index.js.map +1 -0
  87. package/dist/masks.d.ts +20 -0
  88. package/dist/masks.d.ts.map +1 -0
  89. package/{src/packing_mask.ts → dist/masks.js} +16 -7
  90. package/dist/masks.js.map +1 -0
  91. package/dist/metrics.d.ts +20 -0
  92. package/dist/metrics.d.ts.map +1 -0
  93. package/{src/metrics.ts → dist/metrics.js} +8 -12
  94. package/dist/metrics.js.map +1 -0
  95. package/dist/models/gpt_model.d.ts +94 -0
  96. package/dist/models/gpt_model.d.ts.map +1 -0
  97. package/{src/models/gpt_model.ts → dist/models/gpt_model.js} +41 -119
  98. package/dist/models/gpt_model.js.map +1 -0
  99. package/dist/models/index.d.ts +7 -0
  100. package/dist/models/index.d.ts.map +1 -0
  101. package/dist/models/index.js +13 -0
  102. package/dist/models/index.js.map +1 -0
  103. package/dist/models/llm_model.d.ts +87 -0
  104. package/dist/models/llm_model.d.ts.map +1 -0
  105. package/{src/models/llm_model.ts → dist/models/llm_model.js} +51 -161
  106. package/dist/models/llm_model.js.map +1 -0
  107. package/dist/models/u_net.d.ts +40 -0
  108. package/dist/models/u_net.d.ts.map +1 -0
  109. package/{src/models/u_net.ts → dist/models/u_net.js} +27 -116
  110. package/dist/models/u_net.js.map +1 -0
  111. package/dist/src/index.d.ts +6 -0
  112. package/dist/src/index.d.ts.map +1 -0
  113. package/dist/src/index.js +6 -0
  114. package/dist/src/index.js.map +1 -0
  115. package/dist/src/kv_cache.d.ts +53 -0
  116. package/dist/src/kv_cache.d.ts.map +1 -0
  117. package/dist/src/kv_cache.js +135 -0
  118. package/dist/src/kv_cache.js.map +1 -0
  119. package/dist/src/layers/cached_rope_multihead_attention.d.ts +31 -0
  120. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +1 -0
  121. package/{src/layers/cached_rope_multihead_attention.ts → dist/src/layers/cached_rope_multihead_attention.js} +25 -62
  122. package/dist/src/layers/cached_rope_multihead_attention.js.map +1 -0
  123. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +2 -0
  124. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +1 -0
  125. package/dist/src/layers/cached_rope_multihead_attention.test.js +43 -0
  126. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +1 -0
  127. package/dist/src/layers/gpt_decoder_block.d.ts +34 -0
  128. package/dist/src/layers/gpt_decoder_block.d.ts.map +1 -0
  129. package/dist/src/layers/gpt_decoder_block.js +51 -0
  130. package/dist/src/layers/gpt_decoder_block.js.map +1 -0
  131. package/dist/src/layers/index.d.ts +17 -0
  132. package/dist/src/layers/index.d.ts.map +1 -0
  133. package/dist/src/layers/index.js +33 -0
  134. package/dist/src/layers/index.js.map +1 -0
  135. package/dist/src/layers/multihead_attention.d.ts +106 -0
  136. package/dist/src/layers/multihead_attention.d.ts.map +1 -0
  137. package/dist/src/layers/multihead_attention.js +269 -0
  138. package/dist/src/layers/multihead_attention.js.map +1 -0
  139. package/dist/src/layers/multihead_attention.test.d.ts +2 -0
  140. package/dist/src/layers/multihead_attention.test.d.ts.map +1 -0
  141. package/dist/src/layers/multihead_attention.test.js +160 -0
  142. package/dist/src/layers/multihead_attention.test.js.map +1 -0
  143. package/dist/src/layers/positional_encoding.d.ts +37 -0
  144. package/dist/src/layers/positional_encoding.d.ts.map +1 -0
  145. package/dist/src/layers/positional_encoding.js +115 -0
  146. package/dist/src/layers/positional_encoding.js.map +1 -0
  147. package/dist/src/layers/positional_encoding.test.d.ts +2 -0
  148. package/dist/src/layers/positional_encoding.test.d.ts.map +1 -0
  149. package/dist/src/layers/positional_encoding.test.js +95 -0
  150. package/dist/src/layers/positional_encoding.test.js.map +1 -0
  151. package/dist/src/layers/rotary_position_embedding.d.ts +39 -0
  152. package/dist/src/layers/rotary_position_embedding.d.ts.map +1 -0
  153. package/dist/src/layers/rotary_position_embedding.js +99 -0
  154. package/dist/src/layers/rotary_position_embedding.js.map +1 -0
  155. package/dist/src/layers/rotary_position_embedding.test.d.ts +2 -0
  156. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +1 -0
  157. package/dist/src/layers/rotary_position_embedding.test.js +88 -0
  158. package/dist/src/layers/rotary_position_embedding.test.js.map +1 -0
  159. package/dist/src/layers/token_and_positional_embedding.d.ts +47 -0
  160. package/dist/src/layers/token_and_positional_embedding.d.ts.map +1 -0
  161. package/dist/src/layers/token_and_positional_embedding.js +109 -0
  162. package/dist/src/layers/token_and_positional_embedding.js.map +1 -0
  163. package/dist/src/layers/token_and_positional_embedding.test.d.ts +2 -0
  164. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +1 -0
  165. package/dist/src/layers/token_and_positional_embedding.test.js +58 -0
  166. package/dist/src/layers/token_and_positional_embedding.test.js.map +1 -0
  167. package/dist/src/layers/transformer_decoder.d.ts +69 -0
  168. package/dist/src/layers/transformer_decoder.d.ts.map +1 -0
  169. package/{src/layers/transformer_decoder.ts → dist/src/layers/transformer_decoder.js} +41 -95
  170. package/dist/src/layers/transformer_decoder.js.map +1 -0
  171. package/dist/src/layers/transformer_decoder.test.d.ts +2 -0
  172. package/dist/src/layers/transformer_decoder.test.d.ts.map +1 -0
  173. package/dist/src/layers/transformer_decoder.test.js +72 -0
  174. package/dist/src/layers/transformer_decoder.test.js.map +1 -0
  175. package/dist/src/layers/transformer_encoder.d.ts +55 -0
  176. package/dist/src/layers/transformer_encoder.d.ts.map +1 -0
  177. package/dist/src/layers/transformer_encoder.js +175 -0
  178. package/dist/src/layers/transformer_encoder.js.map +1 -0
  179. package/dist/src/layers/transformer_encoder.test.d.ts +2 -0
  180. package/dist/src/layers/transformer_encoder.test.d.ts.map +1 -0
  181. package/dist/src/layers/transformer_encoder.test.js +58 -0
  182. package/dist/src/layers/transformer_encoder.test.js.map +1 -0
  183. package/dist/src/losses/dice.d.ts +30 -0
  184. package/dist/src/losses/dice.d.ts.map +1 -0
  185. package/dist/src/losses/dice.js +93 -0
  186. package/dist/src/losses/dice.js.map +1 -0
  187. package/dist/src/losses/index.d.ts +2 -0
  188. package/dist/src/losses/index.d.ts.map +1 -0
  189. package/dist/src/losses/index.js +2 -0
  190. package/dist/src/losses/index.js.map +1 -0
  191. package/dist/src/masks.d.ts +20 -0
  192. package/dist/src/masks.d.ts.map +1 -0
  193. package/dist/src/masks.js +37 -0
  194. package/dist/src/masks.js.map +1 -0
  195. package/dist/src/metrics.d.ts +20 -0
  196. package/dist/src/metrics.d.ts.map +1 -0
  197. package/dist/src/metrics.js +28 -0
  198. package/dist/src/metrics.js.map +1 -0
  199. package/dist/src/models/gpt_model.d.ts +94 -0
  200. package/dist/src/models/gpt_model.d.ts.map +1 -0
  201. package/dist/src/models/gpt_model.js +154 -0
  202. package/dist/src/models/gpt_model.js.map +1 -0
  203. package/dist/src/models/index.d.ts +3 -0
  204. package/dist/src/models/index.d.ts.map +1 -0
  205. package/{src/models/index.ts → dist/src/models/index.js} +1 -0
  206. package/dist/src/models/index.js.map +1 -0
  207. package/dist/src/models/llm_model.d.ts +87 -0
  208. package/dist/src/models/llm_model.d.ts.map +1 -0
  209. package/dist/src/models/llm_model.js +245 -0
  210. package/dist/src/models/llm_model.js.map +1 -0
  211. package/dist/src/models/u_net.d.ts +40 -0
  212. package/dist/src/models/u_net.d.ts.map +1 -0
  213. package/dist/src/models/u_net.js +151 -0
  214. package/dist/src/models/u_net.js.map +1 -0
  215. package/{src/tfjs_types.ts → dist/src/tfjs_types.d.ts} +1 -6
  216. package/dist/src/tfjs_types.d.ts.map +1 -0
  217. package/dist/src/tfjs_types.js +2 -0
  218. package/dist/src/tfjs_types.js.map +1 -0
  219. package/dist/src/utils.d.ts +28 -0
  220. package/dist/src/utils.d.ts.map +1 -0
  221. package/{src/utils.ts → dist/src/utils.js} +10 -33
  222. package/dist/src/utils.js.map +1 -0
  223. package/dist/src/utils.test.d.ts +2 -0
  224. package/dist/src/utils.test.d.ts.map +1 -0
  225. package/{src/utils.test.ts → dist/src/utils.test.js} +22 -50
  226. package/dist/src/utils.test.js.map +1 -0
  227. package/dist/tfjs_types.d.ts +10 -0
  228. package/dist/tfjs_types.d.ts.map +1 -0
  229. package/dist/tfjs_types.js +2 -0
  230. package/dist/tfjs_types.js.map +1 -0
  231. package/dist/utils.d.ts +28 -0
  232. package/dist/utils.d.ts.map +1 -0
  233. package/dist/utils.js +63 -0
  234. package/dist/utils.js.map +1 -0
  235. package/dist/utils.test.d.ts +2 -0
  236. package/dist/utils.test.d.ts.map +1 -0
  237. package/dist/utils.test.js +73 -0
  238. package/dist/utils.test.js.map +1 -0
  239. package/package.json +10 -4
  240. package/src/index.ts +0 -93
  241. package/src/layers/rotary_position_embedding.test.ts +0 -107
  242. package/src/losses/index.ts +0 -1
  243. package/src/testing.ts +0 -1
  244. package/tsconfig.json +0 -49
@@ -0,0 +1 @@
1
+ {"version":3,"file":"positional_encoding.js","sourceRoot":"","sources":["../../../src/layers/positional_encoding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAavC;;;;;;;;;GASG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IACvB,iBAAiB,CAAS;IAC1B,QAAQ,CAAS;IAC1B,mBAAmB,CAAmB;IAG9C,YAAY,IAA4B;QACpC,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,iBAAiB,GAAG,IAAI,CAAC,iBAAiB,IAAI,IAAI,CAAC;QACxD,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,QAAQ,CAAC;QAE9B,IAAI,IAAI,CAAC,iBAAiB,GAAG,CAAC,EAAE,CAAC;YAC7B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,oBAAoB;gBAC5E,KAAK,IAAI,CAAC,iBAAiB,0BAA0B,CAAC,CAAC;QAC/D,CAAC;QAED,IAAI,IAAI,CAAC,QAAQ,GAAG,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,WAAW;gBACnE,KAAK,IAAI,CAAC,QAAQ,0BAA0B,CAAC,CAAC;QACtD,CAAC;QAED,yCAAyC;QACzC,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,SAAS,CAAC,sBAAsB,EAC5D,CAAC,IAAI,CAAC,iBAAiB,EAAE,IAAI,CAAC,QAAQ,CAAC,EAAE,SAAS,EAClD,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IACnD,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC;QACzD,MAAM,SAAS,GAAG,KAAK,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC;QAElC,IAAI,KAAK,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,IAAI,CAAC,QAAQ,EAAE,CAAC;YAC7D,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,6BAA6B;gBAC9E,mBAAmB,IAAI,CAAC,iBAAiB,MAAM,IAAI,CAAC,QAAQ,kBAAkB,KAAK,CAAC,KAAK,EAAE,CAAC,CAAC;QACrG,CAAC;QAED,IAAI,SAAS,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;YACrC,6BAA6B;YAC7B,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,yBAAyB;gBAC1E,qBAAqB,SAAS,iDAAiD;gBAC/E,IAAI,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC;QACtC,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,OAAO,KAAK,CAAC,GAAG,CAAC,IAAI,CAAC,mBAAmB,CAAC,IAAI,EAAE;iBAC3C,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,SAAS,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,kCAAkC;iBAC5E,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,2DAA2D;QACpF,CAAC,CAAC,CAAA;IACN,CAAC;IAED;;;;OAIG;IACM,KAAK,CAAC,UAAiC;QAC5C,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YACT,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAC,CAAC;YAEpD,qDAAqD;YACrD,mEAAmE;YACnE,MAAM,SAAS,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC;iBACnD,OAAO,CAAC,CAAC,IAAI,CAAC,iBAAiB,EAAE,CAAC,CAAC,CAAC;gBACrC,4FAA4F;iBAC3F,WAAW,CAAC,CAAC,IAAI,CAAC,iBAAiB,EAAE,cAAc,CAAC,CAAC,CAAC;YAE3D,oEAAoE;YACpE,iFAAiF;YACjF,mFAAmF;YACnF,yBAAyB;YACzB,sFAAsF;YACtF,MAAM,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC;YAErF,MAAM,UAAU,GAAG,SAAS,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC;YAE9C,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;YAChC,MAAM,MAAM,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;YAElC,uEAAuE;YACvE,4BAA4B;YAC5B,4BAA4B;YAC5B,MAAM;YACN,MAAM,WAAW,GAAG,EAAE,CAAC;YACvB,MAAM,QAAQ,GAAG,CAAC,CAAC,CAAC;YACpB,MAAM,OAAO,GAAG,CAAC,CAAC;YAClB,MAAM,SAAS,GAAG,CAAC,CAAC;YAEpB,KAAK,IAAI,SAAS,GAAG,CAAC,EAAE,SAAS,GAAG,IAAI,CAAC,QAAQ,GAAG,CAAC,EAAE,SAAS,EAAE,EAAE,CAAC;gBACjE,WAAW,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,EAAE,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,CAAA;gBAEzE,IAAI,SAAS,IAAI,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,QAAQ,GAAG,CAAC,CAAC,EAAE,CAAC;oBAC7C,8DAA8D;oBAC9D,yFAAyF;oBACzF,qCAAqC;oBACrC,WAAW,CAAC,IAAI,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,EAAE,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,CAAA;gBAC/E,CAAC;YACL,CAAC;YAED,8BAA8B;YAC9B,IAAI,CAAC,UAAU,CAAC,CAAC,EAAE,CAAC,MAAM,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;QACjD,CAAC,CAAC,CAAC;QAEH,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGQ,kBAAkB,CAAC,UAAiC;QACzD,OAAO,UAAU,CAAC;IACtB,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,QAAQ,EAAE,IAAI,CAAC,QAAQ;SAC1B,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=positional_encoding.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"positional_encoding.test.d.ts","sourceRoot":"","sources":["../../../src/layers/positional_encoding.test.ts"],"names":[],"mappings":""}
@@ -0,0 +1,95 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { PositionalEncoding } from '@/layers/positional_encoding';
3
+ // disables warning for using the faster node backend,
4
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
5
+ tf.env().set('IS_NODE', false);
6
+ describe("PositionalEncoding tests", () => {
7
+ it("should fail to instantiate a layer", () => {
8
+ expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: 0 })).toThrow();
9
+ expect(() => new PositionalEncoding({ maxSequenceLength: 3, embedDim: -1 })).toThrow();
10
+ expect(() => new PositionalEncoding({ maxSequenceLength: 0, embedDim: 32 })).toThrow();
11
+ expect(() => new PositionalEncoding({ maxSequenceLength: -1, embedDim: 32 })).toThrow();
12
+ });
13
+ test("successfull forward calls", () => {
14
+ const embed_dims = 32;
15
+ const sequences = 4;
16
+ const input = tf.randomUniform([2, sequences, embed_dims]);
17
+ const positional = new PositionalEncoding({ embedDim: embed_dims });
18
+ expect(() => positional.apply(input)).not.toThrow();
19
+ expect(() => positional.apply([input])).not.toThrow();
20
+ expect(positional.computeOutputShape(input.shape)).toEqual(input.shape);
21
+ });
22
+ it("should throw when input sequences are too large, embedding dims don't match, input aren't rank 3", () => {
23
+ const sequences_too_long = tf.randomUniform([100, 32]);
24
+ const embeddings_too_large = tf.randomUniform([32, 100]);
25
+ const wrong_rank = tf.randomUniform([10, 32, 32]);
26
+ const positional = new PositionalEncoding({ maxSequenceLength: 10, embedDim: 32 });
27
+ expect(() => positional.apply(sequences_too_long)).toThrow();
28
+ expect(() => positional.apply(embeddings_too_large)).toThrow();
29
+ expect(() => positional.apply(wrong_rank)).toThrow();
30
+ });
31
+ it("should return a non-empty config dict", () => {
32
+ const attention = new PositionalEncoding({ embedDim: 32 });
33
+ expect(Object.keys(attention.getConfig())).not.toBe(0);
34
+ });
35
+ // PyTorch implementation at found at
36
+ // https://pytorch-tutorials-preview.netlify.app/beginner/transformer_tutorial.html
37
+ it("should be within 1e-6 of PyTorch's implementation", () => {
38
+ const pytorch_embed4 = tf.tensor([
39
+ [[0.0000000, 1.0000000, 0.0000000, 1.0000000],
40
+ [0.8414710, 0.5403023, 0.0099998, 0.9999500],
41
+ [0.9092974, -0.4161468, 0.0199987, 0.9998000],
42
+ [0.1411200, -0.9899925, 0.0299955, 0.9995500],
43
+ [-0.7568025, -0.6536436, 0.0399893, 0.9992001],
44
+ [-0.9589243, 0.2836622, 0.0499792, 0.9987503],
45
+ [-0.2794155, 0.9601703, 0.0599640, 0.9982005],
46
+ [0.6569866, 0.7539023, 0.0699428, 0.9975510],
47
+ [0.9893582, -0.1455000, 0.0799147, 0.9968017],
48
+ [0.4121185, -0.9111302, 0.0898785, 0.9959527]]
49
+ ]);
50
+ const pytorch_embed8 = tf.tensor([
51
+ [[0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00,
52
+ 0.0000000e+00, 1.0000000e+00, 0.0000000e+00, 1.0000000e+00],
53
+ [8.4147096e-01, 5.4030234e-01, 9.9833414e-02, 9.9500418e-01,
54
+ 9.9998331e-03, 9.9994999e-01, 9.9999981e-04, 9.9999952e-01],
55
+ [9.0929741e-01, -4.1614684e-01, 1.9866931e-01, 9.8006660e-01,
56
+ 1.9998666e-02, 9.9980003e-01, 1.9999985e-03, 9.9999803e-01],
57
+ [1.4112000e-01, -9.8999250e-01, 2.9552019e-01, 9.5533651e-01,
58
+ 2.9995499e-02, 9.9955004e-01, 2.9999954e-03, 9.9999553e-01],
59
+ [-7.5680250e-01, -6.5364361e-01, 3.8941833e-01, 9.2106098e-01,
60
+ 3.9989334e-02, 9.9920011e-01, 3.9999890e-03, 9.9999201e-01],
61
+ [-9.5892429e-01, 2.8366220e-01, 4.7942552e-01, 8.7758255e-01,
62
+ 4.9979165e-02, 9.9875027e-01, 4.9999789e-03, 9.9998754e-01],
63
+ [-2.7941549e-01, 9.6017027e-01, 5.6464243e-01, 8.2533562e-01,
64
+ 5.9964005e-02, 9.9820054e-01, 5.9999637e-03, 9.9998200e-01],
65
+ [6.5698659e-01, 7.5390226e-01, 6.4421761e-01, 7.6484221e-01,
66
+ 6.9942847e-02, 9.9755102e-01, 6.9999420e-03, 9.9997550e-01],
67
+ [9.8935825e-01, -1.4550003e-01, 7.1735609e-01, 6.9670677e-01,
68
+ 7.9914689e-02, 9.9680167e-01, 7.9999138e-03, 9.9996799e-01],
69
+ [4.1211849e-01, -9.1113025e-01, 7.8332686e-01, 6.2160999e-01,
70
+ 8.9878544e-02, 9.9595273e-01, 8.9998785e-03, 9.9995953e-01]]
71
+ ]);
72
+ const positional4 = new PositionalEncoding({ embedDim: 4, maxSequenceLength: 10 });
73
+ positional4.build([]);
74
+ const positional8 = new PositionalEncoding({ embedDim: 8, maxSequenceLength: 10 });
75
+ positional8.build([]);
76
+ const margin_of_error = 1e-6;
77
+ // the difference between this and PyTorch's implementation
78
+ //should be insignificantly small
79
+ expect(positional4.getWeights()[0]
80
+ .sub(pytorch_embed4)
81
+ .abs()
82
+ .arraySync()
83
+ .flat(2)
84
+ .filter(i => i > margin_of_error)
85
+ .length).toBe(0);
86
+ expect(positional8.getWeights()[0]
87
+ .sub(pytorch_embed8)
88
+ .abs()
89
+ .arraySync()
90
+ .flat(2)
91
+ .filter(i => i > margin_of_error)
92
+ .length).toBe(0);
93
+ });
94
+ });
95
+ //# sourceMappingURL=positional_encoding.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"positional_encoding.test.js","sourceRoot":"","sources":["../../../src/layers/positional_encoding.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,oCAAoC,EAAE,GAAG,EAAE;QAC1C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,CAAC,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC5F,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,2BAA2B,EAAE,GAAG,EAAE;QACnC,MAAM,UAAU,GAAG,EAAE,CAAC;QACtB,MAAM,SAAS,GAAG,CAAC,CAAC;QACpB,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,SAAS,EAAE,UAAU,CAAC,CAAC,CAAC;QAE3D,MAAM,UAAU,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,UAAU,EAAE,CAAC,CAAC;QACpE,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACpD,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACtD,MAAM,CAAC,UAAU,CAAC,kBAAkB,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC;IAC5E,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,kGAAkG,EAAE,GAAG,EAAE;QACxG,MAAM,kBAAkB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC;QACvD,MAAM,oBAAoB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QACzD,MAAM,UAAU,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QAElD,MAAM,UAAU,GAAG,IAAI,kBAAkB,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC;QAEnF,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,kBAAkB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC7D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,oBAAoB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC/D,MAAM,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,SAAS,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,EAAE,EAAE,CAAC,CAAC;QAC3D,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,SAAS,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAC3D,CAAC,CAAC,CAAA;IAGF,qCAAqC;IACrC,mFAAmF;IACnF,EAAE,CAAC,mDAAmD,EAAE,GAAG,EAAE;QACzD,MAAM,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC;YAC7B,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC5C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC9C,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC5C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;gBAC7C,CAAC,SAAS,EAAE,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC,CAAC;SAAC,CAAC,CAAC;QAErD,MAAM,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC;YAC7B,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACvD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACzD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa;oBACvD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC;gBAC/D,CAAC,aAAa,EAAE,CAAC,aAAa,EAAE,aAAa,EAAE,aAAa;oBACxD,aAAa,EAAE,aAAa,EAAE,aAAa,EAAE,aAAa,CAAC,CAAC;SAAC,CAAC,CAAC;QAEvE,MAAM,WAAW,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QACnF,WAAW,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEtB,MAAM,WAAW,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QACnF,WAAW,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEtB,MAAM,eAAe,GAAG,IAAI,CAAC;QAE7B,2DAA2D;QAC3D,iCAAiC;QACjC,MAAM,CAAE,WAAW,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;aAC9B,GAAG,CAAC,cAAc,CAAC;aACnB,GAAG,EAAE;aACL,SAAS,EAAS;aAClB,IAAI,CAAC,CAAC,CAAC;aACP,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,eAAe,CAAC;aAChC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAErB,MAAM,CAAE,WAAW,CAAC,UAAU,EAAE,CAAC,CAAC,CAAC;aAC9B,GAAG,CAAC,cAAc,CAAC;aACnB,GAAG,EAAE;aACL,SAAS,EAAS;aAClB,IAAI,CAAC,CAAC,CAAC;aACP,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,eAAe,CAAC;aAChC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzB,CAAC,CAAC,CAAC;AACP,CAAC,CAAC,CAAC"}
@@ -0,0 +1,39 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { type LayerArgs } from "@tensorflow/tfjs-layers/dist/engine/topology";
3
+ export declare function applyRope(x: tf.Tensor, dim: number, cosine_cache: tf.Tensor, sine_cache: tf.Tensor): tf.Tensor<tf.Rank>;
4
+ export declare function rotateHalf(x: tf.Tensor, dim: number): tf.Tensor;
5
+ export declare function createRoPECache(dim: number, max_sequence_length: number, theta?: number): tf.Tensor<tf.Rank>[];
6
+ export interface RotaryPositionEmbeddingArgs extends LayerArgs {
7
+ /**
8
+ * The dimension of each head (rounded down), e.g. `Math.floor(embedDim / numHeads)`
9
+ */
10
+ dim: number;
11
+ /**
12
+ * The RoPE cache will be pre-calculated up to the max sequence length, and re-caculated as needed. Defaults to `4096`.
13
+ */
14
+ maxSequenceLength?: number;
15
+ /**
16
+ * The base for the geometric progression used to compute the rotation angles. Defaults to `10_000`.
17
+ */
18
+ theta?: number;
19
+ }
20
+ /**
21
+ * Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
22
+ * Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
23
+ */
24
+ export declare class RotaryPositionEmbedding extends tf.layers.Layer {
25
+ static className: string;
26
+ protected dim: number;
27
+ protected max_sequence_length: number;
28
+ protected theta: number;
29
+ protected cosine_cache: tf.LayerVariable;
30
+ protected sine_cache: tf.LayerVariable;
31
+ constructor({ dim, maxSequenceLength, theta, ...args }: RotaryPositionEmbeddingArgs);
32
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: any): tf.Tensor | tf.Tensor[];
33
+ build(input_shape: tf.Shape | tf.Shape[]): void;
34
+ /**
35
+ * Output shape: [batch, head, sequence, head_dim]
36
+ */
37
+ computeOutputShape(input_shape: tf.Shape | tf.Shape[]): tf.Shape;
38
+ }
39
+ //# sourceMappingURL=rotary_position_embedding.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"rotary_position_embedding.d.ts","sourceRoot":"","sources":["../../../src/layers/rotary_position_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAG9E,wBAAgB,SAAS,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,MAAM,EAAE,YAAY,EAAE,EAAE,CAAC,MAAM,EAAE,UAAU,EAAE,EAAE,CAAC,MAAM,sBAalG;AAGD,wBAAgB,UAAU,CAAC,CAAC,EAAE,EAAE,CAAC,MAAM,EAAE,GAAG,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,CAgB/D;AAGD,wBAAgB,eAAe,CAAC,GAAG,EAAE,MAAM,EAAE,mBAAmB,EAAE,MAAM,EAAE,KAAK,GAAE,MAAe,wBAqB/F;AAGD,MAAM,WAAW,2BAA4B,SAAQ,SAAS;IAC1D;;OAEG;IACH,GAAG,EAAE,MAAM,CAAC;IACZ;;OAEG;IACH,iBAAiB,CAAC,EAAE,MAAM,CAAC;IAC3B;;OAEG;IACH,KAAK,CAAC,EAAE,MAAM,CAAC;CAClB;AAGD;;;GAGG;AACH,qBAAa,uBAAwB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACxD,MAAM,CAAC,SAAS,SAA6B;IAE7C,SAAS,CAAC,GAAG,EAAE,MAAM,CAAC;IACtB,SAAS,CAAC,mBAAmB,EAAE,MAAM,CAAC;IACtC,SAAS,CAAC,KAAK,EAAE,MAAM,CAAC;IAGxB,SAAS,CAAC,YAAY,EAAE,EAAE,CAAC,aAAa,CAAC;IACzC,SAAS,CAAC,UAAU,EAAE,EAAE,CAAC,aAAa,CAAC;gBAE3B,EAAE,GAAG,EAAE,iBAAwB,EAAE,KAAc,EAAE,GAAG,IAAI,EAAE,EAAE,2BAA2B;IAqB1F,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,GAAG,GAAG,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE;IAkB3E,KAAK,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAmBjD;;OAEG;IACI,kBAAkB,CAAC,WAAW,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;CAK/D"}
@@ -0,0 +1,99 @@
1
+ import * as tf from "@tensorflow/tfjs";
2
+ export function applyRope(x, dim, cosine_cache, sine_cache) {
3
+ return tf.tidy(() => {
4
+ const seq_length = x.shape[2];
5
+ // get a slice of the pre-computed cache, up to the input's sequence length
6
+ const cosine = cosine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
7
+ const sine = sine_cache.slice([0, 0, 0, 0], [1, 1, seq_length, dim]);
8
+ // apply RoPE formula (x1 * cosine) + (rotate(-x2) * sine)
9
+ const rotated_x = rotateHalf(x, dim);
10
+ return tf.add(tf.mul(x, cosine), tf.mul(rotated_x, sine));
11
+ });
12
+ }
13
+ export function rotateHalf(x, dim) {
14
+ return tf.tidy(() => {
15
+ // reshape the last dimension such that adjacent coordinates are paired together
16
+ // [x1, x2, x3, x4] -> [[x1, x2], [x3, x4]]
17
+ // the leading dimensions are flattened because TFJS has issues during
18
+ // backpropagation with 5D slicing
19
+ const reshaped = x.reshape([-1, dim / 2, 2]);
20
+ const x1 = reshaped.slice([0, 0, 0], [-1, -1, 1]);
21
+ const x2 = reshaped.slice([0, 0, 1], [-1, -1, 1]);
22
+ // [x1, x2] -> [-x2, x1]
23
+ const rotated = tf.concat([tf.neg(x2), x1], -1);
24
+ return rotated.reshape(x.shape);
25
+ });
26
+ }
27
+ export function createRoPECache(dim, max_sequence_length, theta = 10_000) {
28
+ return tf.tidy(() => {
29
+ // [dim]
30
+ const inv_frequencies = tf.div(1, tf.pow(theta, tf.range(0, Math.floor(dim / 2) * 2, 2, "float32").div(dim)));
31
+ // [max_sequene_length]
32
+ const sequence_indices = tf.range(0, max_sequence_length);
33
+ //
34
+ const freq = tf.outerProduct(sequence_indices, inv_frequencies);
35
+ // cache final shape [max_sequence_length, dim]
36
+ const freq_pairs = tf.stack([freq, freq], -1)
37
+ .reshape([max_sequence_length, dim]);
38
+ return [
39
+ tf.keep(tf.cos(freq_pairs).expandDims(0).expandDims(0)),
40
+ tf.keep(tf.sin(freq_pairs).expandDims(0).expandDims(0))
41
+ ];
42
+ });
43
+ }
44
+ /**
45
+ * Implements RoPE from the RoFormer: Enhanced Transformer with Rotary Position Embedding paper
46
+ * Inspired by: https://meta-pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html#RotaryPositionalEmbeddings
47
+ */
48
+ export class RotaryPositionEmbedding extends tf.layers.Layer {
49
+ static className = "RotaryPositionEmbedding";
50
+ dim;
51
+ max_sequence_length;
52
+ theta;
53
+ // cached sine and cosine frequencies, untrainable weights
54
+ cosine_cache;
55
+ sine_cache;
56
+ constructor({ dim, maxSequenceLength = 4096, theta = 10_000, ...args }) {
57
+ super(args);
58
+ if (dim % 2 !== 0) {
59
+ throw Error(`${this.getClassName()}::constructor ${this.name} expected dim to be even, got ${dim}`);
60
+ }
61
+ this.dim = dim;
62
+ this.max_sequence_length = maxSequenceLength;
63
+ this.theta = theta;
64
+ this.cosine_cache = this.addWeight("sine_cache", [1, 1, maxSequenceLength, Math.floor(this.dim)], "float32", tf.initializers.zeros(), undefined, false);
65
+ this.sine_cache = this.addWeight("cosine_cache", [1, 1, maxSequenceLength, Math.floor(this.dim)], "float32", tf.initializers.zeros(), undefined, false);
66
+ }
67
+ call(inputs, kwargs) {
68
+ const shape = Array.isArray(inputs) ? inputs[0].shape : inputs.shape;
69
+ const seq_length = shape[2];
70
+ if (seq_length > this.max_sequence_length) {
71
+ // expand cache to the nearest power of 2
72
+ this.max_sequence_length = Math.pow(2, Math.ceil(Math.log2(seq_length)));
73
+ this.build([]);
74
+ }
75
+ return applyRope(Array.isArray(inputs) ? inputs[0] : inputs, this.dim, this.cosine_cache.read(), this.sine_cache.read());
76
+ }
77
+ build(input_shape) {
78
+ const [cosine, sine] = createRoPECache(this.dim, this.max_sequence_length, this.theta);
79
+ this.cosine_cache.dispose();
80
+ this.sine_cache.dispose();
81
+ this.cosine_cache = new tf.LayerVariable(cosine);
82
+ this.sine_cache = new tf.LayerVariable(sine);
83
+ this.nonTrainableWeights = [
84
+ new tf.LayerVariable(cosine),
85
+ new tf.LayerVariable(sine)
86
+ ];
87
+ this.setWeights([cosine, sine]);
88
+ }
89
+ /**
90
+ * Output shape: [batch, head, sequence, head_dim]
91
+ */
92
+ computeOutputShape(input_shape) {
93
+ return Array.isArray(input_shape[0])
94
+ ? input_shape[0]
95
+ : input_shape;
96
+ }
97
+ }
98
+ tf.serialization.registerClass(RotaryPositionEmbedding);
99
+ //# sourceMappingURL=rotary_position_embedding.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"rotary_position_embedding.js","sourceRoot":"","sources":["../../../src/layers/rotary_position_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,MAAM,UAAU,SAAS,CAAC,CAAY,EAAE,GAAW,EAAE,YAAuB,EAAE,UAAqB;IAC/F,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,MAAM,UAAU,GAAG,CAAC,CAAC,KAAK,CAAC,CAAC,CAAE,CAAC;QAE/B,2EAA2E;QAC3E,MAAM,MAAM,GAAG,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,GAAG,CAAC,CAAC,CAAC;QACzE,MAAM,IAAI,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,UAAU,EAAE,GAAG,CAAC,CAAC,CAAC;QAErE,0DAA0D;QAC1D,MAAM,SAAS,GAAG,UAAU,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC;QAErC,OAAO,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,EAAE,MAAM,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,IAAI,CAAC,CAAC,CAAC;IAC9D,CAAC,CAAC,CAAC;AACP,CAAC;AAGD,MAAM,UAAU,UAAU,CAAC,CAAY,EAAE,GAAW;IAChD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,gFAAgF;QAChF,2CAA2C;QAC3C,sEAAsE;QACtE,kCAAkC;QAClC,MAAM,QAAQ,GAAG,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,EAAE,GAAG,GAAG,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAE7C,MAAM,EAAE,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClD,MAAM,EAAE,GAAG,QAAQ,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAElD,wBAAwB;QACxB,MAAM,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAEhD,OAAO,OAAO,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;AACP,CAAC;AAGD,MAAM,UAAU,eAAe,CAAC,GAAW,EAAE,mBAA2B,EAAE,QAAgB,MAAM;IAC5F,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,QAAQ;QACR,MAAM,eAAe,GAAG,EAAE,CAAC,GAAG,CAAc,CAAC,EAAE,EAAE,CAAC,GAAG,CACjD,KAAK,EACL,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,IAAI,CAAC,KAAK,CAAC,GAAG,GAAG,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,SAAS,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QAElE,uBAAuB;QACvB,MAAM,gBAAgB,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,mBAAmB,CAAC,CAAC;QAC1D,GAAG;QACH,MAAM,IAAI,GAAG,EAAE,CAAC,YAAY,CAAC,gBAAgB,EAAE,eAAe,CAAC,CAAC;QAEhE,+CAA+C;QAC/C,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC;aACxC,OAAO,CAAC,CAAC,mBAAmB,EAAE,GAAG,CAAC,CAAC,CAAC;QAEzC,OAAO;YACH,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YACvD,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;SAC1D,CAAA;IACL,CAAC,CAAC,CAAC;AACP,CAAC;AAmBD;;;GAGG;AACH,MAAM,OAAO,uBAAwB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACxD,MAAM,CAAC,SAAS,GAAG,yBAAyB,CAAC;IAEnC,GAAG,CAAS;IACZ,mBAAmB,CAAS;IAC5B,KAAK,CAAS;IAExB,0DAA0D;IAChD,YAAY,CAAmB;IAC/B,UAAU,CAAmB;IAEvC,YAAY,EAAE,GAAG,EAAE,iBAAiB,GAAG,IAAI,EAAE,KAAK,GAAG,MAAM,EAAE,GAAG,IAAI,EAA+B;QAC/F,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,GAAG,GAAG,CAAC,KAAK,CAAC,EAAE,CAAC;YAChB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,iBAAiB,IAAI,CAAC,IAAI,iCAAiC,GAAG,EAAE,CAAC,CAAC;QACxG,CAAC;QAED,IAAI,CAAC,GAAG,GAAG,GAAG,CAAC;QACf,IAAI,CAAC,mBAAmB,GAAG,iBAAiB,CAAC;QAC7C,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC;QAEnB,IAAI,CAAC,YAAY,GAAG,IAAI,CAAC,SAAS,CAAC,YAAY,EAC3C,CAAC,CAAC,EAAE,CAAC,EAAE,iBAAiB,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAC/C,SAAS,EAAE,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;QAE1D,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,cAAc,EAC3C,CAAC,CAAC,EAAE,CAAC,EAAE,iBAAiB,EAAE,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,EAC/C,SAAS,EAAE,EAAE,CAAC,YAAY,CAAC,KAAK,EAAE,EAAE,SAAS,EAAE,KAAK,CAAC,CAAC;IAC9D,CAAC;IAGQ,IAAI,CAAC,MAA+B,EAAE,MAAW;QACtD,MAAM,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC;QACrE,MAAM,UAAU,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;QAE5B,IAAI,UAAU,GAAG,IAAI,CAAC,mBAAmB,EAAE,CAAC;YACxC,yCAAyC;YACzC,IAAI,CAAC,mBAAmB,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YACzE,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QACnB,CAAC;QAED,OAAO,SAAS,CACZ,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,EAC1C,IAAI,CAAC,GAAG,EACR,IAAI,CAAC,YAAY,CAAC,IAAI,EAAE,EACxB,IAAI,CAAC,UAAU,CAAC,IAAI,EAAE,CAAC,CAAA;IAC/B,CAAC;IAGQ,KAAK,CAAC,WAAkC;QAC7C,MAAM,CAAC,MAAM,EAAE,IAAI,CAAC,GAAG,eAAe,CAClC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,mBAAmB,EAAE,IAAI,CAAC,KAAK,CAAC,CAAC;QAEpD,IAAI,CAAC,YAAY,CAAC,OAAO,EAAE,CAAC;QAC5B,IAAI,CAAC,UAAU,CAAC,OAAO,EAAE,CAAC;QAE1B,IAAI,CAAC,YAAY,GAAG,IAAI,EAAE,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;QACjD,IAAI,CAAC,UAAU,GAAG,IAAI,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;QAE7C,IAAI,CAAC,mBAAmB,GAAG;YACvB,IAAI,EAAE,CAAC,aAAa,CAAC,MAAM,CAAC;YAC5B,IAAI,EAAE,CAAC,aAAa,CAAC,IAAI,CAAC;SAC7B,CAAC;QAEF,IAAI,CAAC,UAAU,CAAC,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC,CAAC;IACpC,CAAC;IAGD;;OAEG;IACI,kBAAkB,CAAC,WAAkC;QACxD,OAAO,KAAK,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;YAChC,CAAC,CAAC,WAAW,CAAC,CAAC,CAAa;YAC5B,CAAC,CAAC,WAAuB,CAAC;IAClC,CAAC;;AAGL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,uBAAuB,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=rotary_position_embedding.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"rotary_position_embedding.test.d.ts","sourceRoot":"","sources":["../../../src/layers/rotary_position_embedding.test.ts"],"names":[],"mappings":""}
@@ -0,0 +1,88 @@
1
+ import { RotaryPositionEmbedding } from "@/layers/rotary_position_embedding";
2
+ import * as tf from "@tensorflow/tfjs";
3
+ // disables warning for using the faster node backend,
4
+ // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
5
+ tf.env().set('IS_NODE', false);
6
+ describe("RotaryPositionEmbedding tests", () => {
7
+ test("create cache", async () => {
8
+ const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
9
+ rope.build([]);
10
+ const expected_cosine_cache = tf.tensor([[[
11
+ [1, 1, 1, 1, 1, 1, 1, 1],
12
+ [0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334, 0.9999995231628418, 0.9999995231628418],
13
+ [-0.416146844625473, -0.416146844625473, 0.9800665974617004, 0.9800665974617004, 0.9998000264167786, 0.9998000264167786, 0.9999979734420776, 0.9999979734420776],
14
+ [-0.9899924993515015, -0.9899924993515015, 0.9553365111351013, 0.9553365111351013, 0.9995500445365906, 0.9995500445365906, 0.9999955296516418, 0.9999955296516418],
15
+ [-0.6536436080932617, -0.6536436080932617, 0.9210609793663025, 0.9210609793663025, 0.9992001056671143, 0.9992001056671143, 0.9999920129776001, 0.9999920129776001],
16
+ [0.28366219997406006, 0.28366219997406006, 0.8775825500488281, 0.8775825500488281, 0.9987502694129944, 0.9987502694129944, 0.9999874830245972, 0.9999874830245972],
17
+ [0.9601702690124512, 0.9601702690124512, 0.8253356218338013, 0.8253356218338013, 0.998200535774231, 0.998200535774231, 0.9999819993972778, 0.9999819993972778],
18
+ [0.7539022564888, 0.7539022564888, 0.7648422122001648, 0.7648422122001648, 0.9975510239601135, 0.9975510239601135, 0.9999755024909973, 0.9999755024909973],
19
+ [-0.1455000340938568, -0.1455000340938568, 0.6967067122459412, 0.6967067122459412, 0.9968017339706421, 0.9968017339706421, 0.9999679923057556, 0.9999679923057556],
20
+ [-0.9111302495002747, -0.9111302495002747, 0.6216099262237549, 0.6216099262237549, 0.9959527254104614, 0.9959527254104614, 0.9999595284461975, 0.9999595284461975],
21
+ [-0.83907151222229, -0.83907151222229, 0.5403022766113281, 0.5403022766113281, 0.9950041770935059, 0.9950041770935059, 0.9999499917030334, 0.9999499917030334],
22
+ [0.004425697959959507, 0.004425697959959507, 0.4535960853099823, 0.4535960853099823, 0.9939560890197754, 0.9939560890197754, 0.999939501285553, 0.999939501285553],
23
+ [0.8438539505004883, 0.8438539505004883, 0.3623577058315277, 0.3623577058315277, 0.9928086400032043, 0.9928086400032043, 0.9999279975891113, 0.9999279975891113],
24
+ [0.9074468016624451, 0.9074468016624451, 0.26749876141548157, 0.26749876141548157, 0.9915618896484375, 0.9915618896484375, 0.9999154806137085, 0.9999154806137085],
25
+ [0.13673721253871918, 0.13673721253871918, 0.1699671596288681, 0.1699671596288681, 0.9902160167694092, 0.9902160167694092, 0.9999020099639893, 0.9999020099639893]
26
+ ]]]);
27
+ const expected_sine_cache = tf.tensor([[[
28
+ [0, 0, 0, 0, 0, 0, 0, 0],
29
+ [0.8414709568023682, 0.8414709568023682, 0.0998334214091301, 0.0998334214091301, 0.009999833069741726, 0.009999833069741726, 0.0009999999310821295, 0.0009999999310821295],
30
+ [0.9092974066734314, 0.9092974066734314, 0.19866932928562164, 0.19866932928562164, 0.019998665899038315, 0.019998665899038315, 0.0019999986980110407, 0.0019999986980110407],
31
+ [0.14112000167369843, 0.14112000167369843, 0.29552021622657776, 0.29552021622657776, 0.029995499178767204, 0.029995499178767204, 0.0029999956022948027, 0.0029999956022948027],
32
+ [-0.756802499294281, -0.756802499294281, 0.3894183337688446, 0.3894183337688446, 0.03998933359980583, 0.03998933359980583, 0.003999989479780197, 0.003999989479780197],
33
+ [-0.9589242935180664, -0.9589242935180664, 0.4794255495071411, 0.4794255495071411, 0.04997916519641876, 0.04997916519641876, 0.0049999793991446495, 0.0049999793991446495],
34
+ [-0.279415488243103, -0.279415488243103, 0.5646424889564514, 0.5646424889564514, 0.059964004904031754, 0.059964004904031754, 0.0059999641962349415, 0.0059999641962349415],
35
+ [0.6569865942001343, 0.6569865942001343, 0.6442176699638367, 0.6442176699638367, 0.06994284689426422, 0.06994284689426422, 0.0069999429397284985, 0.0069999429397284985],
36
+ [0.9893582463264465, 0.9893582463264465, 0.7173560857772827, 0.7173560857772827, 0.07991468906402588, 0.07991468906402588, 0.007999914698302746, 0.007999914698302746],
37
+ [0.41211849451065063, 0.41211849451065063, 0.7833269238471985, 0.7833269238471985, 0.08987854421138763, 0.08987854421138763, 0.008999879471957684, 0.008999879471957684],
38
+ [-0.5440211296081543, -0.5440211296081543, 0.8414709568023682, 0.8414709568023682, 0.0998334139585495, 0.0998334139585495, 0.0099998340010643, 0.0099998340010643],
39
+ [-0.9999902248382568, -0.9999902248382568, 0.8912073969841003, 0.8912073969841003, 0.10977829992771149, 0.10977829992771149, 0.010999779216945171, 0.010999779216945171],
40
+ [-0.5365729331970215, -0.5365729331970215, 0.9320390820503235, 0.9320390820503235, 0.11971220374107361, 0.11971220374107361, 0.011999712325632572, 0.011999712325632572],
41
+ [0.4201670289039612, 0.4201670289039612, 0.9635581970214844, 0.9635581970214844, 0.12963414192199707, 0.12963414192199707, 0.012999634258449078, 0.012999634258449078],
42
+ [0.9906073808670044, 0.9906073808670044, 0.9854497313499451, 0.9854497313499451, 0.13954311609268188, 0.13954311609268188, 0.013999543152749538, 0.013999543152749538]
43
+ ]]]);
44
+ const [cosine_cache, sine_cache] = rope.getWeights();
45
+ expect(await cosine_cache?.sub(expected_cosine_cache).sum().array()).toBeLessThanOrEqual(1e-6);
46
+ expect(await sine_cache?.sub(expected_sine_cache).sum().array()).toBeLessThanOrEqual(1e-6);
47
+ });
48
+ test("rotate inputs", async () => {
49
+ const rope = new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 15 });
50
+ const x = tf.tensor([[[
51
+ [0.0766048, 0.5706575, 0.6705932, 0.5273118, 0.4794086, 0.9378104, 0.9888024, 0.6926053],
52
+ [0.9064133, 0.5875182, 0.1681865, 0.3833345, 0.9901192, 0.4677338, 0.3353315, 0.02699],
53
+ [0.3033573, 0.4139377, 0.4062586, 0.9705839, 0.3582608, 0.328775, 0.1340587, 0.2193414],
54
+ [0.5565202, 0.4334963, 0.9912352, 0.3388563, 0.7991487, 0.1911893, 0.1140554, 0.6949552]
55
+ ]]
56
+ ]); // batch=1, seq = 1, heads=4, embedDim=8
57
+ const expected_output = tf.tensor([[[
58
+ [0.07660479843616486, 0.57065749168396, 0.6705932021141052, 0.5273118019104004, 0.4794085919857025, 0.9378104209899902, 0.9888023734092712, 0.6926053166389465],
59
+ [-0.004642367362976074, 1.08015775680542, 0.12907665967941284, 0.39820998907089233, 0.9853923320770264, 0.47761136293411255, 0.33530429005622864, 0.027325313538312912],
60
+ [-0.5026336908340454, 0.10358311235904694, 0.20533521473407745, 1.0319478511810303, 0.3516140580177307, 0.33587393164634705, 0.1336197406053543, 0.21960905194282532],
61
+ [-0.6121258735656738, -0.3506217896938324, 0.8468242287635803, 0.6166517734527588, 0.7930541634559631, 0.2150741070508957, 0.11197001487016678, 0.695294201374054]
62
+ ]]]);
63
+ const output = rope.apply(x);
64
+ expect(await expected_output.sub(output).sum().array()).toBeLessThan(1e-6);
65
+ expect(rope.computeOutputShape(x.shape)).toEqual(x.shape);
66
+ expect(rope.computeOutputShape([x.shape])).toEqual(x.shape);
67
+ });
68
+ test("expand cache when input sequences are larger than rope's max sequence length", async () => {
69
+ const dim = 8;
70
+ const rope = new RotaryPositionEmbedding({ dim, maxSequenceLength: 15, theta: 1_000_000 });
71
+ const larger_sequence = 20;
72
+ const even_larger_sequence = 50;
73
+ rope.apply(tf.randomUniform([1, 1, larger_sequence, dim]));
74
+ rope.getWeights().forEach(weight => {
75
+ expect(weight.shape).toEqual([1, 1, 32, dim]);
76
+ });
77
+ rope.apply([tf.randomUniform([1, 1, even_larger_sequence, dim])]);
78
+ rope.getWeights().forEach(weight => {
79
+ expect(weight.shape).toEqual([1, 1, 64, dim]);
80
+ });
81
+ });
82
+ test("create layer", async () => {
83
+ // dim must be even
84
+ expect(() => new RotaryPositionEmbedding({ dim: 7, maxSequenceLength: 15 })).toThrow();
85
+ expect(() => new RotaryPositionEmbedding({ dim: 8, maxSequenceLength: 25 })).not.toThrow();
86
+ });
87
+ });
88
+ //# sourceMappingURL=rotary_position_embedding.test.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"rotary_position_embedding.test.js","sourceRoot":"","sources":["../../../src/layers/rotary_position_embedding.test.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,uBAAuB,EAAE,MAAM,oCAAoC,CAAC;AAC7E,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,+BAA+B,EAAE,GAAG,EAAE;IAC3C,IAAI,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QAC5E,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC;QAEf,MAAM,qBAAqB,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBACtC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;oBACxB,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,iBAAiB,EAAE,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC9J,CAAC,eAAe,EAAE,eAAe,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC1J,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,gBAAgB,EAAE,CAAC,gBAAgB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC9J,CAAC,oBAAoB,EAAE,oBAAoB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,iBAAiB,EAAE,iBAAiB,CAAC;oBAClK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAChK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;iBACrK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,mBAAmB,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBACpC,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC;oBACxB,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC5K,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC9K,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,CAAC,iBAAiB,EAAE,CAAC,iBAAiB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,oBAAoB,EAAE,oBAAoB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBAC1K,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,qBAAqB,EAAE,qBAAqB,CAAC;oBACxK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAClK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACxK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;oBACtK,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,EAAE,oBAAoB,CAAC;iBACzK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,CAAC,YAAY,EAAE,UAAU,CAAC,GAAG,IAAI,CAAC,UAAU,EAAE,CAAC;QAErD,MAAM,CAAC,MAAM,YAAY,EAAE,GAAG,CAAC,qBAAqB,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC;QACzG,MAAM,CAAC,MAAM,UAAU,EAAE,GAAG,CAAC,mBAAmB,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,mBAAmB,CAAC,IAAI,CAAC,CAAC;IACzG,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,eAAe,EAAE,KAAK,IAAI,EAAE;QAC7B,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC;QAE5E,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBAClB,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;oBACxF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,OAAO,CAAC;oBACtF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,SAAS,CAAC;oBACvF,CAAC,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,EAAE,SAAS,CAAC;iBAAC,CAAC;SAC7F,CAAC,CAAC,CAAC,wCAAwC;QAE5C,MAAM,eAAe,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC;oBAChC,CAAC,mBAAmB,EAAE,gBAAgB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,CAAC;oBAC/J,CAAC,CAAC,oBAAoB,EAAE,gBAAgB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,oBAAoB,CAAC;oBACvK,CAAC,CAAC,kBAAkB,EAAE,mBAAmB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,kBAAkB,EAAE,mBAAmB,CAAC;oBACrK,CAAC,CAAC,kBAAkB,EAAE,CAAC,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,kBAAkB,EAAE,mBAAmB,EAAE,iBAAiB,CAAC;iBACrK,CAAC,CAAC,CAAC,CAAC;QAEL,MAAM,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAc,CAAC;QAE1C,MAAM,CAAC,MAAM,eAAe,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,GAAG,EAAE,CAAC,KAAK,EAAY,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC;QACrF,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;QAC1D,MAAM,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;IAChE,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,8EAA8E,EAAE,KAAK,IAAI,EAAE;QAC5F,MAAM,GAAG,GAAG,CAAC,CAAC;QACd,MAAM,IAAI,GAAG,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,iBAAiB,EAAE,EAAE,EAAE,KAAK,EAAE,SAAS,EAAE,CAAC,CAAC;QAC3F,MAAM,eAAe,GAAG,EAAE,CAAC;QAC3B,MAAM,oBAAoB,GAAG,EAAE,CAAC;QAEhC,IAAI,CAAC,KAAK,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,eAAe,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC;QAE3D,IAAI,CAAC,UAAU,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC/B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QAClD,CAAC,CAAC,CAAC;QAEH,IAAI,CAAC,KAAK,CAAC,CAAC,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,oBAAoB,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;QAElE,IAAI,CAAC,UAAU,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE;YAC/B,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,OAAO,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,GAAG,CAAC,CAAC,CAAC;QAClD,CAAC,CAAC,CAAC;IACP,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,cAAc,EAAE,KAAK,IAAI,EAAE;QAC5B,mBAAmB;QACnB,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,uBAAuB,CAAC,EAAE,GAAG,EAAE,CAAC,EAAE,iBAAiB,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IAC/F,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAC"}
@@ -0,0 +1,47 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { type LayerArgs } from '@tensorflow/tfjs-layers/dist/engine/topology';
3
+ import { type Kwargs } from '@tensorflow/tfjs-layers/dist/types';
4
+ import { type PositionalEncodingArgs } from '@/layers/positional_encoding';
5
+ export interface TokenAndPositionalEmbeddingArgs extends LayerArgs, PositionalEncodingArgs {
6
+ vocabularySize: number;
7
+ dropout?: number;
8
+ }
9
+ /**
10
+ * This class implements combines sinusoidal positional encoding from the
11
+ * 2017 paper "Attention Is All You Need" with a normal embedding layer to
12
+ * form a simplified single embedding layer.
13
+ *
14
+ * This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
15
+ * it through an embedding layer before adding sinusoidal positional encoding.
16
+ *
17
+ * @param embedDim the size of each token/word's embedding
18
+ * @param vocabularySize the number of tokens to embed
19
+ * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
20
+ * @param dropout applies dropout to the positionally encoded embeddings, default `0.1`
21
+ */
22
+ export declare class TokenAndPositionalEmbedding extends tf.layers.Layer {
23
+ static className: string;
24
+ private readonly embedDim;
25
+ private readonly vocabularySize;
26
+ private embedding;
27
+ private positional;
28
+ private readonly maxSequenceLength;
29
+ private readonly dropout;
30
+ private dropoutLayer;
31
+ constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }: TokenAndPositionalEmbeddingArgs);
32
+ /**
33
+ * Forward propagation.
34
+ */
35
+ call(inputs: tf.Tensor | tf.Tensor[], kwargs: Kwargs): tf.Tensor<tf.Rank>;
36
+ /**
37
+ * Build the sublayers and enable serialization
38
+ */
39
+ build(inputShape: tf.Shape | tf.Shape[]): void;
40
+ /**
41
+ * The output shape, for an input shape of [batch, sequences], is
42
+ * [batch, sequences, embedDim]
43
+ */
44
+ computeOutputShape(inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[];
45
+ getConfig(): tf.serialization.ConfigDict;
46
+ }
47
+ //# sourceMappingURL=token_and_positional_embedding.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"token_and_positional_embedding.d.ts","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,SAAS,EAAE,MAAM,8CAA8C,CAAC;AAC9E,OAAO,EAAE,KAAK,MAAM,EAAE,MAAM,oCAAoC,CAAC;AAEjE,OAAO,EAAsB,KAAK,sBAAsB,EAAE,MAAM,8BAA8B,CAAC;AAG/F,MAAM,WAAW,+BAAgC,SAAQ,SAAS,EAAE,sBAAsB;IACtF,cAAc,EAAE,MAAM,CAAC;IACvB,OAAO,CAAC,EAAE,MAAM,CAAA;CACnB;AAGD;;;;;;;;;;;;GAYG;AACH,qBAAa,2BAA4B,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IAC5D,MAAM,CAAC,SAAS,SAAiC;IAEjD,OAAO,CAAC,QAAQ,CAAC,QAAQ,CAAS;IAClC,OAAO,CAAC,QAAQ,CAAC,cAAc,CAAS;IACxC,OAAO,CAAC,SAAS,CAAkB;IAEnC,OAAO,CAAC,UAAU,CAAiB;IACnC,OAAO,CAAC,QAAQ,CAAC,iBAAiB,CAAS;IAC3C,OAAO,CAAC,QAAQ,CAAC,OAAO,CAAS;IAEjC,OAAO,CAAC,YAAY,CAAkB;gBAG1B,EAAE,QAAQ,EAAE,cAAc,EAAE,iBAAiB,EAAE,OAAO,EAAE,GAAG,IAAI,EAAE,EAAE,+BAA+B;IA0B9G;;OAEG;IACM,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,EAAE,EAAE,MAAM,EAAE,MAAM;IAe7D;;OAEG;IACM,KAAK,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IAgCvD;;;OAGG;IACM,kBAAkB,CAAC,UAAU,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE;IAQ5E,SAAS,IAAI,EAAE,CAAC,aAAa,CAAC,UAAU;CAcpD"}
@@ -0,0 +1,109 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { PositionalEncoding } from '@/layers/positional_encoding';
3
+ /**
4
+ * This class implements combines sinusoidal positional encoding from the
5
+ * 2017 paper "Attention Is All You Need" with a normal embedding layer to
6
+ * form a simplified single embedding layer.
7
+ *
8
+ * This layer accepts tokenized inputs of the shape `[ batch, tokens ]` and runs
9
+ * it through an embedding layer before adding sinusoidal positional encoding.
10
+ *
11
+ * @param embedDim the size of each token/word's embedding
12
+ * @param vocabularySize the number of tokens to embed
13
+ * @param maxSequenceLength the max number of tokens (words) per input (sentence), default `5120`
14
+ * @param dropout applies dropout to the positionally encoded embeddings, default `0.1`
15
+ */
16
+ export class TokenAndPositionalEmbedding extends tf.layers.Layer {
17
+ static className = "TokenAndPositionalEmbedding";
18
+ embedDim;
19
+ vocabularySize;
20
+ embedding;
21
+ positional;
22
+ maxSequenceLength;
23
+ dropout;
24
+ dropoutLayer;
25
+ constructor({ embedDim, vocabularySize, maxSequenceLength, dropout, ...args }) {
26
+ super(args);
27
+ this.embedDim = embedDim;
28
+ this.vocabularySize = vocabularySize;
29
+ this.maxSequenceLength = maxSequenceLength ?? 5120;
30
+ this.dropout = dropout ?? 0.1;
31
+ if (this.dropout >= 1) {
32
+ throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
33
+ }
34
+ this.embedding = tf.layers.embedding({
35
+ inputDim: this.vocabularySize,
36
+ outputDim: this.embedDim,
37
+ });
38
+ this.positional = new PositionalEncoding({
39
+ maxSequenceLength: this.maxSequenceLength,
40
+ embedDim: this.embedDim,
41
+ });
42
+ this.dropoutLayer = tf.layers.dropout({ rate: this.dropout });
43
+ }
44
+ /**
45
+ * Forward propagation.
46
+ */
47
+ call(inputs, kwargs) {
48
+ if (Array.isArray(inputs) && inputs.length != 1) {
49
+ throw Error(`${this.getClassName()}::call ${this.name} expects exactly` +
50
+ ` 1 tensor input, received ${inputs.length}`);
51
+ }
52
+ return tf.tidy(() => {
53
+ let output = this.positional.apply(this.embedding.apply(inputs));
54
+ output = this.dropoutLayer.apply(output);
55
+ return output;
56
+ });
57
+ }
58
+ /**
59
+ * Build the sublayers and enable serialization
60
+ */
61
+ build(inputShape) {
62
+ let input_shapes = [];
63
+ // only consider the first shape if multiple provided
64
+ if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
65
+ // input is an array of shapes
66
+ input_shapes = inputShape;
67
+ }
68
+ else if (inputShape.length != 0) {
69
+ // input is a single shape
70
+ input_shapes = [inputShape];
71
+ }
72
+ if (input_shapes[0].length != 2 || input_shapes[0][1] > this.maxSequenceLength) {
73
+ throw Error(`${this.getClassName()}::build ${this.name} expected an input of` +
74
+ ` shape [batch, tokens] where tokens < ${this.maxSequenceLength},` +
75
+ ` received ${JSON.stringify(input_shapes[0])}`);
76
+ }
77
+ // initialize the sublayers' weights
78
+ this.embedding.build(input_shapes[0]);
79
+ this.positional.build(this.embedding.computeOutputShape(input_shapes[0]));
80
+ // no need to rename weights, haven't found a case where their names collide
81
+ this.trainableWeights = [
82
+ ...this.embedding.trainableWeights,
83
+ ...this.positional.trainableWeights
84
+ ];
85
+ super.build(input_shapes[0]);
86
+ }
87
+ /**
88
+ * The output shape, for an input shape of [batch, sequences], is
89
+ * [batch, sequences, embedDim]
90
+ */
91
+ computeOutputShape(inputShape) {
92
+ const embedding_shape = this.embedding.computeOutputShape(inputShape);
93
+ const positional_shape = this.positional.computeOutputShape(embedding_shape);
94
+ return positional_shape;
95
+ }
96
+ getConfig() {
97
+ const base_config = super.getConfig();
98
+ const config = {
99
+ embedDim: this.embedDim,
100
+ vocabularySize: this.vocabularySize,
101
+ maxSequenceLength: this.maxSequenceLength,
102
+ dropout: this.dropout,
103
+ };
104
+ Object.assign(config, base_config);
105
+ return config;
106
+ }
107
+ }
108
+ tf.serialization.registerClass(TokenAndPositionalEmbedding);
109
+ //# sourceMappingURL=token_and_positional_embedding.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"token_and_positional_embedding.js","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAS/F;;;;;;;;;;;;GAYG;AACH,MAAM,OAAO,2BAA4B,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IAC5D,MAAM,CAAC,SAAS,GAAG,6BAA6B,CAAC;IAEhC,QAAQ,CAAS;IACjB,cAAc,CAAS;IAChC,SAAS,CAAkB;IAE3B,UAAU,CAAiB;IAClB,iBAAiB,CAAS;IAC1B,OAAO,CAAS;IAEzB,YAAY,CAAkB;IAGtC,YAAY,EAAE,QAAQ,EAAE,cAAc,EAAE,iBAAiB,EAAE,OAAO,EAAE,GAAG,IAAI,EAAmC;QAC1G,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,cAAc,GAAG,cAAc,CAAC;QACrC,IAAI,CAAC,iBAAiB,GAAG,iBAAiB,IAAI,IAAI,CAAC;QACnD,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAE9B,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,MAAM,CAAC,SAAS,CAAC;YACjC,QAAQ,EAAE,IAAI,CAAC,cAAc;YAC7B,SAAS,EAAE,IAAI,CAAC,QAAQ;SAC3B,CAAC,CAAC;QAEH,IAAI,CAAC,UAAU,GAAG,IAAI,kBAAkB,CAAC;YACrC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,QAAQ,EAAE,IAAI,CAAC,QAAQ;SAC1B,CAAC,CAAC;QAEH,IAAI,CAAC,YAAY,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;IAClE,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC9C,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,UAAU,IAAI,CAAC,IAAI,kBAAkB;gBACnE,6BAA6B,MAAM,CAAC,MAAM,EAAE,CAAC,CAAC;QACtD,CAAC;QAED,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,IAAI,MAAM,GAAG,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,MAAM,CAAC,CAAc,CAAC;YAC9E,MAAM,GAAG,IAAI,CAAC,YAAY,CAAC,KAAK,CAAC,MAAM,CAAc,CAAC;YAEtD,OAAO,MAAM,CAAC;QAClB,CAAC,CAAC,CAAA;IACN,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,qDAAqD;QACrD,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAE,GAAG,IAAI,CAAC,iBAAiB,EAAE,CAAC;YAC9E,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,uBAAuB;gBACzE,yCAAyC,IAAI,CAAC,iBAAiB,GAAG;gBAClE,aAAa,IAAI,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC;QACxD,CAAC;QAED,oCAAoC;QACpC,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;QACtC,IAAI,CAAC,UAAU,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;QAE1E,4EAA4E;QAC5E,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,SAAS,CAAC,gBAAgB;YAClC,GAAG,IAAI,CAAC,UAAU,CAAC,gBAAgB;SACtC,CAAC;QAEF,KAAK,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC,CAAC,CAAC;IACjC,CAAC;IAGD;;;OAGG;IACM,kBAAkB,CAAC,UAAiC;QACzD,MAAM,eAAe,GAAG,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;QACtE,MAAM,gBAAgB,GAAG,IAAI,CAAC,UAAU,CAAC,kBAAkB,CAAC,eAAe,CAAC,CAAC;QAE7E,OAAO,gBAAgB,CAAC;IAC5B,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,cAAc,EAAE,IAAI,CAAC,cAAc;YACnC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;YACzC,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,2BAA2B,CAAC,CAAC"}
@@ -0,0 +1,2 @@
1
+ export {};
2
+ //# sourceMappingURL=token_and_positional_embedding.test.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"token_and_positional_embedding.test.d.ts","sourceRoot":"","sources":["../../../src/layers/token_and_positional_embedding.test.ts"],"names":[],"mappings":""}