@stellarapp/tfjs-stellar 1.0.3 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +17 -0
  2. package/dist/index.d.ts +3 -1
  3. package/dist/index.d.ts.map +1 -1
  4. package/dist/index.js +3 -1
  5. package/dist/index.js.map +1 -1
  6. package/dist/kv_cache.d.ts +2 -0
  7. package/dist/kv_cache.d.ts.map +1 -1
  8. package/dist/kv_cache.js +6 -0
  9. package/dist/kv_cache.js.map +1 -1
  10. package/dist/models/index.d.ts +2 -1
  11. package/dist/models/index.d.ts.map +1 -1
  12. package/dist/models/index.js +2 -1
  13. package/dist/models/index.js.map +1 -1
  14. package/package.json +1 -1
  15. package/dist/jest.config.d.ts +0 -8
  16. package/dist/jest.config.d.ts.map +0 -1
  17. package/dist/jest.config.js +0 -147
  18. package/dist/jest.config.js.map +0 -1
  19. package/dist/src/index.d.ts +0 -6
  20. package/dist/src/index.d.ts.map +0 -1
  21. package/dist/src/index.js +0 -6
  22. package/dist/src/index.js.map +0 -1
  23. package/dist/src/kv_cache.d.ts +0 -53
  24. package/dist/src/kv_cache.d.ts.map +0 -1
  25. package/dist/src/kv_cache.js +0 -135
  26. package/dist/src/kv_cache.js.map +0 -1
  27. package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
  28. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
  29. package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
  30. package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
  31. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
  32. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
  33. package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
  34. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
  35. package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
  36. package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
  37. package/dist/src/layers/gpt_decoder_block.js +0 -51
  38. package/dist/src/layers/gpt_decoder_block.js.map +0 -1
  39. package/dist/src/layers/index.d.ts +0 -17
  40. package/dist/src/layers/index.d.ts.map +0 -1
  41. package/dist/src/layers/index.js +0 -33
  42. package/dist/src/layers/index.js.map +0 -1
  43. package/dist/src/layers/multihead_attention.d.ts +0 -106
  44. package/dist/src/layers/multihead_attention.d.ts.map +0 -1
  45. package/dist/src/layers/multihead_attention.js +0 -269
  46. package/dist/src/layers/multihead_attention.js.map +0 -1
  47. package/dist/src/layers/multihead_attention.test.d.ts +0 -2
  48. package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
  49. package/dist/src/layers/multihead_attention.test.js +0 -160
  50. package/dist/src/layers/multihead_attention.test.js.map +0 -1
  51. package/dist/src/layers/positional_encoding.d.ts +0 -37
  52. package/dist/src/layers/positional_encoding.d.ts.map +0 -1
  53. package/dist/src/layers/positional_encoding.js +0 -115
  54. package/dist/src/layers/positional_encoding.js.map +0 -1
  55. package/dist/src/layers/positional_encoding.test.d.ts +0 -2
  56. package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
  57. package/dist/src/layers/positional_encoding.test.js +0 -95
  58. package/dist/src/layers/positional_encoding.test.js.map +0 -1
  59. package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
  60. package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
  61. package/dist/src/layers/rotary_position_embedding.js +0 -99
  62. package/dist/src/layers/rotary_position_embedding.js.map +0 -1
  63. package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
  64. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
  65. package/dist/src/layers/rotary_position_embedding.test.js +0 -88
  66. package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
  67. package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
  68. package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
  69. package/dist/src/layers/token_and_positional_embedding.js +0 -109
  70. package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
  71. package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
  72. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
  73. package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
  74. package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
  75. package/dist/src/layers/transformer_decoder.d.ts +0 -69
  76. package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
  77. package/dist/src/layers/transformer_decoder.js +0 -182
  78. package/dist/src/layers/transformer_decoder.js.map +0 -1
  79. package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
  80. package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
  81. package/dist/src/layers/transformer_decoder.test.js +0 -72
  82. package/dist/src/layers/transformer_decoder.test.js.map +0 -1
  83. package/dist/src/layers/transformer_encoder.d.ts +0 -55
  84. package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
  85. package/dist/src/layers/transformer_encoder.js +0 -175
  86. package/dist/src/layers/transformer_encoder.js.map +0 -1
  87. package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
  88. package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
  89. package/dist/src/layers/transformer_encoder.test.js +0 -58
  90. package/dist/src/layers/transformer_encoder.test.js.map +0 -1
  91. package/dist/src/losses/dice.d.ts +0 -30
  92. package/dist/src/losses/dice.d.ts.map +0 -1
  93. package/dist/src/losses/dice.js +0 -93
  94. package/dist/src/losses/dice.js.map +0 -1
  95. package/dist/src/losses/index.d.ts +0 -2
  96. package/dist/src/losses/index.d.ts.map +0 -1
  97. package/dist/src/losses/index.js +0 -2
  98. package/dist/src/losses/index.js.map +0 -1
  99. package/dist/src/masks.d.ts +0 -20
  100. package/dist/src/masks.d.ts.map +0 -1
  101. package/dist/src/masks.js +0 -37
  102. package/dist/src/masks.js.map +0 -1
  103. package/dist/src/metrics.d.ts +0 -20
  104. package/dist/src/metrics.d.ts.map +0 -1
  105. package/dist/src/metrics.js +0 -28
  106. package/dist/src/metrics.js.map +0 -1
  107. package/dist/src/models/gpt_model.d.ts +0 -94
  108. package/dist/src/models/gpt_model.d.ts.map +0 -1
  109. package/dist/src/models/gpt_model.js +0 -154
  110. package/dist/src/models/gpt_model.js.map +0 -1
  111. package/dist/src/models/index.d.ts +0 -3
  112. package/dist/src/models/index.d.ts.map +0 -1
  113. package/dist/src/models/index.js +0 -3
  114. package/dist/src/models/index.js.map +0 -1
  115. package/dist/src/models/llm_model.d.ts +0 -87
  116. package/dist/src/models/llm_model.d.ts.map +0 -1
  117. package/dist/src/models/llm_model.js +0 -245
  118. package/dist/src/models/llm_model.js.map +0 -1
  119. package/dist/src/models/u_net.d.ts +0 -40
  120. package/dist/src/models/u_net.d.ts.map +0 -1
  121. package/dist/src/models/u_net.js +0 -151
  122. package/dist/src/models/u_net.js.map +0 -1
  123. package/dist/src/tfjs_types.d.ts +0 -10
  124. package/dist/src/tfjs_types.d.ts.map +0 -1
  125. package/dist/src/tfjs_types.js +0 -2
  126. package/dist/src/tfjs_types.js.map +0 -1
  127. package/dist/src/utils.d.ts +0 -28
  128. package/dist/src/utils.d.ts.map +0 -1
  129. package/dist/src/utils.js +0 -63
  130. package/dist/src/utils.js.map +0 -1
  131. package/dist/src/utils.test.d.ts +0 -2
  132. package/dist/src/utils.test.d.ts.map +0 -1
  133. package/dist/src/utils.test.js +0 -73
  134. package/dist/src/utils.test.js.map +0 -1
@@ -1 +0,0 @@
1
- {"version":3,"file":"gpt_model.d.ts","sourceRoot":"","sources":["../../../src/models/gpt_model.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,cAAc,EAAE,MAAM,cAAc,CAAC;AACnD,OAAO,EAAE,QAAQ,EAAE,KAAK,YAAY,EAAE,MAAM,oBAAoB,CAAC;AACjE,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAC9C,OAAO,EAAE,KAAK,aAAa,EAAE,MAAM,8CAA8C,CAAC;AAIlF,MAAM,WAAW,YAAa,SAAQ,YAAY;IAC9C;;OAEG;IACH,QAAQ,EAAE,MAAM,CAAC;IACjB;;OAEG;IACH,SAAS,EAAE,MAAM,CAAC;IAClB;;OAEG;IACH,QAAQ,EAAE,MAAM,CAAC;IACjB;;;OAGG;IACH,SAAS,EAAE,MAAM,CAAC;IAClB;;;;;;OAMG;IACH,iBAAiB,CAAC,EAAE,OAAO,CAAC;CAC/B;AAGD;;;;;;;;;;;;;;;;;;;;;;GAsBG;AACH,qBAAa,QAAS,SAAQ,QAAQ;IAClC,MAAM,CAAC,SAAS,SAAc;IAE9B,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,SAAS,EAAE,MAAM,CAAC;IACrC,SAAS,CAAC,QAAQ,CAAC,QAAQ,EAAE,MAAM,CAAC;IACpC,SAAS,CAAC,QAAQ,CAAC,SAAS,EAAE,MAAM,CAAC;IACrC,SAAS,CAAC,QAAQ,CAAC,iBAAiB,EAAE,OAAO,CAAC;IAI9C,SAAS,CAAC,QAAQ,CAAC,eAAe,EAAE,MAAM,CAAC;IAG3C,SAAS,CAAC,kBAAkB,CAAC,EAAE,EAAE,CAAC,QAAQ,CAAC;IAG3C;;;OAGG;gBACS,IAAI,EAAE,YAAY;cAgBX,QAAQ,CACvB,EAAE,EAAE,EAAE,CAAC,MAAM,EACb,EAAE,EAAE,EAAE,CAAC,MAAM,EACb,SAAS,EAAE,EAAE,CAAC,MAAM,GAAG,SAAS,EAChC,aAAa,EAAE,cAAc,EAC7B,WAAW,CAAC,EAAE;QAAE,CAAC,GAAG,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,SAAS,CAAA;KAAE;;;;IAsC1D;;;;;OAKG;IACM,gBAAgB,CAAC,KAAK,EAAE,EAAE,CAAC,QAAQ,EAAE,QAAQ,EAAE,gBAAgB,GAAG,EAAE,CAAC,QAAQ;IA4B7E,KAAK,CAAC,UAAU,CAAC,EAAE,EAAE,CAAC,KAAK,GAAG,EAAE,CAAC,KAAK,EAAE,GAAG,IAAI;IA+B/C,OAAO,IAAI,aAAa;IAMxB,SAAS;;;;;;;;CAiBrB"}
@@ -1,154 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { LlmModel } from "@/models/llm_model";
3
- import { GPT2DecoderBlock } from "@/layers/gpt_decoder_block";
4
- /**
5
- * This is a subclass of tf.Sequential that creating a GPT-like model and
6
- * automatically handles padding (and masking) the vocab size for hardware
7
- * efficiency.
8
- *
9
- * Example:
10
- *
11
- * ```javascript
12
- *
13
- * const model = new GptModel({ numLayers: 1, numHeads: 1, embedDim: 16, vocabSize: 64 });
14
- * model.compile({ loss: "sparseCategoricalCrossentropy", optimizer: "adam" });
15
- *
16
- * // use fitDataset() instead of fit for masking support
17
- * model.fitDataset(your_batched_generator_dataset, { epochs: 1 });
18
- *
19
- * const kv_cache = new KvCacheContainer(your_preferred_max_sequence_length);
20
- *
21
- * // use generate() and predictNextToken() instead of predict() for masking and auto memory cleanup
22
- * model.generate(tokenized_tensor1d_input, kv_cache, onPredict_callback)
23
- *
24
- *
25
- * ```
26
- */
27
- export class GptModel extends LlmModel {
28
- static className = "GptModel";
29
- numHeads;
30
- numLayers;
31
- embedDim;
32
- vocabSize;
33
- padToMultipleOf64;
34
- // this is kept for reproducibility and model history but is not important since
35
- // it can be calculated mathematically
36
- vocabSizePadded;
37
- // the amount to pad the embedding vocab size and dense output units count
38
- vocab_padding_mask;
39
- /**
40
- * DO NOT add layers in the constructor or it will break tf.loadLayersModel().
41
- * It should be done in build() instead.
42
- */
43
- constructor(args) {
44
- const { numHeads, numLayers, embedDim, vocabSize, padToMultipleOf64 = true, ...rest } = args;
45
- super({ name: "model", ...rest });
46
- this.numHeads = numHeads;
47
- this.numLayers = numLayers;
48
- this.embedDim = embedDim;
49
- this.vocabSize = vocabSize;
50
- this.padToMultipleOf64 = padToMultipleOf64;
51
- this.vocabSizePadded = this.padToMultipleOf64
52
- ? Math.ceil(this.vocabSize / 64) * 64
53
- : this.vocabSize;
54
- }
55
- fitBatch(xs, ys, loss_mask, loss_function, other_masks) {
56
- let y_pred;
57
- // forward pass, calculate loss
58
- const { value: loss, grads } = tf.variableGrads(() => {
59
- y_pred = this.apply(xs, {
60
- training: true,
61
- ...other_masks
62
- });
63
- // apply vocab pad masking
64
- if (this.vocab_padding_mask) {
65
- y_pred = y_pred.add(this.vocab_padding_mask);
66
- }
67
- y_pred = tf.softmax(y_pred);
68
- // manually dispose later instead of the built-in disposal from variableGrads
69
- tf.keep(y_pred);
70
- const loss = loss_mask
71
- ? loss_function(ys, y_pred).mul(loss_mask)
72
- : loss_function(ys, y_pred);
73
- return loss.mean();
74
- });
75
- // backpropagation
76
- this.optimizer.applyGradients(grads);
77
- return {
78
- y_pred: y_pred,
79
- loss
80
- };
81
- }
82
- /**
83
- * Overrides LlmModel.predictNextToken to add softmax before argMax because the final
84
- * dense layer doesn't have an activation.
85
- *
86
- * TODO: implement temperature and multinomial sampling so that the model has varied outputs
87
- */
88
- predictNextToken(input, kv_cache) {
89
- if (input.shape[0] != 1) {
90
- throw Error(`GptModel.predictNextToken: ${this.name} expects an input with a batch size of 1`);
91
- }
92
- return tf.tidy(() => {
93
- // comes back as [batch, sequence_length, vocab_size]
94
- const prediction = this.apply(input, { kvCache: kv_cache });
95
- const [batch_size, sequence_length, vocab_size] = prediction.shape;
96
- // get the last token
97
- const next_token = this.vocab_padding_mask != undefined
98
- ? prediction
99
- .slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size])
100
- .add(this.vocab_padding_mask)
101
- .softmax()
102
- .argMax(2)
103
- : prediction
104
- .slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size])
105
- .softmax()
106
- .argMax(2);
107
- return next_token;
108
- });
109
- }
110
- build(inputShape) {
111
- const actual_vocab_size = this.vocabSizePadded
112
- ? this.vocabSizePadded
113
- : this.padToMultipleOf64
114
- ? Math.ceil(this.vocabSize / 64) * 64
115
- : this.vocabSize;
116
- if (this.layers.length == 0) {
117
- [
118
- tf.layers.embedding({ inputDim: actual_vocab_size, outputDim: this.embedDim, batchInputShape: [null, null] }),
119
- ...Array(this.numLayers).fill(0).map(_ => new GPT2DecoderBlock({ numHeads: this.numHeads, embedDim: this.embedDim })),
120
- tf.layers.dense({ units: actual_vocab_size })
121
- ].forEach(layer => this.add(layer));
122
- }
123
- if (this.vocab_padding_mask) {
124
- this.vocab_padding_mask.dispose();
125
- }
126
- if (this.padToMultipleOf64 && actual_vocab_size > this.vocabSize) {
127
- this.vocab_padding_mask = tf.tidy(() => tf.where(
128
- // Create a mask of padded vocab length, values after the index "vocabSize"
129
- // are set to -1e7 to mask out those positions so that softmax will ignore
130
- // them. This mask is added to the final dense layer's output
131
- tf.range(0, actual_vocab_size).greaterEqual(this.vocabSize), -1e7, 0).toFloat());
132
- }
133
- super.build(inputShape);
134
- }
135
- dispose() {
136
- this.vocab_padding_mask?.dispose();
137
- return super.dispose();
138
- }
139
- getConfig() {
140
- const base_config = super.getConfig();
141
- const config = {
142
- numHeads: this.numHeads,
143
- numLayers: this.numLayers,
144
- embedDim: this.embedDim,
145
- vocabSize: this.vocabSize,
146
- vocabSizePadded: this.vocabSizePadded,
147
- padToMultipleOf64: this.padToMultipleOf64
148
- };
149
- Object.assign(config, base_config);
150
- return config;
151
- }
152
- }
153
- tf.serialization.registerClass(GptModel);
154
- //# sourceMappingURL=gpt_model.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"gpt_model.js","sourceRoot":"","sources":["../../../src/models/gpt_model.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,QAAQ,EAAqB,MAAM,oBAAoB,CAAC;AAGjE,OAAO,EAAE,gBAAgB,EAAE,MAAM,4BAA4B,CAAC;AAgC9D;;;;;;;;;;;;;;;;;;;;;;GAsBG;AACH,MAAM,OAAO,QAAS,SAAQ,QAAQ;IAClC,MAAM,CAAC,SAAS,GAAG,UAAU,CAAC;IAEX,QAAQ,CAAS;IACjB,SAAS,CAAS;IAClB,QAAQ,CAAS;IACjB,SAAS,CAAS;IAClB,iBAAiB,CAAU;IAE9C,gFAAgF;IAChF,sCAAsC;IACnB,eAAe,CAAS;IAE3C,0EAA0E;IAChE,kBAAkB,CAAe;IAG3C;;;OAGG;IACH,YAAY,IAAkB;QAC1B,MAAM,EAAE,QAAQ,EAAE,SAAS,EAAE,QAAQ,EAAE,SAAS,EAAE,iBAAiB,GAAG,IAAI,EAAE,GAAG,IAAI,EAAE,GAAG,IAAI,CAAC;QAE7F,KAAK,CAAC,EAAE,IAAI,EAAE,OAAO,EAAE,GAAG,IAAI,EAAE,CAAC,CAAC;QAElC,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAC3B,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,SAAS,GAAG,SAAS,CAAC;QAC3B,IAAI,CAAC,iBAAiB,GAAG,iBAAiB,CAAC;QAC3C,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC,iBAAiB;YACzC,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,GAAG,EAAE;YACrC,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC;IACzB,CAAC;IAGkB,QAAQ,CACvB,EAAa,EACb,EAAa,EACb,SAAgC,EAChC,aAA6B,EAC7B,WAAsD;QAEtD,IAAI,MAAiB,CAAC;QAEtB,+BAA+B;QAC/B,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,CAAC,aAAa,CAAC,GAAG,EAAE;YACjD,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,EAAE,EAAE;gBACpB,QAAQ,EAAE,IAAI;gBACd,GAAG,WAAW;aACjB,CAAc,CAAC;YAEhB,0BAA0B;YAC1B,IAAI,IAAI,CAAC,kBAAkB,EAAE,CAAC;gBAC1B,MAAM,GAAG,MAAM,CAAC,GAAG,CAAC,IAAI,CAAC,kBAAkB,CAAC,CAAC;YACjD,CAAC;YAED,MAAM,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC;YAE5B,6EAA6E;YAC7E,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAEhB,MAAM,IAAI,GAAG,SAAS;gBAClB,CAAC,CAAC,aAAa,CAAC,EAAE,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC;gBAC1C,CAAC,CAAC,aAAa,CAAC,EAAE,EAAE,MAAM,CAAC,CAAC;YAEhC,OAAO,IAAI,CAAC,IAAI,EAAe,CAAC;QACpC,CAAC,CAAC,CAAC;QAEH,kBAAkB;QAClB,IAAI,CAAC,SAAS,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC;QAErC,OAAO;YACH,MAAM,EAAE,MAAO;YACf,IAAI;SACP,CAAC;IACN,CAAC;IAGD;;;;;OAKG;IACM,gBAAgB,CAAC,KAAkB,EAAE,QAA0B;QACpE,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC;YACtB,MAAM,KAAK,CAAC,8BAA8B,IAAI,CAAC,IAAI,0CAA0C,CAAC,CAAC;QACnG,CAAC;QAED,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,qDAAqD;YACrD,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC;YAEzE,MAAM,CAAC,UAAU,EAAE,eAAe,EAAE,UAAU,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC;YAEnE,qBAAqB;YACrB,MAAM,UAAU,GAAG,IAAI,CAAC,kBAAkB,IAAI,SAAS;gBACnD,CAAC,CAAC,UAAU;qBACP,KAAK,CAAC,CAAC,CAAC,EAAE,eAAe,GAAG,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,UAAU,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC;qBAC/D,GAAG,CAAC,IAAI,CAAC,kBAAkB,CAAC;qBAC5B,OAAO,EAAE;qBACT,MAAM,CAAC,CAAC,CAAC;gBACd,CAAC,CAAC,UAAU;qBACP,KAAK,CAAC,CAAC,CAAC,EAAE,eAAe,GAAG,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,UAAU,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC;qBAC/D,OAAO,EAAE;qBACT,MAAM,CAAC,CAAC,CAAC,CAAC;YAEnB,OAAO,UAAyB,CAAC;QACrC,CAAC,CAAC,CAAA;IACN,CAAC;IAGQ,KAAK,CAAC,UAAkC;QAC7C,MAAM,iBAAiB,GAAG,IAAI,CAAC,eAAe;YAC1C,CAAC,CAAC,IAAI,CAAC,eAAe;YACtB,CAAC,CAAC,IAAI,CAAC,iBAAiB;gBACpB,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,GAAG,EAAE;gBACrC,CAAC,CAAC,IAAI,CAAC,SAAS,CAAA;QAExB,IAAI,IAAI,CAAC,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC1B;gBACI,EAAE,CAAC,MAAM,CAAC,SAAS,CAAC,EAAE,QAAQ,EAAE,iBAAiB,EAAE,SAAS,EAAE,IAAI,CAAC,QAAQ,EAAE,eAAe,EAAE,CAAC,IAAI,EAAE,IAAI,CAAC,EAAE,CAAC;gBAC7G,GAAG,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,IAAI,gBAAgB,CAAC,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC;gBACrH,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,EAAE,KAAK,EAAE,iBAAiB,EAAE,CAAC;aAChD,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,KAAK,CAAC,CAAC,CAAA;QACvC,CAAC;QAED,IAAI,IAAI,CAAC,kBAAkB,EAAE,CAAC;YAC1B,IAAI,CAAC,kBAAkB,CAAC,OAAO,EAAE,CAAC;QACtC,CAAC;QAED,IAAI,IAAI,CAAC,iBAAiB,IAAI,iBAAiB,GAAG,IAAI,CAAC,SAAS,EAAE,CAAC;YAC/D,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC,KAAK;YAC5C,2EAA2E;YAC3E,0EAA0E;YAC1E,6DAA6D;YAC7D,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,iBAAiB,CAAC,CAAC,YAAY,CAAC,IAAI,CAAC,SAAS,CAAC,EAAE,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,CAAA;QACxF,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGQ,OAAO;QACZ,IAAI,CAAC,kBAAkB,EAAE,OAAO,EAAE,CAAC;QACnC,OAAO,KAAK,CAAC,OAAO,EAAE,CAAC;IAC3B,CAAC;IAGQ,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,SAAS,EAAE,IAAI,CAAC,SAAS;YACzB,eAAe,EAAE,IAAI,CAAC,eAAe;YACrC,iBAAiB,EAAE,IAAI,CAAC,iBAAiB;SAC5C,CAAA;QAED,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,QAAQ,CAAC,CAAC"}
@@ -1,3 +0,0 @@
1
- export * from "@/models/gpt_model";
2
- export * from "@/models/u_net";
3
- //# sourceMappingURL=index.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/models/index.ts"],"names":[],"mappings":"AAAA,cAAc,oBAAoB,CAAC;AACnC,cAAc,gBAAgB,CAAC"}
@@ -1,3 +0,0 @@
1
- export * from "@/models/gpt_model";
2
- export * from "@/models/u_net";
3
- //# sourceMappingURL=index.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.js","sourceRoot":"","sources":["../../../src/models/index.ts"],"names":[],"mappings":"AAAA,cAAc,oBAAoB,CAAC;AACnC,cAAc,gBAAgB,CAAC"}
@@ -1,87 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { Dataset, type LossOrMetricFn } from "@/tfjs_types";
3
- import { KvCacheContainer } from "@/kv_cache";
4
- export interface LlmModelArgs extends tf.SequentialArgs {
5
- }
6
- interface DatasetArgs extends tf.TensorContainerObject {
7
- xs: tf.Tensor;
8
- ys: tf.Tensor;
9
- loss_mask?: tf.Tensor;
10
- packing_mask?: tf.Tensor;
11
- }
12
- /**
13
- * This class overrides the `fitDataset()` function of tf.Sequential to support loss
14
- * and packing masking. Use the `generate()` function to autoregressively predict the
15
- * next, set `stopPredicting=true` to stop.
16
- */
17
- export declare class LlmModel extends tf.Sequential {
18
- static className: string;
19
- private stopPredicting_;
20
- constructor(args: LlmModelArgs);
21
- /**
22
- * Returns the metric functions and names so that metrics can be reported
23
- * as they are in the base version of model.fitDataset
24
- *
25
- * e.g. "categoricalAccuracy" should be reported as "acc"
26
- */
27
- protected getMetricFunctions(): {
28
- metric_fn: import("@tensorflow/tfjs-layers/dist/types").LossOrMetricFn;
29
- metric_label: string;
30
- }[];
31
- /**
32
- * Get exactly one loss function from the loss function provided in `model.compile()`.
33
- * If a string identifier was used, convert it to the actual loss function.
34
- */
35
- protected getLossFunction(): LossOrMetricFn;
36
- /**
37
- * Train on a `tf.data.generator` dataset. See https://js.tensorflow.org/api/latest/#data.generator.
38
- *
39
- * The generator should yield `xs`, `ys`, `loss_mask` (if fine-tuning), and
40
- * `packing_mask` (if sequence packing was done)
41
- *
42
- * @param tfdataset an instance of a `tf.Dataset` generator
43
- * @param args a ModelFitDatasetArgs
44
- */
45
- fitDataset<T = DatasetArgs>(tfdataset: Dataset<T>, args: tf.ModelFitDatasetArgs<T>): Promise<any>;
46
- /**
47
- * Run the core forward and backward propagation on one training batch. This
48
- * should be called within a `tf.tidy()`.
49
- *
50
- * @param xs the sample/input tensor
51
- * @param ys the label/target tensor
52
- * @param loss_mask a loss mask to ignore the prediction's non-assistant tokens
53
- * @param loss_function the model's loss function
54
- * @param other_masks other masks used by the model's layers e.g. packing mask, causal mask
55
- */
56
- protected fitBatch(xs: tf.Tensor, ys: tf.Tensor, loss_mask: tf.Tensor | undefined, loss_function: LossOrMetricFn, other_masks?: {
57
- [key: string]: tf.Tensor | undefined;
58
- }): {
59
- y_pred: tf.Tensor<tf.Rank>;
60
- loss: tf.Scalar;
61
- };
62
- compile(args: tf.ModelCompileArgs): void;
63
- /**
64
- * Autoregressively generate the next token until `model.stopPredicting` is set
65
- * to `true` or the KV cache reaches its maximum sequence length. For a single chat
66
- * session, the input should only be the most recent prompt(s). The KV cache stores
67
- * the prior chat history up until the most recent chat.
68
- *
69
- * @param input tokenized input of the newest chat
70
- * @param kv_cache an instance of a KV cache container
71
- * @param onPredict callback function to receive the most recent token predicted
72
- */
73
- generate(input: tf.Tensor1D, kv_cache: KvCacheContainer, onPredict: (token: tf.Tensor) => Promise<void>): Promise<void>;
74
- /**
75
- * Given a tokenized sentence, predict the next token (word).
76
- * A normal prediction is ran to get an output with the shape
77
- * `[ batch_size, sentence_length, vocab_size ]` and the `vocab_size`
78
- * position with the highest scored probability in the last
79
- * position of `sentence_length` is returned as the next predicted
80
- * token.
81
- */
82
- predictNextToken(input: tf.Tensor2D, kv_cache: KvCacheContainer): tf.Tensor2D;
83
- get stopPredicting(): boolean;
84
- set stopPredicting(stop: boolean);
85
- }
86
- export {};
87
- //# sourceMappingURL=llm_model.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"llm_model.d.ts","sourceRoot":"","sources":["../../../src/models/llm_model.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,OAAO,EAAE,KAAK,cAAc,EAAE,MAAM,cAAc,CAAC;AAE5D,OAAO,EAAE,gBAAgB,EAAE,MAAM,YAAY,CAAC;AAK9C,MAAM,WAAW,YAAa,SAAQ,EAAE,CAAC,cAAc;CACtD;AAGD,UAAU,WAAY,SAAQ,EAAE,CAAC,qBAAqB;IAClD,EAAE,EAAE,EAAE,CAAC,MAAM,CAAC;IACd,EAAE,EAAE,EAAE,CAAC,MAAM,CAAC;IACd,SAAS,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;IACtB,YAAY,CAAC,EAAE,EAAE,CAAC,MAAM,CAAC;CAC5B;AAGD;;;;GAIG;AACH,qBAAa,QAAS,SAAQ,EAAE,CAAC,UAAU;IACvC,MAAM,CAAC,SAAS,SAAc;IAE9B,OAAO,CAAC,eAAe,CAAiB;gBAE5B,IAAI,EAAE,YAAY;IAM9B;;;;;OAKG;IACH,SAAS,CAAC,kBAAkB;;;;IAU5B;;;OAGG;IACH,SAAS,CAAC,eAAe,IAAI,cAAc;IAoC3C;;;;;;;;OAQG;IACY,UAAU,CAAC,CAAC,GAAG,WAAW,EAAE,SAAS,EAAE,OAAO,CAAC,CAAC,CAAC,EAAE,IAAI,EAAE,EAAE,CAAC,mBAAmB,CAAC,CAAC,CAAC,GAAG,OAAO,CAAC,GAAG,CAAC;IAqHhH;;;;;;;;;OASG;IACH,SAAS,CAAC,QAAQ,CACd,EAAE,EAAE,EAAE,CAAC,MAAM,EACb,EAAE,EAAE,EAAE,CAAC,MAAM,EACb,SAAS,EAAE,EAAE,CAAC,MAAM,GAAG,SAAS,EAChC,aAAa,EAAE,cAAc,EAC7B,WAAW,CAAC,EAAE;QAAE,CAAC,GAAG,EAAE,MAAM,GAAG,EAAE,CAAC,MAAM,GAAG,SAAS,CAAA;KAAE,GACvD;QACC,MAAM,EAAE,EAAE,CAAC,MAAM,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC;QAC3B,IAAI,EAAE,EAAE,CAAC,MAAM,CAAC;KACnB;IA+BQ,OAAO,CAAC,IAAI,EAAE,EAAE,CAAC,gBAAgB,GAAG,IAAI;IASjD;;;;;;;;;OASG;IACU,QAAQ,CAAC,KAAK,EAAE,EAAE,CAAC,QAAQ,EAAE,QAAQ,EAAE,gBAAgB,EAAE,SAAS,EAAE,CAAC,KAAK,EAAE,EAAE,CAAC,MAAM,KAAK,OAAO,CAAC,IAAI,CAAC;IA2BpH;;;;;;;OAOG;IACI,gBAAgB,CAAC,KAAK,EAAE,EAAE,CAAC,QAAQ,EAAE,QAAQ,EAAE,gBAAgB;IAmBtE,IAAI,cAAc,IAKO,OAAO,CAH/B;IAGD,IAAI,cAAc,CAAC,IAAI,EAAE,OAAO,EAE/B;CAEJ"}
@@ -1,245 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { sparseCategoricalCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
3
- import { causal as generateCausalMask } from "@/masks";
4
- import * as losses from "@/losses";
5
- ;
6
- /**
7
- * This class overrides the `fitDataset()` function of tf.Sequential to support loss
8
- * and packing masking. Use the `generate()` function to autoregressively predict the
9
- * next, set `stopPredicting=true` to stop.
10
- */
11
- export class LlmModel extends tf.Sequential {
12
- static className = "LlmModel";
13
- stopPredicting_ = true;
14
- constructor(args) {
15
- args.name = args.name ?? "model";
16
- super(args);
17
- }
18
- /**
19
- * Returns the metric functions and names so that metrics can be reported
20
- * as they are in the base version of model.fitDataset
21
- *
22
- * e.g. "categoricalAccuracy" should be reported as "acc"
23
- */
24
- getMetricFunctions() {
25
- const [loss, ...metric_fn_names] = this.metricsNames;
26
- return this.metricsTensors.map((metric_tensor, index) => ({
27
- metric_fn: metric_tensor[0],
28
- metric_label: metric_fn_names[index]
29
- }));
30
- }
31
- /**
32
- * Get exactly one loss function from the loss function provided in `model.compile()`.
33
- * If a string identifier was used, convert it to the actual loss function.
34
- */
35
- getLossFunction() {
36
- let loss = this.loss;
37
- if (Array.isArray(loss)) {
38
- loss = loss[0];
39
- }
40
- if (typeof loss == "string") {
41
- if (loss == "sparseCategoricalCrossentropy") {
42
- return sparseCategoricalCrossentropy;
43
- /* throw Error("LlmModel.getLossFunction: TFJS's sparseCategoricalCrossentropy" +
44
- " is not truly sparse, it simply converts it to onehot." +
45
- " Use categoricalCrossentropy instead. See" +
46
- " https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146"); */
47
- }
48
- const loss_id = loss;
49
- const loss_fn = (losses[loss_id] ??
50
- tf.losses[loss_id] ??
51
- tf.metrics[loss_id]);
52
- if (loss_fn) {
53
- return loss_fn;
54
- }
55
- else {
56
- throw Error(`LlmModel.getLossFunction: ${loss_id} is not a valid loss function`);
57
- }
58
- }
59
- else if (typeof loss == "function") {
60
- return loss;
61
- }
62
- throw Error("LlmModel.getLossFunction: the loss function's type should be string or function");
63
- }
64
- /**
65
- * Train on a `tf.data.generator` dataset. See https://js.tensorflow.org/api/latest/#data.generator.
66
- *
67
- * The generator should yield `xs`, `ys`, `loss_mask` (if fine-tuning), and
68
- * `packing_mask` (if sequence packing was done)
69
- *
70
- * @param tfdataset an instance of a `tf.Dataset` generator
71
- * @param args a ModelFitDatasetArgs
72
- */
73
- async fitDataset(tfdataset, args) {
74
- this.stopTraining = false;
75
- const dataset = tfdataset;
76
- const { epochs, callbacks } = args;
77
- const metric_functions = this.getMetricFunctions();
78
- const loss_function = this.getLossFunction();
79
- this.lossFunctions = [loss_function];
80
- const { onBatchBegin, onBatchEnd, onEpochBegin, onEpochEnd, onTrainBegin, onTrainEnd, } = callbacks ?? {};
81
- await onTrainBegin?.();
82
- let cached_causal_mask = undefined;
83
- for (let epoch = 0; epoch < epochs; epoch++) {
84
- await onEpochBegin?.(epoch);
85
- let batch = 0;
86
- let total_samples = 0;
87
- const accumulated_epoch_metrics = {};
88
- // loop through dataset using its iterator
89
- const iterator = await dataset.iterator();
90
- let sample = await iterator.next();
91
- while (!sample.done) {
92
- const batch_metrics = { batch };
93
- const { xs, ys, loss_mask, packing_mask } = sample.value;
94
- const batch_size = xs.shape[0];
95
- total_samples += batch_size; // for epoch metrics averaging
96
- if (xs.shape.length != 2) {
97
- throw Error(`LlmModel.fitDataset: ${this.name} the generator dataset should be batched, run: dataset.batch(batch_size)`);
98
- }
99
- // pre-calculate the causal attention mask and reuse it for all attention layers,
100
- const seq_length = xs.shape[xs.shape.length - 1];
101
- if (!cached_causal_mask || cached_causal_mask.shape[0] != seq_length) {
102
- cached_causal_mask = generateCausalMask(seq_length, seq_length);
103
- }
104
- await onBatchBegin?.(batch);
105
- tf.tidy(() => {
106
- const { y_pred, loss } = this.fitBatch(xs, ys, loss_mask, loss_function, {
107
- packingMask: packing_mask,
108
- causalMask: cached_causal_mask
109
- });
110
- const loss_value = (loss.dataSync())[0];
111
- batch_metrics.loss = loss_value;
112
- accumulated_epoch_metrics.loss = (accumulated_epoch_metrics.loss || 0) + loss_value * batch_size;
113
- // calculate and store metrics
114
- for (const { metric_fn, metric_label } of metric_functions) {
115
- const metric_sum = metric_fn(ys, y_pred).mean();
116
- const metric_value = (metric_sum.dataSync())[0];
117
- batch_metrics[metric_label] = metric_value; // / batch_size;
118
- accumulated_epoch_metrics[metric_label] = (accumulated_epoch_metrics[metric_label] || 0) + metric_value * batch_size;
119
- }
120
- tf.dispose(y_pred);
121
- });
122
- tf.dispose(xs);
123
- tf.dispose(ys);
124
- tf.dispose(loss_mask);
125
- if (packing_mask) {
126
- tf.dispose(packing_mask);
127
- }
128
- await onBatchEnd?.(batch, batch_metrics);
129
- // so that stop training works
130
- await tf.nextFrame();
131
- if (this.stopTraining) {
132
- break;
133
- }
134
- sample = await iterator.next();
135
- batch++;
136
- }
137
- for (const metric in accumulated_epoch_metrics) {
138
- accumulated_epoch_metrics[metric] = accumulated_epoch_metrics[metric] / total_samples;
139
- }
140
- await onEpochEnd?.(epoch, accumulated_epoch_metrics);
141
- if (this.stopTraining) {
142
- break;
143
- }
144
- }
145
- tf.dispose(cached_causal_mask);
146
- await onTrainEnd?.();
147
- return {};
148
- }
149
- /**
150
- * Run the core forward and backward propagation on one training batch. This
151
- * should be called within a `tf.tidy()`.
152
- *
153
- * @param xs the sample/input tensor
154
- * @param ys the label/target tensor
155
- * @param loss_mask a loss mask to ignore the prediction's non-assistant tokens
156
- * @param loss_function the model's loss function
157
- * @param other_masks other masks used by the model's layers e.g. packing mask, causal mask
158
- */
159
- fitBatch(xs, ys, loss_mask, loss_function, other_masks) {
160
- let y_pred;
161
- // forward pass, calculate loss
162
- const { value: loss, grads } = tf.variableGrads(() => {
163
- // prediction has shape [batch, sequence_length, vocab_size]
164
- y_pred = this.apply(xs, {
165
- training: true,
166
- ...other_masks
167
- });
168
- // manually dispose later instead of the built-in disposal from variableGrads
169
- tf.keep(y_pred);
170
- const loss = loss_mask
171
- ? loss_function(ys, y_pred).mul(loss_mask)
172
- : loss_function(ys, y_pred);
173
- return loss.mean();
174
- });
175
- // backpropagation
176
- this.optimizer.applyGradients(grads);
177
- return {
178
- y_pred: y_pred,
179
- loss
180
- };
181
- }
182
- compile(args) {
183
- if (args.loss == "categoricalCrossentropy") {
184
- throw Error(`LlmModel.compile: use sparseCategoricalCrossentropy loss (along with onehot encoded labels) instead of categoricalCrossEntropy`);
185
- }
186
- super.compile(args);
187
- }
188
- /**
189
- * Autoregressively generate the next token until `model.stopPredicting` is set
190
- * to `true` or the KV cache reaches its maximum sequence length. For a single chat
191
- * session, the input should only be the most recent prompt(s). The KV cache stores
192
- * the prior chat history up until the most recent chat.
193
- *
194
- * @param input tokenized input of the newest chat
195
- * @param kv_cache an instance of a KV cache container
196
- * @param onPredict callback function to receive the most recent token predicted
197
- */
198
- async generate(input, kv_cache, onPredict) {
199
- if (kv_cache.size >= kv_cache.maxSequenceLength) {
200
- throw Error(`LlmModel.generate: ${this.name} KV cache's size reached the maxSequenceLength (${kv_cache.maxSequenceLength})`);
201
- }
202
- this.stopPredicting = false;
203
- let current_token = tf.tidy(() => input.expandDims(0)); // it's 2D because of the required batch dimension
204
- while (!this.stopPredicting && kv_cache.size < kv_cache.maxSequenceLength) {
205
- // add a batch dimension because forward pass requires inputs batched
206
- const next_token = tf.tidy(() => this.predictNextToken(current_token, kv_cache));
207
- // pass back the predicted token, without the batch dim,
208
- const unbatched_next_token = tf.tidy(() => next_token.squeeze([0]));
209
- await onPredict(unbatched_next_token);
210
- unbatched_next_token.dispose();
211
- current_token.dispose();
212
- current_token = next_token;
213
- }
214
- tf.dispose(current_token);
215
- }
216
- /**
217
- * Given a tokenized sentence, predict the next token (word).
218
- * A normal prediction is ran to get an output with the shape
219
- * `[ batch_size, sentence_length, vocab_size ]` and the `vocab_size`
220
- * position with the highest scored probability in the last
221
- * position of `sentence_length` is returned as the next predicted
222
- * token.
223
- */
224
- predictNextToken(input, kv_cache) {
225
- if (input.shape[0] != 1) {
226
- throw Error(`LlmModel.predictNextToken: ${this.name} expects an input with a batch size of 1`);
227
- }
228
- return tf.tidy(() => {
229
- // comes back as [batch, sequence_length, vocab_size]
230
- const prediction = this.apply(input, { kvCache: kv_cache });
231
- const [batch_size, sequence_length, vocab_size] = prediction.shape;
232
- // get the last token
233
- const next_token = prediction.slice([0, sequence_length - 1, 0], [batch_size, 1, vocab_size]).argMax(2);
234
- return next_token;
235
- });
236
- }
237
- get stopPredicting() {
238
- return this.stopPredicting_;
239
- }
240
- set stopPredicting(stop) {
241
- this.stopPredicting_ = stop;
242
- }
243
- }
244
- tf.serialization.registerClass(LlmModel);
245
- //# sourceMappingURL=llm_model.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"llm_model.js","sourceRoot":"","sources":["../../../src/models/llm_model.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,6BAA6B,EAAE,MAAM,qCAAqC,CAAC;AAEpF,OAAO,EAAE,MAAM,IAAI,kBAAkB,EAAE,MAAM,SAAS,CAAC;AAEvD,OAAO,KAAK,MAAM,MAAM,UAAU,CAAC;AAKlC,CAAC;AAWF;;;;GAIG;AACH,MAAM,OAAO,QAAS,SAAQ,EAAE,CAAC,UAAU;IACvC,MAAM,CAAC,SAAS,GAAG,UAAU,CAAC;IAEtB,eAAe,GAAY,IAAI,CAAC;IAExC,YAAY,IAAkB;QAC1B,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,IAAI,IAAI,OAAO,CAAC;QACjC,KAAK,CAAC,IAAI,CAAC,CAAC;IAChB,CAAC;IAGD;;;;;OAKG;IACO,kBAAkB;QACxB,MAAM,CAAC,IAAI,EAAE,GAAG,eAAe,CAAC,GAAG,IAAI,CAAC,YAAY,CAAC;QAErD,OAAO,IAAI,CAAC,cAAc,CAAC,GAAG,CAAC,CAAC,aAAa,EAAE,KAAK,EAAE,EAAE,CAAC,CAAC;YACtD,SAAS,EAAE,aAAa,CAAC,CAAC,CAAC;YAC3B,YAAY,EAAE,eAAe,CAAC,KAAK,CAAC;SACvC,CAAC,CAAC,CAAA;IACP,CAAC;IAGD;;;OAGG;IACO,eAAe;QACrB,IAAI,IAAI,GAAG,IAAI,CAAC,IAAI,CAAC;QAErB,IAAI,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,EAAE,CAAC;YACtB,IAAI,GAAG,IAAI,CAAC,CAAC,CAAC,CAAC;QACnB,CAAC;QAED,IAAI,OAAO,IAAI,IAAI,QAAQ,EAAE,CAAC;YAC1B,IAAI,IAAI,IAAI,+BAA+B,EAAE,CAAC;gBAC1C,OAAO,6BAA6B,CAAC;gBACrC;;;gJAGgI;YACpI,CAAC;YAED,MAAM,OAAO,GAAG,IAAc,CAAC;YAE/B,MAAM,OAAO,GACT,CAAE,MAA8B,CAAC,OAAO,CAAC;gBACpC,EAAE,CAAC,MAA8B,CAAC,OAAO,CAAC;gBAC1C,EAAE,CAAC,OAA+B,CAAC,OAAO,CAAC,CAAmB,CAAA;YAEvE,IAAI,OAAO,EAAE,CAAC;gBACV,OAAO,OAAO,CAAA;YAClB,CAAC;iBAAM,CAAC;gBACJ,MAAM,KAAK,CAAC,6BAA6B,OAAO,+BAA+B,CAAC,CAAC;YACrF,CAAC;QACL,CAAC;aAAM,IAAI,OAAO,IAAI,IAAI,UAAU,EAAE,CAAC;YACnC,OAAO,IAAI,CAAC;QAChB,CAAC;QAED,MAAM,KAAK,CAAC,iFAAiF,CAAC,CAAC;IACnG,CAAC;IAGD;;;;;;;;OAQG;IACM,KAAK,CAAC,UAAU,CAAkB,SAAqB,EAAE,IAA+B;QAC7F,IAAI,CAAC,YAAY,GAAG,KAAK,CAAC;QAE1B,MAAM,OAAO,GAAG,SAAyC,CAAC;QAC1D,MAAM,EAAE,MAAM,EAAE,SAAS,EAAE,GAAG,IAAI,CAAC;QAEnC,MAAM,gBAAgB,GAAG,IAAI,CAAC,kBAAkB,EAAE,CAAC;QACnD,MAAM,aAAa,GAAG,IAAI,CAAC,eAAe,EAAE,CAAC;QAC7C,IAAI,CAAC,aAAa,GAAG,CAAC,aAAa,CAAC,CAAC;QAErC,MAAM,EACF,YAAY,EACZ,UAAU,EACV,YAAY,EACZ,UAAU,EACV,YAAY,EACZ,UAAU,GACb,GAAG,SAAkC,IAAI,EAAE,CAAC;QAE7C,MAAM,YAAY,EAAE,EAAE,CAAC;QAEvB,IAAI,kBAAkB,GAA0B,SAAS,CAAC;QAE1D,KAAK,IAAI,KAAK,GAAG,CAAC,EAAE,KAAK,GAAG,MAAM,EAAE,KAAK,EAAE,EAAE,CAAC;YAC1C,MAAM,YAAY,EAAE,CAAC,KAAK,CAAC,CAAC;YAE5B,IAAI,KAAK,GAAG,CAAC,CAAC;YACd,IAAI,aAAa,GAAG,CAAC,CAAC;YACtB,MAAM,yBAAyB,GAAiC,EAAE,CAAC;YAEnE,0CAA0C;YAC1C,MAAM,QAAQ,GAAG,MAAM,OAAO,CAAC,QAAQ,EAAE,CAAC;YAC1C,IAAI,MAAM,GAAG,MAAM,QAAQ,CAAC,IAAI,EAAE,CAAC;YAEnC,OAAO,CAAC,MAAM,CAAC,IAAI,EAAE,CAAC;gBAClB,MAAM,aAAa,GAAiC,EAAE,KAAK,EAAE,CAAC;gBAE9D,MAAM,EAAE,EAAE,EAAE,EAAE,EAAE,SAAS,EAAE,YAAY,EAAE,GAAG,MAAM,CAAC,KAAK,CAAC;gBACzD,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;gBAC/B,aAAa,IAAI,UAAU,CAAC,CAAC,8BAA8B;gBAE3D,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;oBACvB,MAAM,KAAK,CAAC,wBAAwB,IAAI,CAAC,IAAI,0EAA0E,CAAC,CAAC;gBAC7H,CAAC;gBAED,iFAAiF;gBACjF,MAAM,UAAU,GAAG,EAAE,CAAC,KAAK,CAAC,EAAE,CAAC,KAAK,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;gBAEjD,IAAI,CAAC,kBAAkB,IAAI,kBAAkB,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,UAAU,EAAE,CAAC;oBACnE,kBAAkB,GAAG,kBAAkB,CAAC,UAAU,EAAE,UAAU,CAAC,CAAC;gBACpE,CAAC;gBAED,MAAM,YAAY,EAAE,CAAC,KAAK,CAAC,CAAC;gBAE5B,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;oBACT,MAAM,EAAE,MAAM,EAAE,IAAI,EAAE,GAAG,IAAI,CAAC,QAAQ,CAAC,EAAE,EAAE,EAAE,EAAE,SAAS,EAAE,aAAa,EAAE;wBACrE,WAAW,EAAE,YAAY;wBACzB,UAAU,EAAE,kBAAkB;qBACjC,CAAC,CAAA;oBAEF,MAAM,UAAU,GAAG,CAAC,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;oBAExC,aAAa,CAAC,IAAI,GAAG,UAAU,CAAC;oBAChC,yBAAyB,CAAC,IAAI,GAAG,CAAC,yBAAyB,CAAC,IAAI,IAAI,CAAC,CAAC,GAAG,UAAU,GAAG,UAAU,CAAC;oBAEjG,8BAA8B;oBAC9B,KAAK,MAAM,EAAE,SAAS,EAAE,YAAY,EAAE,IAAI,gBAAgB,EAAE,CAAC;wBACzD,MAAM,UAAU,GAAG,SAAS,CAAC,EAAE,EAAE,MAAO,CAAC,CAAC,IAAI,EAAE,CAAC;wBAEjD,MAAM,YAAY,GAAG,CAAC,UAAU,CAAC,QAAQ,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;wBAEhD,aAAa,CAAC,YAAY,CAAC,GAAG,YAAY,CAAA,CAAA,gBAAgB;wBAC1D,yBAAyB,CAAC,YAAY,CAAC,GAAG,CAAC,yBAAyB,CAAC,YAAY,CAAC,IAAI,CAAC,CAAC,GAAG,YAAY,GAAG,UAAU,CAAC;oBACzH,CAAC;oBAED,EAAE,CAAC,OAAO,CAAC,MAAO,CAAC,CAAC;gBACxB,CAAC,CAAC,CAAA;gBAEF,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;gBACf,EAAE,CAAC,OAAO,CAAC,EAAE,CAAC,CAAC;gBACf,EAAE,CAAC,OAAO,CAAC,SAAS,CAAC,CAAC;gBAEtB,IAAI,YAAY,EAAE,CAAC;oBACf,EAAE,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC;gBAC7B,CAAC;gBAED,MAAM,UAAU,EAAE,CAAC,KAAK,EAAE,aAAa,CAAC,CAAC;gBAEzC,8BAA8B;gBAC9B,MAAM,EAAE,CAAC,SAAS,EAAE,CAAC;gBAErB,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;oBACpB,MAAM;gBACV,CAAC;gBAED,MAAM,GAAG,MAAM,QAAQ,CAAC,IAAI,EAAE,CAAC;gBAC/B,KAAK,EAAE,CAAC;YACZ,CAAC;YAED,KAAK,MAAM,MAAM,IAAI,yBAAyB,EAAE,CAAC;gBAC7C,yBAAyB,CAAC,MAAM,CAAC,GAAG,yBAAyB,CAAC,MAAM,CAAC,GAAG,aAAa,CAAC;YAC1F,CAAC;YAED,MAAM,UAAU,EAAE,CAAC,KAAK,EAAE,yBAAyB,CAAC,CAAC;YAErD,IAAI,IAAI,CAAC,YAAY,EAAE,CAAC;gBACpB,MAAM;YACV,CAAC;QACL,CAAC;QAED,EAAE,CAAC,OAAO,CAAC,kBAAkB,CAAC,CAAC;QAC/B,MAAM,UAAU,EAAE,EAAE,CAAA;QAEpB,OAAO,EAAE,CAAC;IACd,CAAC;IAGD;;;;;;;;;OASG;IACO,QAAQ,CACd,EAAa,EACb,EAAa,EACb,SAAgC,EAChC,aAA6B,EAC7B,WAAsD;QAKtD,IAAI,MAAiB,CAAC;QAEtB,+BAA+B;QAC/B,MAAM,EAAE,KAAK,EAAE,IAAI,EAAE,KAAK,EAAE,GAAG,EAAE,CAAC,aAAa,CAAC,GAAG,EAAE;YACjD,4DAA4D;YAC5D,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,EAAE,EAAE;gBACpB,QAAQ,EAAE,IAAI;gBACd,GAAG,WAAW;aACjB,CAAc,CAAC;YAEhB,6EAA6E;YAC7E,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;YAEhB,MAAM,IAAI,GAAG,SAAS;gBAClB,CAAC,CAAC,aAAa,CAAC,EAAE,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC;gBAC1C,CAAC,CAAC,aAAa,CAAC,EAAE,EAAE,MAAM,CAAC,CAAC;YAEhC,OAAO,IAAI,CAAC,IAAI,EAAe,CAAC;QACpC,CAAC,CAAC,CAAC;QAEH,kBAAkB;QAClB,IAAI,CAAC,SAAS,CAAC,cAAc,CAAC,KAAK,CAAC,CAAC;QAErC,OAAO;YACH,MAAM,EAAE,MAAO;YACf,IAAI;SACP,CAAC;IACN,CAAC;IAGQ,OAAO,CAAC,IAAyB;QACtC,IAAI,IAAI,CAAC,IAAI,IAAI,yBAAyB,EAAE,CAAC;YACzC,MAAM,KAAK,CAAC,gIAAgI,CAAC,CAAA;QACjJ,CAAC;QAED,KAAK,CAAC,OAAO,CAAC,IAAI,CAAC,CAAC;IACxB,CAAC;IAGD;;;;;;;;;OASG;IACI,KAAK,CAAC,QAAQ,CAAC,KAAkB,EAAE,QAA0B,EAAE,SAA8C;QAChH,IAAI,QAAQ,CAAC,IAAI,IAAI,QAAQ,CAAC,iBAAiB,EAAE,CAAC;YAC9C,MAAM,KAAK,CAAC,sBAAsB,IAAI,CAAC,IAAI,mDAAmD,QAAQ,CAAC,iBAAiB,GAAG,CAAC,CAAC;QACjI,CAAC;QAED,IAAI,CAAC,cAAc,GAAG,KAAK,CAAC;QAE5B,IAAI,aAAa,GAAgB,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAgB,CAAC,CAAC,kDAAkD;QAEtI,OAAO,CAAC,IAAI,CAAC,cAAc,IAAI,QAAQ,CAAC,IAAI,GAAG,QAAQ,CAAC,iBAAiB,EAAE,CAAC;YACxE,qEAAqE;YACrE,MAAM,UAAU,GAAG,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,IAAI,CAAC,gBAAgB,CAAC,aAAa,EAAE,QAAQ,CAAC,CAAC,CAAC;YAEjF,wDAAwD;YACxD,MAAM,oBAAoB,GAAG,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE,CAAC,UAAU,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACpE,MAAM,SAAS,CAAC,oBAAoB,CAAC,CAAC;YAEtC,oBAAoB,CAAC,OAAO,EAAE,CAAC;YAE/B,aAAa,CAAC,OAAO,EAAE,CAAC;YACxB,aAAa,GAAG,UAAU,CAAC;QAC/B,CAAC;QAED,EAAE,CAAC,OAAO,CAAC,aAAa,CAAC,CAAC;IAC9B,CAAC;IAGD;;;;;;;OAOG;IACI,gBAAgB,CAAC,KAAkB,EAAE,QAA0B;QAClE,IAAI,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,IAAI,CAAC,EAAE,CAAC;YACtB,MAAM,KAAK,CAAC,8BAA8B,IAAI,CAAC,IAAI,0CAA0C,CAAC,CAAC;QACnG,CAAC;QAED,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,qDAAqD;YACrD,MAAM,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,KAAK,EAAE,EAAE,OAAO,EAAE,QAAQ,EAAE,CAAc,CAAC;YAEzE,MAAM,CAAC,UAAU,EAAE,eAAe,EAAE,UAAU,CAAC,GAAG,UAAU,CAAC,KAAK,CAAC;YAEnE,qBAAqB;YACrB,MAAM,UAAU,GAAG,UAAU,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,eAAe,GAAG,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,UAAU,EAAE,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAA;YAEvG,OAAO,UAAyB,CAAC;QACrC,CAAC,CAAC,CAAA;IACN,CAAC;IAGD,IAAI,cAAc;QACd,OAAO,IAAI,CAAC,eAAe,CAAC;IAChC,CAAC;IAGD,IAAI,cAAc,CAAC,IAAa;QAC5B,IAAI,CAAC,eAAe,GAAG,IAAI,CAAC;IAChC,CAAC;;AAKL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,QAAQ,CAAC,CAAC"}
@@ -1,40 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { type ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config";
3
- export interface UNetArgs {
4
- /**
5
- * The starting number of filters.
6
- */
7
- filters: number;
8
- /**
9
- * The number of categories. For binary segmentation, `units=1`.
10
- */
11
- units: number;
12
- /**
13
- * The activation of the final output convolution layer. Defaults to `sigmoid` if `categories=1`, else `softmax`.
14
- */
15
- activation?: ActivationIdentifier;
16
- /**
17
- * The depth of the U-Net or the number of contractions and the number of expansions.
18
- */
19
- depth: number;
20
- /**
21
- * Adds residual connections to transform the model into a ResUNet. Defaults to `false`.
22
- */
23
- residual?: boolean;
24
- /**
25
- * Adds batch normalization to convolutions. Best used for batched inputs. Defaults to `false`.
26
- */
27
- batchNorm?: boolean;
28
- /**
29
- * Set the unbatched input shape of the U-Net in the format `[height, width, channels]`. Defaults to `[null, null, 3]`. If set, only channels is mandatory.
30
- */
31
- inputShape?: [number | null, number | null, number];
32
- }
33
- export type UNetModelArgs = UNetArgs & Omit<tf.SequentialArgs, "layers">;
34
- export declare class UNetModel extends tf.Sequential {
35
- constructor(args: UNetModelArgs);
36
- summary(lineLength?: number, positions?: number[], printFn?: (message?: any, ...optionalParams: any[]) => void): void;
37
- }
38
- export declare function createUNet({ filters, depth, units, activation, residual, batchNorm, inputShape }: UNetModelArgs): tf.LayersModel;
39
- export declare function loadUNetModel(pathOrIOHandler: string | tf.io.IOHandler, options?: tf.io.LoadOptions): Promise<tf.LayersModel>;
40
- //# sourceMappingURL=u_net.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"u_net.d.ts","sourceRoot":"","sources":["../../../src/models/u_net.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,KAAK,oBAAoB,EAAE,MAAM,6DAA6D,CAAC;AAGxG,MAAM,WAAW,QAAQ;IACrB;;OAEG;IACH,OAAO,EAAE,MAAM,CAAC;IAChB;;OAEG;IACH,KAAK,EAAE,MAAM,CAAC;IACd;;OAEG;IACH,UAAU,CAAC,EAAE,oBAAoB,CAAC;IAClC;;OAEG;IACH,KAAK,EAAE,MAAM,CAAC;IACd;;OAEG;IACH,QAAQ,CAAC,EAAE,OAAO,CAAC;IACnB;;OAEG;IACH,SAAS,CAAC,EAAE,OAAO,CAAC;IACpB;;OAEG;IACH,UAAU,CAAC,EAAE,CAAC,MAAM,GAAG,IAAI,EAAE,MAAM,GAAG,IAAI,EAAE,MAAM,CAAC,CAAC;CACvD;AAGD,MAAM,MAAM,aAAa,GAAG,QAAQ,GAAG,IAAI,CAAC,EAAE,CAAC,cAAc,EAAE,QAAQ,CAAC,CAAC;AAGzE,qBAAa,SAAU,SAAQ,EAAE,CAAC,UAAU;gBAE5B,IAAI,EAAE,aAAa;IAsBtB,OAAO,CAAC,UAAU,CAAC,EAAE,MAAM,EAAE,SAAS,CAAC,EAAE,MAAM,EAAE,EAAE,OAAO,CAAC,EAAE,CAAC,OAAO,CAAC,EAAE,GAAG,EAAE,GAAG,cAAc,EAAE,GAAG,EAAE,KAAK,IAAI,GAAG,IAAI;CAIjI;AAGD,wBAAgB,UAAU,CAAC,EAAE,OAAO,EAAE,KAAK,EAAE,KAAK,EAAE,UAAU,EAAE,QAAgB,EAAE,SAAiB,EAAE,UAA4B,EAAE,EAAE,aAAa,kBA4CjJ;AAGD,wBAAsB,aAAa,CAAC,eAAe,EAAE,MAAM,GAAG,EAAE,CAAC,EAAE,CAAC,SAAS,EAAE,OAAO,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,WAAW,2BAOzG"}