@stellarapp/tfjs-stellar 1.0.3 → 1.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. package/README.md +17 -0
  2. package/dist/index.d.ts +3 -1
  3. package/dist/index.d.ts.map +1 -1
  4. package/dist/index.js +3 -1
  5. package/dist/index.js.map +1 -1
  6. package/dist/kv_cache.d.ts +2 -0
  7. package/dist/kv_cache.d.ts.map +1 -1
  8. package/dist/kv_cache.js +6 -0
  9. package/dist/kv_cache.js.map +1 -1
  10. package/dist/models/index.d.ts +2 -1
  11. package/dist/models/index.d.ts.map +1 -1
  12. package/dist/models/index.js +2 -1
  13. package/dist/models/index.js.map +1 -1
  14. package/package.json +1 -1
  15. package/dist/jest.config.d.ts +0 -8
  16. package/dist/jest.config.d.ts.map +0 -1
  17. package/dist/jest.config.js +0 -147
  18. package/dist/jest.config.js.map +0 -1
  19. package/dist/src/index.d.ts +0 -6
  20. package/dist/src/index.d.ts.map +0 -1
  21. package/dist/src/index.js +0 -6
  22. package/dist/src/index.js.map +0 -1
  23. package/dist/src/kv_cache.d.ts +0 -53
  24. package/dist/src/kv_cache.d.ts.map +0 -1
  25. package/dist/src/kv_cache.js +0 -135
  26. package/dist/src/kv_cache.js.map +0 -1
  27. package/dist/src/layers/cached_rope_multihead_attention.d.ts +0 -31
  28. package/dist/src/layers/cached_rope_multihead_attention.d.ts.map +0 -1
  29. package/dist/src/layers/cached_rope_multihead_attention.js +0 -76
  30. package/dist/src/layers/cached_rope_multihead_attention.js.map +0 -1
  31. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts +0 -2
  32. package/dist/src/layers/cached_rope_multihead_attention.test.d.ts.map +0 -1
  33. package/dist/src/layers/cached_rope_multihead_attention.test.js +0 -43
  34. package/dist/src/layers/cached_rope_multihead_attention.test.js.map +0 -1
  35. package/dist/src/layers/gpt_decoder_block.d.ts +0 -34
  36. package/dist/src/layers/gpt_decoder_block.d.ts.map +0 -1
  37. package/dist/src/layers/gpt_decoder_block.js +0 -51
  38. package/dist/src/layers/gpt_decoder_block.js.map +0 -1
  39. package/dist/src/layers/index.d.ts +0 -17
  40. package/dist/src/layers/index.d.ts.map +0 -1
  41. package/dist/src/layers/index.js +0 -33
  42. package/dist/src/layers/index.js.map +0 -1
  43. package/dist/src/layers/multihead_attention.d.ts +0 -106
  44. package/dist/src/layers/multihead_attention.d.ts.map +0 -1
  45. package/dist/src/layers/multihead_attention.js +0 -269
  46. package/dist/src/layers/multihead_attention.js.map +0 -1
  47. package/dist/src/layers/multihead_attention.test.d.ts +0 -2
  48. package/dist/src/layers/multihead_attention.test.d.ts.map +0 -1
  49. package/dist/src/layers/multihead_attention.test.js +0 -160
  50. package/dist/src/layers/multihead_attention.test.js.map +0 -1
  51. package/dist/src/layers/positional_encoding.d.ts +0 -37
  52. package/dist/src/layers/positional_encoding.d.ts.map +0 -1
  53. package/dist/src/layers/positional_encoding.js +0 -115
  54. package/dist/src/layers/positional_encoding.js.map +0 -1
  55. package/dist/src/layers/positional_encoding.test.d.ts +0 -2
  56. package/dist/src/layers/positional_encoding.test.d.ts.map +0 -1
  57. package/dist/src/layers/positional_encoding.test.js +0 -95
  58. package/dist/src/layers/positional_encoding.test.js.map +0 -1
  59. package/dist/src/layers/rotary_position_embedding.d.ts +0 -39
  60. package/dist/src/layers/rotary_position_embedding.d.ts.map +0 -1
  61. package/dist/src/layers/rotary_position_embedding.js +0 -99
  62. package/dist/src/layers/rotary_position_embedding.js.map +0 -1
  63. package/dist/src/layers/rotary_position_embedding.test.d.ts +0 -2
  64. package/dist/src/layers/rotary_position_embedding.test.d.ts.map +0 -1
  65. package/dist/src/layers/rotary_position_embedding.test.js +0 -88
  66. package/dist/src/layers/rotary_position_embedding.test.js.map +0 -1
  67. package/dist/src/layers/token_and_positional_embedding.d.ts +0 -47
  68. package/dist/src/layers/token_and_positional_embedding.d.ts.map +0 -1
  69. package/dist/src/layers/token_and_positional_embedding.js +0 -109
  70. package/dist/src/layers/token_and_positional_embedding.js.map +0 -1
  71. package/dist/src/layers/token_and_positional_embedding.test.d.ts +0 -2
  72. package/dist/src/layers/token_and_positional_embedding.test.d.ts.map +0 -1
  73. package/dist/src/layers/token_and_positional_embedding.test.js +0 -58
  74. package/dist/src/layers/token_and_positional_embedding.test.js.map +0 -1
  75. package/dist/src/layers/transformer_decoder.d.ts +0 -69
  76. package/dist/src/layers/transformer_decoder.d.ts.map +0 -1
  77. package/dist/src/layers/transformer_decoder.js +0 -182
  78. package/dist/src/layers/transformer_decoder.js.map +0 -1
  79. package/dist/src/layers/transformer_decoder.test.d.ts +0 -2
  80. package/dist/src/layers/transformer_decoder.test.d.ts.map +0 -1
  81. package/dist/src/layers/transformer_decoder.test.js +0 -72
  82. package/dist/src/layers/transformer_decoder.test.js.map +0 -1
  83. package/dist/src/layers/transformer_encoder.d.ts +0 -55
  84. package/dist/src/layers/transformer_encoder.d.ts.map +0 -1
  85. package/dist/src/layers/transformer_encoder.js +0 -175
  86. package/dist/src/layers/transformer_encoder.js.map +0 -1
  87. package/dist/src/layers/transformer_encoder.test.d.ts +0 -2
  88. package/dist/src/layers/transformer_encoder.test.d.ts.map +0 -1
  89. package/dist/src/layers/transformer_encoder.test.js +0 -58
  90. package/dist/src/layers/transformer_encoder.test.js.map +0 -1
  91. package/dist/src/losses/dice.d.ts +0 -30
  92. package/dist/src/losses/dice.d.ts.map +0 -1
  93. package/dist/src/losses/dice.js +0 -93
  94. package/dist/src/losses/dice.js.map +0 -1
  95. package/dist/src/losses/index.d.ts +0 -2
  96. package/dist/src/losses/index.d.ts.map +0 -1
  97. package/dist/src/losses/index.js +0 -2
  98. package/dist/src/losses/index.js.map +0 -1
  99. package/dist/src/masks.d.ts +0 -20
  100. package/dist/src/masks.d.ts.map +0 -1
  101. package/dist/src/masks.js +0 -37
  102. package/dist/src/masks.js.map +0 -1
  103. package/dist/src/metrics.d.ts +0 -20
  104. package/dist/src/metrics.d.ts.map +0 -1
  105. package/dist/src/metrics.js +0 -28
  106. package/dist/src/metrics.js.map +0 -1
  107. package/dist/src/models/gpt_model.d.ts +0 -94
  108. package/dist/src/models/gpt_model.d.ts.map +0 -1
  109. package/dist/src/models/gpt_model.js +0 -154
  110. package/dist/src/models/gpt_model.js.map +0 -1
  111. package/dist/src/models/index.d.ts +0 -3
  112. package/dist/src/models/index.d.ts.map +0 -1
  113. package/dist/src/models/index.js +0 -3
  114. package/dist/src/models/index.js.map +0 -1
  115. package/dist/src/models/llm_model.d.ts +0 -87
  116. package/dist/src/models/llm_model.d.ts.map +0 -1
  117. package/dist/src/models/llm_model.js +0 -245
  118. package/dist/src/models/llm_model.js.map +0 -1
  119. package/dist/src/models/u_net.d.ts +0 -40
  120. package/dist/src/models/u_net.d.ts.map +0 -1
  121. package/dist/src/models/u_net.js +0 -151
  122. package/dist/src/models/u_net.js.map +0 -1
  123. package/dist/src/tfjs_types.d.ts +0 -10
  124. package/dist/src/tfjs_types.d.ts.map +0 -1
  125. package/dist/src/tfjs_types.js +0 -2
  126. package/dist/src/tfjs_types.js.map +0 -1
  127. package/dist/src/utils.d.ts +0 -28
  128. package/dist/src/utils.d.ts.map +0 -1
  129. package/dist/src/utils.js +0 -63
  130. package/dist/src/utils.js.map +0 -1
  131. package/dist/src/utils.test.d.ts +0 -2
  132. package/dist/src/utils.test.d.ts.map +0 -1
  133. package/dist/src/utils.test.js +0 -73
  134. package/dist/src/utils.test.js.map +0 -1
@@ -1,175 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { MultiHeadAttention } from "@/layers/multihead_attention";
3
- /**
4
- * This class implements the transformer encoder architecture from the 2017 paper
5
- * Attention Is All You Need.
6
- *
7
- * This layer accepts exactly one tensor input with the shape
8
- * `[ batch, sequences, embedding dims ]`.
9
- *
10
- * @param numHeads number of attention heads to use
11
- * @param embedDim the embedding size of the input (input embeddings, typically the last dimension)
12
- * @param causal use causal masking, default `false` for encoders
13
- * @param dropout use dropout during the attention calculations, default `0.1`
14
- * @param activation the activation of the intermediate feed forward layer, default `relu`
15
- * @param dimsFeedForward the size of the intermediate feed forward layer, default `2048`
16
- * @param useBias use bias for the dense sublayers and multiHead attention's dense sublayers, default `true`
17
- */
18
- export class TransformerEncoder extends tf.layers.Layer {
19
- static className = "TransformerEncoder";
20
- selfAttention;
21
- selfAttentionDropout;
22
- selfAttentionNorm;
23
- reluLayer;
24
- linearLayer;
25
- feedForwardDropout;
26
- feedFowardNorm;
27
- numHeads;
28
- embedDim;
29
- causal;
30
- useBias;
31
- dropout;
32
- activation;
33
- dimsFeedForward;
34
- constructor({ numHeads, embedDim, causal, useBias, dropout, activation, dimsFeedForward, ...args }) {
35
- super(args);
36
- this.numHeads = numHeads;
37
- this.embedDim = embedDim;
38
- this.causal = causal ?? false;
39
- this.useBias = useBias ?? true;
40
- this.dropout = dropout ?? 0.1;
41
- this.activation = activation ?? "relu";
42
- this.dimsFeedForward = dimsFeedForward ?? 2048;
43
- if (this.dropout >= 1) {
44
- throw Error(`${this.getClassName()}::constructor dropout must be within [0, 1)`);
45
- }
46
- // self attention sub-block
47
- this.selfAttention = new MultiHeadAttention({
48
- numHeads: this.numHeads, embedDim: this.embedDim, useBias: this.useBias,
49
- dropout: this.dropout, causal: this.causal
50
- });
51
- this.selfAttentionDropout = tf.layers.dropout({ rate: this.dropout });
52
- this.selfAttentionNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
53
- // feed forward sub-block
54
- this.reluLayer = tf.layers.dense({
55
- units: this.dimsFeedForward, activation: this.activation,
56
- useBias: this.useBias
57
- });
58
- this.linearLayer = tf.layers.dense({
59
- units: this.embedDim, activation: "linear",
60
- useBias: this.useBias
61
- });
62
- this.feedForwardDropout = tf.layers.dropout({ rate: this.dropout });
63
- this.feedFowardNorm = tf.layers.layerNormalization({ epsilon: 1e-6 });
64
- }
65
- /**
66
- * Forward propagation
67
- */
68
- call(inputs, kwargs) {
69
- // validate the input tensors
70
- let input;
71
- if (Array.isArray(inputs)) {
72
- if (inputs.length != 1) {
73
- throw Error(`${this.getClassName}::call ${this.name} expects exactly 1 tensor` +
74
- ` input, got ${inputs.length} inputs instead.`);
75
- }
76
- input = inputs[0];
77
- }
78
- else {
79
- input = inputs;
80
- }
81
- // perform forward propagation
82
- return tf.tidy(() => {
83
- const attention = this.selfAttentionBlock(input, kwargs);
84
- const feedforward = this.feedForwardBlock(attention, kwargs);
85
- return feedforward;
86
- });
87
- }
88
- selfAttentionBlock(x, kwargs) {
89
- return tf.tidy(() => {
90
- const residual = x;
91
- let attention = this.selfAttention.apply(x, kwargs);
92
- attention = this.selfAttentionDropout.apply(attention, kwargs);
93
- attention = tf.add(attention, residual);
94
- attention = this.selfAttentionNorm.apply(attention);
95
- return attention;
96
- });
97
- }
98
- feedForwardBlock(x, kwargs) {
99
- return tf.tidy(() => {
100
- const residual = x;
101
- let feedForward = this.reluLayer.apply(x);
102
- feedForward = this.linearLayer.apply(feedForward);
103
- feedForward = this.feedForwardDropout.apply(feedForward, kwargs);
104
- feedForward = tf.add(feedForward, residual);
105
- feedForward = this.feedFowardNorm.apply(feedForward);
106
- return feedForward;
107
- });
108
- }
109
- /**
110
- * Initialize the sublayers' weights and track them to enable backpropagation.
111
- */
112
- build(inputShape) {
113
- let input_shapes = [];
114
- if (Array.isArray(inputShape) && Array.isArray(inputShape[0])) {
115
- // input is an array of shapes
116
- input_shapes = inputShape;
117
- }
118
- else if (inputShape.length != 0) {
119
- // input is a single shape
120
- input_shapes = [inputShape];
121
- }
122
- // expects only 1 rank 3 tensor input
123
- if (input_shapes.length != 1 || input_shapes[0].length != 3) {
124
- throw Error(`${this.getClassName()}::build ${this.name} expects a single input shape of [batch, seq, embed_dim], got ${JSON.stringify(inputShape)}`);
125
- }
126
- // initialize self attention sub-block's weights
127
- this.selfAttention.build(inputShape);
128
- this.selfAttentionNorm.build(inputShape);
129
- // inintialize feedforward sub-block's weights
130
- const reluLayerOutputShape = this.reluLayer.computeOutputShape(inputShape);
131
- const linearLayerOutputShape = this.linearLayer.computeOutputShape(reluLayerOutputShape);
132
- this.reluLayer.build(inputShape);
133
- this.linearLayer.build(reluLayerOutputShape);
134
- this.feedFowardNorm.build(linearLayerOutputShape);
135
- // track sublayers' weights
136
- this.trainableWeights = [
137
- ...this.selfAttention.trainableWeights,
138
- ...this.selfAttentionDropout.trainableWeights,
139
- ...this.selfAttentionNorm.trainableWeights,
140
- ...this.reluLayer.trainableWeights,
141
- ...this.linearLayer.trainableWeights,
142
- ...this.feedForwardDropout.trainableWeights,
143
- ...this.feedFowardNorm.trainableWeights
144
- ];
145
- // rename the weights otherwise they'll take on the default naming and overlap
146
- // each other which breaks model loading due to duplicate weight names
147
- let indexing = 0;
148
- for (const weight of this.trainableWeights) {
149
- const unique_name = `${this.getClassName()}_${indexing}`;
150
- weight.name += unique_name;
151
- weight.originalName += unique_name;
152
- indexing++;
153
- }
154
- super.build(inputShape);
155
- }
156
- /**
157
- * Save the layer's hyperparameters for serialization
158
- */
159
- getConfig() {
160
- const base_config = super.getConfig();
161
- const config = {
162
- numHeads: this.numHeads,
163
- embedDim: this.embedDim,
164
- causal: this.causal,
165
- useBias: this.useBias,
166
- dropout: this.dropout,
167
- activation: this.activation,
168
- dimsFeedForward: this.dimsFeedForward
169
- };
170
- Object.assign(config, base_config);
171
- return config;
172
- }
173
- }
174
- tf.serialization.registerClass(TransformerEncoder);
175
- //# sourceMappingURL=transformer_encoder.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"transformer_encoder.js","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAIvC,OAAO,EAAE,kBAAkB,EAA+B,MAAM,8BAA8B,CAAC;AAS/F;;;;;;;;;;;;;;GAcG;AACH,MAAM,OAAO,kBAAmB,SAAQ,EAAE,CAAC,MAAM,CAAC,KAAK;IACnD,MAAM,CAAC,SAAS,GAAG,oBAAoB,CAAC;IAEvB,aAAa,CAAkB;IAC/B,oBAAoB,CAAkB;IACtC,iBAAiB,CAAkB;IAEnC,SAAS,CAAkB;IAC3B,WAAW,CAAkB;IAC7B,kBAAkB,CAAkB;IACpC,cAAc,CAAkB;IAEhC,QAAQ,CAAS;IACjB,QAAQ,CAAS;IACjB,MAAM,CAAU;IAChB,OAAO,CAAU;IACjB,OAAO,CAAS;IAChB,UAAU,CAAuB;IACjC,eAAe,CAAS;IAGzC,YAAY,EAAE,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,OAAO,EAAE,OAAO,EAAE,UAAU,EAAE,eAAe,EAAE,GAAG,IAAI,EAA0B;QACtH,KAAK,CAAC,IAAI,CAAC,CAAC;QAEZ,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,QAAQ,GAAG,QAAQ,CAAC;QACzB,IAAI,CAAC,MAAM,GAAG,MAAM,IAAI,KAAK,CAAC;QAC9B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,IAAI,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,OAAO,IAAI,GAAG,CAAC;QAC9B,IAAI,CAAC,UAAU,GAAG,UAAU,IAAI,MAAM,CAAC;QACvC,IAAI,CAAC,eAAe,GAAG,eAAe,IAAI,IAAI,CAAC;QAE/C,IAAI,IAAI,CAAC,OAAO,IAAI,CAAC,EAAE,CAAC;YACpB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,6CAA6C,CAAC,CAAC;QACrF,CAAC;QAED,2BAA2B;QAC3B,IAAI,CAAC,aAAa,GAAG,IAAI,kBAAkB,CAAC;YACxC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,OAAO,EAAE,IAAI,CAAC,OAAO;YACvE,OAAO,EAAE,IAAI,CAAC,OAAO,EAAE,MAAM,EAAE,IAAI,CAAC,MAAM;SAC7C,CAAC,CAAC;QACH,IAAI,CAAC,oBAAoB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAA;QACrE,IAAI,CAAC,iBAAiB,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;QAEzE,yBAAyB;QACzB,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAC7B,KAAK,EAAE,IAAI,CAAC,eAAe,EAAE,UAAU,EAAE,IAAI,CAAC,UAAU;YACxD,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC;YAC/B,KAAK,EAAE,IAAI,CAAC,QAAQ,EAAE,UAAU,EAAE,QAAQ;YAC1C,OAAO,EAAE,IAAI,CAAC,OAAO;SACxB,CAAC,CAAC;QACH,IAAI,CAAC,kBAAkB,GAAG,EAAE,CAAC,MAAM,CAAC,OAAO,CAAC,EAAE,IAAI,EAAE,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC;QACpE,IAAI,CAAC,cAAc,GAAG,EAAE,CAAC,MAAM,CAAC,kBAAkB,CAAC,EAAE,OAAO,EAAE,IAAI,EAAE,CAAC,CAAC;IAC1E,CAAC;IAGD;;OAEG;IACM,IAAI,CAAC,MAA+B,EAAE,MAAc;QACzD,6BAA6B;QAC7B,IAAI,KAAgB,CAAC;QAErB,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,EAAE,CAAC;YACxB,IAAI,MAAM,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;gBACrB,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,UAAU,IAAI,CAAC,IAAI,2BAA2B;oBAC1E,eAAe,MAAM,CAAC,MAAM,kBAAkB,CAAC,CAAC;YACxD,CAAC;YAED,KAAK,GAAG,MAAM,CAAC,CAAC,CAAC,CAAC;QACtB,CAAC;aAAM,CAAC;YACJ,KAAK,GAAG,MAAM,CAAC;QACnB,CAAC;QAED,8BAA8B;QAC9B,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,SAAS,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,EAAE,MAAM,CAAC,CAAC;YACzD,MAAM,WAAW,GAAG,IAAI,CAAC,gBAAgB,CAAC,SAAS,EAAE,MAAM,CAAC,CAAC;YAE7D,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGO,kBAAkB,CAAC,CAAY,EAAE,MAAc;QACnD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,SAAS,GAAG,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC,EAAE,MAAM,CAAc,CAAC;YACjE,SAAS,GAAG,IAAI,CAAC,oBAAoB,CAAC,KAAK,CAAC,SAAS,EAAE,MAAM,CAAc,CAAC;YAC5E,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;YACxC,SAAS,GAAG,IAAI,CAAC,iBAAiB,CAAC,KAAK,CAAC,SAAS,CAAc,CAAC;YAEjE,OAAO,SAAS,CAAC;QACrB,CAAC,CAAC,CAAC;IACP,CAAC;IAGO,gBAAgB,CAAC,CAAY,EAAE,MAAc;QACjD,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;YAChB,MAAM,QAAQ,GAAG,CAAC,CAAC;YAEnB,IAAI,WAAW,GAAG,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC1C,WAAW,GAAG,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,WAAW,CAAC,CAAC;YAClD,WAAW,GAAG,IAAI,CAAC,kBAAkB,CAAC,KAAK,CAAC,WAAW,EAAE,MAAM,CAAc,CAAC;YAC9E,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,QAAQ,CAAC,CAAC;YAC5C,WAAW,GAAG,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,WAAW,CAAc,CAAC;YAElE,OAAO,WAAW,CAAC;QACvB,CAAC,CAAC,CAAC;IACP,CAAC;IAGD;;OAEG;IACM,KAAK,CAAC,UAAiC;QAC5C,IAAI,YAAY,GAAe,EAAE,CAAC;QAElC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,IAAI,KAAK,CAAC,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,EAAE,CAAC;YAC5D,8BAA8B;YAC9B,YAAY,GAAG,UAAwB,CAAC;QAC5C,CAAC;aAAM,IAAI,UAAU,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAChC,0BAA0B;YAC1B,YAAY,GAAG,CAAC,UAAsB,CAAC,CAAC;QAC5C,CAAC;QAED,qCAAqC;QACrC,IAAI,YAAY,CAAC,MAAM,IAAI,CAAC,IAAI,YAAY,CAAC,CAAC,CAAC,CAAC,MAAM,IAAI,CAAC,EAAE,CAAC;YAC1D,MAAM,KAAK,CAAC,GAAG,IAAI,CAAC,YAAY,EAAE,WAAW,IAAI,CAAC,IAAI,iEAAiE,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,EAAE,CAAC,CAAA;QACxJ,CAAC;QAED,gDAAgD;QAChD,IAAI,CAAC,aAAa,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QACrC,IAAI,CAAC,iBAAiB,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QAEzC,8CAA8C;QAC9C,MAAM,oBAAoB,GAAG,IAAI,CAAC,SAAS,CAAC,kBAAkB,CAAC,UAAU,CAAC,CAAC;QAC3E,MAAM,sBAAsB,GAAG,IAAI,CAAC,WAAW,CAAC,kBAAkB,CAAC,oBAAoB,CAAC,CAAC;QAEzF,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;QACjC,IAAI,CAAC,WAAW,CAAC,KAAK,CAAC,oBAAoB,CAAC,CAAC;QAC7C,IAAI,CAAC,cAAc,CAAC,KAAK,CAAC,sBAAsB,CAAC,CAAC;QAElD,2BAA2B;QAC3B,IAAI,CAAC,gBAAgB,GAAG;YACpB,GAAG,IAAI,CAAC,aAAa,CAAC,gBAAgB;YACtC,GAAG,IAAI,CAAC,oBAAoB,CAAC,gBAAgB;YAC7C,GAAG,IAAI,CAAC,iBAAiB,CAAC,gBAAgB;YAC1C,GAAG,IAAI,CAAC,SAAS,CAAC,gBAAgB;YAClC,GAAG,IAAI,CAAC,WAAW,CAAC,gBAAgB;YACpC,GAAG,IAAI,CAAC,kBAAkB,CAAC,gBAAgB;YAC3C,GAAG,IAAI,CAAC,cAAc,CAAC,gBAAgB;SAC1C,CAAC;QAEF,8EAA8E;QAC9E,sEAAsE;QACtE,IAAI,QAAQ,GAAG,CAAC,CAAC;QAEjB,KAAK,MAAM,MAAM,IAAI,IAAI,CAAC,gBAAgB,EAAE,CAAC;YACzC,MAAM,WAAW,GAAG,GAAG,IAAI,CAAC,YAAY,EAAE,IAAI,QAAQ,EAAE,CAAC;YACxD,MAAc,CAAC,IAAI,IAAI,WAAW,CAAC;YACnC,MAAc,CAAC,YAAY,IAAI,WAAW,CAAC;YAC5C,QAAQ,EAAE,CAAC;QACf,CAAC;QAED,KAAK,CAAC,KAAK,CAAC,UAAU,CAAC,CAAC;IAC5B,CAAC;IAGD;;OAEG;IACM,SAAS;QACd,MAAM,WAAW,GAAG,KAAK,CAAC,SAAS,EAAE,CAAC;QAEtC,MAAM,MAAM,GAAG;YACX,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,QAAQ,EAAE,IAAI,CAAC,QAAQ;YACvB,MAAM,EAAE,IAAI,CAAC,MAAM;YACnB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,OAAO,EAAE,IAAI,CAAC,OAAO;YACrB,UAAU,EAAE,IAAI,CAAC,UAAU;YAC3B,eAAe,EAAE,IAAI,CAAC,eAAe;SACxC,CAAC;QAEF,MAAM,CAAC,MAAM,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC;QAEnC,OAAO,MAAM,CAAC;IAClB,CAAC;;AAIL,EAAE,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC"}
@@ -1,2 +0,0 @@
1
- export {};
2
- //# sourceMappingURL=transformer_encoder.test.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"transformer_encoder.test.d.ts","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.test.ts"],"names":[],"mappings":""}
@@ -1,58 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { TransformerEncoder } from "@/layers/transformer_encoder";
3
- // disables warning for using the faster node backend,
4
- // https://github.com/tensorflow/tfjs/issues/5349#issuecomment-885170504
5
- tf.env().set('IS_NODE', false);
6
- describe("TransformerEncoder tests", () => {
7
- it("should return an output with the same shape as the input", () => {
8
- const input = tf.randomUniform([2, 3, 10]);
9
- const decoder = new TransformerEncoder({
10
- numHeads: 2, embedDim: input.shape.at(-1),
11
- dropout: 0.5, activation: "gelu", dimsFeedForward: 512, useBias: true
12
- });
13
- const output = decoder.apply(input);
14
- expect(output.shape.length).toBe(input.shape.length);
15
- });
16
- test("correct forward calls", () => {
17
- const input = tf.randomUniform([2, 3, 10]);
18
- const encoder = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1) });
19
- expect(() => encoder.apply(input)).not.toThrow();
20
- expect(() => encoder.apply([input])).not.toThrow();
21
- const causal = new TransformerEncoder({ numHeads: 2, embedDim: input.shape.at(-1), causal: true });
22
- expect(() => causal.apply(input)).not.toThrow();
23
- expect(() => causal.apply([input])).not.toThrow();
24
- });
25
- it("should fail to instantiate a layer if heads count is not divisible by the input's embedding dimension", () => {
26
- const input = tf.randomUniform([2, 3, 10]);
27
- expect(() => new TransformerEncoder({ numHeads: 3, embedDim: input.shape.at(-1) })).toThrow();
28
- expect(() => new TransformerEncoder({ numHeads: 5, embedDim: input.shape.at(-1) })).not.toThrow();
29
- });
30
- it("should not accept non-rank 3 tensor inputs", () => {
31
- const incorrect_input = tf.randomUniform([2, 3, 10, 10]);
32
- const incorrect_input2 = tf.randomUniform([2, 3]);
33
- const correct_input = tf.randomUniform([2, 3, 10]);
34
- const encoder = new TransformerEncoder({ numHeads: 2, embedDim: incorrect_input.shape.at(-1) });
35
- expect(() => encoder.apply([correct_input, correct_input])).toThrow();
36
- expect(() => encoder.apply(incorrect_input)).toThrow();
37
- expect(() => encoder.apply(incorrect_input2)).toThrow();
38
- expect(() => encoder.apply([correct_input, incorrect_input])).toThrow();
39
- expect(() => encoder.apply([incorrect_input, correct_input])).toThrow();
40
- expect(() => encoder.apply([correct_input, incorrect_input2])).toThrow();
41
- expect(() => encoder.apply([incorrect_input2, correct_input])).toThrow();
42
- });
43
- it("should accept exactly one input", () => {
44
- const input = tf.randomUniform([2, 3, 10]);
45
- const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1) });
46
- expect(() => encoder.apply(input)).not.toThrow();
47
- expect(() => encoder.apply([input])).not.toThrow();
48
- expect(() => encoder.apply([])).toThrow();
49
- expect(() => encoder.apply([input, input])).toThrow();
50
- expect(() => encoder.apply([input, input, input])).toThrow();
51
- });
52
- it("should return a non-empty config dict", () => {
53
- const input = tf.randomUniform([2, 3, 10]);
54
- const encoder = new TransformerEncoder({ numHeads: 1, embedDim: input.shape.at(-1) });
55
- expect(Object.keys(encoder.getConfig())).not.toBe(0);
56
- });
57
- });
58
- //# sourceMappingURL=transformer_encoder.test.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"transformer_encoder.test.js","sourceRoot":"","sources":["../../../src/layers/transformer_encoder.test.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAEvC,OAAO,EAAE,kBAAkB,EAAE,MAAM,8BAA8B,CAAC;AAElE,sDAAsD;AACtD,wEAAwE;AACxE,EAAE,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,SAAS,EAAE,KAAK,CAAC,CAAC;AAG/B,QAAQ,CAAC,0BAA0B,EAAE,GAAG,EAAE;IACtC,EAAE,CAAC,0DAA0D,EAAE,GAAG,EAAE;QAChE,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC;YACnC,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE;YAC1C,OAAO,EAAE,GAAG,EAAE,UAAU,EAAE,MAAM,EAAE,eAAe,EAAE,GAAG,EAAE,OAAO,EAAE,IAAI;SACxE,CAAC,CAAC;QAEH,MAAM,MAAM,GAAG,OAAO,CAAC,KAAK,CAAC,KAAK,CAAc,CAAC;QAEjD,MAAM,CAAC,MAAM,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;IAGF,IAAI,CAAC,uBAAuB,EAAE,GAAG,EAAE;QAC/B,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM,MAAM,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,MAAM,EAAE,IAAI,EAAE,CAAC,CAAC;QACpG,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAChD,MAAM,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACtD,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uGAAuG,EAAE,GAAG,EAAE;QAC7G,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC/F,MAAM,CAAC,GAAG,EAAE,CAAC,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;IACvG,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,4CAA4C,EAAE,GAAG,EAAE;QAClD,MAAM,eAAe,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;QACzD,MAAM,gBAAgB,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;QAClD,MAAM,aAAa,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAGnD,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,eAAe,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACjG,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAEtE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,eAAe,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACvD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,gBAAgB,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAExD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,eAAe,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACxE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,eAAe,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAExE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,aAAa,EAAE,gBAAgB,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACzE,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,gBAAgB,EAAE,aAAa,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;IAC7E,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,iCAAiC,EAAE,GAAG,EAAE;QACvC,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,KAAK,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QACjD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,EAAE,CAAC;QAEnD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QAC1C,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACtD,MAAM,CAAC,GAAG,EAAE,CAAC,OAAO,CAAC,KAAK,CAAC,CAAC,KAAK,EAAE,KAAK,EAAE,KAAK,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAA;IAChE,CAAC,CAAC,CAAA;IAGF,EAAE,CAAC,uCAAuC,EAAE,GAAG,EAAE;QAC7C,MAAM,KAAK,GAAG,EAAE,CAAC,aAAa,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,CAAC;QAE3C,MAAM,OAAO,GAAG,IAAI,kBAAkB,CAAC,EAAE,QAAQ,EAAE,CAAC,EAAE,QAAQ,EAAE,KAAK,CAAC,KAAK,CAAC,EAAE,CAAC,CAAC,CAAC,CAAE,EAAE,CAAC,CAAC;QACvF,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,SAAS,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACzD,CAAC,CAAC,CAAA;AACN,CAAC,CAAC,CAAA"}
@@ -1,30 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- export declare function diceBinaryStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
3
- export declare function diceBinaryGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
4
- export declare function diceCategoricalStandard(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
5
- export declare function diceCategoricalGeneralized(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
6
- export declare function diceCategoricalGlobal(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
7
- /**
8
- * Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
9
- * Both have equal weight.
10
- *
11
- * @param y_true the label tensor
12
- * @param y_pred the prediction tensor (not sparse)
13
- * @returns a tensor of shape `[ batch ]`
14
- */
15
- export declare function diceBinaryCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
16
- /**
17
- * Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
18
- * Both have equal weight. Expects dense (non-sparse) label tensors.
19
- *
20
- * This does not support sparse tensors because TFJS's
21
- * sparseCategoricalCrossentropy loss onehots the label
22
- * and calls categoricalCrossentropy. See
23
- * https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
24
- *
25
- * @param y_true the label
26
- * @param y_pred the prediction tensor (not sparse)
27
- * @returns a tensor of shape `[ batch ]`
28
- */
29
- export declare function diceCategoricalCrossentropy(y_true: tf.Tensor, y_pred: tf.Tensor): tf.Tensor;
30
- //# sourceMappingURL=dice.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"dice.d.ts","sourceRoot":"","sources":["../../../src/losses/dice.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAWvC,wBAAgB,kBAAkB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAclF;AAQD,wBAAgB,gBAAgB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAahF;AAOD,wBAAgB,uBAAuB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAUvF;AAOD,wBAAgB,0BAA0B,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAgB1F;AAOD,wBAAgB,qBAAqB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAWrF;AAOD;;;;;;;GAOG;AACH,wBAAgB,sBAAsB,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAMtF;AAOD;;;;;;;;;;;;GAYG;AACH,wBAAgB,2BAA2B,CAAC,MAAM,EAAE,EAAE,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,CAAC,MAAM,GAAG,EAAE,CAAC,MAAM,CAM3F"}
@@ -1,93 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { categoricalCrossentropy, binaryCrossentropy } from "@tensorflow/tfjs-layers/dist/losses";
3
- const epsilon = 1e-7;
4
- const REDUCE_HW = [1, 2]; // reduce over width and height
5
- const REDUCE_BHW = [0, 1, 2]; // reduce over batch, width, height
6
- const REDUCE_BHWC = [0, 1, 2, 3]; // reduce all dimensions
7
- // Standard (Sorensen) Dice Loss
8
- export function diceBinaryStandard(y_true, y_pred) {
9
- const y_true_flat = tf.reshape(y_true, [y_true.shape[0], -1]);
10
- const y_pred_flat = tf.reshape(y_pred, [y_pred.shape[0], -1]);
11
- const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat), 1);
12
- const union = tf.add(tf.sum(y_true_flat, 1), tf.sum(y_pred_flat, 1));
13
- const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
14
- return tf.scalar(1).sub(dice);
15
- }
16
- // prevents minification of function name which TFJS relies on
17
- Object.defineProperty(diceBinaryStandard, "name", { value: "diceBinaryStandard", configurable: false });
18
- // https://github.com/keras-team/keras/blob/v3.3.3/keras/src/losses/losses.py#L1983-L2010
19
- export function diceBinaryGlobal(y_true, y_pred) {
20
- const y_true_flat = tf.reshape(y_true, [-1]);
21
- const y_pred_flat = tf.reshape(y_pred, [-1]);
22
- const intersection = tf.sum(tf.mul(y_true_flat, y_pred_flat));
23
- const union = tf.add(tf.sum(y_true_flat), tf.sum(y_pred_flat));
24
- const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
25
- return tf.scalar(1).sub(dice);
26
- }
27
- // prevents minification of function name which TFJS relies on
28
- Object.defineProperty(diceBinaryGlobal, "name", { value: "diceBinaryGlobal", configurable: false });
29
- export function diceCategoricalStandard(y_true, y_pred) {
30
- const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_HW);
31
- const union = tf.add(y_true, y_pred).sum(REDUCE_HW);
32
- const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
33
- return tf.scalar(1).sub(tf.mean(dice, -1));
34
- }
35
- // prevents minification of function name which TFJS relies on
36
- Object.defineProperty(diceCategoricalStandard, "name", { value: "diceCategoricalStandard", configurable: false });
37
- export function diceCategoricalGeneralized(y_true, y_pred) {
38
- // this is done twice so we calculate it once
39
- const y_true_sum = y_true.sum(REDUCE_BHW);
40
- const weighting = tf.div(1, y_true_sum.square().add(epsilon));
41
- const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHW).mul(weighting).sum();
42
- const union = tf.add(y_true_sum, y_pred.sum(REDUCE_BHW)).mul(weighting).sum();
43
- const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
44
- return tf.scalar(1).sub(dice);
45
- }
46
- // prevents minification of function name which TFJS relies on
47
- Object.defineProperty(diceCategoricalGeneralized, "name", { value: "diceCategoricalGeneralized", configurable: false });
48
- export function diceCategoricalGlobal(y_true, y_pred) {
49
- const intersection = tf.sum(tf.mul(y_true, y_pred), REDUCE_BHWC);
50
- const union = tf.add(tf.sum(y_true, REDUCE_BHWC), tf.sum(y_pred, REDUCE_BHWC));
51
- const dice = tf.div(intersection.mul(2).add(epsilon), union.add(epsilon));
52
- return tf.scalar(1).sub(dice);
53
- }
54
- // prevents minification of function name which TFJS relies on
55
- Object.defineProperty(diceCategoricalGlobal, "name", { value: "diceCategoricalGlobal", configurable: false });
56
- /**
57
- * Calculates the Sorensen-Dice coefficient and the binary cross entropy losses.
58
- * Both have equal weight.
59
- *
60
- * @param y_true the label tensor
61
- * @param y_pred the prediction tensor (not sparse)
62
- * @returns a tensor of shape `[ batch ]`
63
- */
64
- export function diceBinaryCrossentropy(y_true, y_pred) {
65
- // reduce cross entropy shape from [B, H, W] to [B] to match dice
66
- const bce = binaryCrossentropy(y_true, y_pred).mean(REDUCE_HW);
67
- const dice = diceBinaryStandard(y_true, y_pred);
68
- return tf.add(bce.mul(0.5), dice.mul(0.5));
69
- }
70
- // prevents minification of function name which TFJS relies on
71
- Object.defineProperty(diceBinaryCrossentropy, "name", { value: "diceBinaryCrossentropy", configurable: false });
72
- /**
73
- * Calculates the Sorensen-Dice coefficient and the categorical cross entropy losses.
74
- * Both have equal weight. Expects dense (non-sparse) label tensors.
75
- *
76
- * This does not support sparse tensors because TFJS's
77
- * sparseCategoricalCrossentropy loss onehots the label
78
- * and calls categoricalCrossentropy. See
79
- * https://github.com/tensorflow/tfjs/blob/0fc04d958ea592f3b8db79a8b3b497b5c8904097/tfjs-layers/src/losses.ts#L143-L146
80
- *
81
- * @param y_true the label
82
- * @param y_pred the prediction tensor (not sparse)
83
- * @returns a tensor of shape `[ batch ]`
84
- */
85
- export function diceCategoricalCrossentropy(y_true, y_pred) {
86
- // reduce cross entropy shape from [B, H, W] to [B] to match dice
87
- const cce = categoricalCrossentropy(y_true, y_pred).mean(REDUCE_HW);
88
- const dice = diceCategoricalStandard(y_true, y_pred);
89
- return tf.add(cce.mul(0.5), dice.mul(0.5));
90
- }
91
- // prevents minification of function name which TFJS relies on
92
- Object.defineProperty(diceCategoricalCrossentropy, "name", { value: "diceCategoricalCrossentropy", configurable: false });
93
- //# sourceMappingURL=dice.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"dice.js","sourceRoot":"","sources":["../../../src/losses/dice.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AACvC,OAAO,EAAE,uBAAuB,EAAE,kBAAkB,EAAE,MAAM,qCAAqC,CAAC;AAElG,MAAM,OAAO,GAAG,IAAI,CAAC;AAErB,MAAM,SAAS,GAAG,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,+BAA+B;AACzD,MAAM,UAAU,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,mCAAmC;AACjE,MAAM,WAAW,GAAG,CAAC,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,wBAAwB;AAG1D,gCAAgC;AAChC,MAAM,UAAU,kBAAkB,CAAC,MAAiB,EAAE,MAAiB;IAEnE,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IAC9D,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;IAE9D,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,WAAW,CAAC,EAAE,CAAC,CAAC,CAAC;IACjE,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,CAAC,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,CAAC,CAAC,CAAC,CAAC;IAErE,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,kBAAkB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,oBAAoB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGxG,yFAAyF;AACzF,MAAM,UAAU,gBAAgB,CAAC,MAAiB,EAAE,MAAiB;IACjE,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC7C,MAAM,WAAW,GAAG,EAAE,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAE7C,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,EAAE,WAAW,CAAC,CAAC,CAAC;IAC9D,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,WAAW,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,WAAW,CAAC,CAAC,CAAC;IAE/D,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,gBAAgB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,kBAAkB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGpG,MAAM,UAAU,uBAAuB,CAAC,MAAiB,EAAE,MAAiB;IACxE,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,SAAS,CAAC,CAAC;IAC/D,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC;IAEpD,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,uBAAuB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,yBAAyB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGlH,MAAM,UAAU,0BAA0B,CAAC,MAAiB,EAAE,MAAiB;IAE3E,6CAA6C;IAC7C,MAAM,UAAU,GAAG,MAAM,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC;IAE1C,MAAM,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,EAAE,UAAU,CAAC,MAAM,EAAE,CAAC,GAAG,CAAC,OAAO,CAAC,CAAC,CAAC;IAE9D,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,UAAU,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,GAAG,EAAE,CAAC;IACrF,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,UAAU,EAAE,MAAM,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,CAAC,GAAG,CAAC,SAAS,CAAC,CAAC,GAAG,EAAE,CAAC;IAE9E,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,0BAA0B,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,4BAA4B,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGxH,MAAM,UAAU,qBAAqB,CAAC,MAAiB,EAAE,MAAiB;IAEtE,MAAM,YAAY,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,EAAE,WAAW,CAAC,CAAC;IACjE,MAAM,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,WAAW,CAAC,EAAE,EAAE,CAAC,GAAG,CAAC,MAAM,EAAE,WAAW,CAAC,CAAC,CAAC;IAE/E,MAAM,IAAI,GAAG,EAAE,CAAC,GAAG,CACf,YAAY,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,OAAO,CAAC,EAChC,KAAK,CAAC,GAAG,CAAC,OAAO,CAAC,CACrB,CAAC;IAEF,OAAO,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;AAClC,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,qBAAqB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,uBAAuB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAG9G;;;;;;;GAOG;AACH,MAAM,UAAU,sBAAsB,CAAC,MAAiB,EAAE,MAAiB;IACvE,iEAAiE;IACjE,MAAM,GAAG,GAAG,kBAAkB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IAC/D,MAAM,IAAI,GAAG,kBAAkB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;IAEhD,OAAO,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,sBAAsB,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,wBAAwB,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAGhH;;;;;;;;;;;;GAYG;AACH,MAAM,UAAU,2BAA2B,CAAC,MAAiB,EAAE,MAAiB;IAC5E,iEAAiE;IACjE,MAAM,GAAG,GAAG,uBAAuB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IACpE,MAAM,IAAI,GAAG,uBAAuB,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;IAErD,OAAO,EAAE,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,GAAG,CAAC,EAAE,IAAI,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC,CAAC;AAC/C,CAAC;AAGD,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,2BAA2B,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,6BAA6B,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC"}
@@ -1,2 +0,0 @@
1
- export * from "./dice";
2
- //# sourceMappingURL=index.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.d.ts","sourceRoot":"","sources":["../../../src/losses/index.ts"],"names":[],"mappings":"AAAA,cAAc,QAAQ,CAAC"}
@@ -1,2 +0,0 @@
1
- export * from "./dice";
2
- //# sourceMappingURL=index.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"index.js","sourceRoot":"","sources":["../../../src/losses/index.ts"],"names":[],"mappings":"AAAA,cAAc,QAAQ,CAAC"}
@@ -1,20 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- /**
3
- * Generate a causal mask used in self-attention to prevent tokens from looking
4
- * ahead. The values in the upper right portion of the mask matrix are set to
5
- * -1e7 so that they have no impact during scaled dot product attention.
6
- */
7
- export declare function causal(query_seq_length: number, key_seq_length: number): tf.Tensor<tf.Rank>;
8
- /**
9
- * Generate a self-attention mask that prevents packed sequences from cross document
10
- * boundaries and attending to each other. The result is a tensor of diagonally
11
- * positioned blocks of zeroes (allow attention) and -1e7 values (prevent attention).
12
- * The latter is scored zero during the scaled dot product attention's softmax operation.
13
- *
14
- * @param boundaries an array of ones (denotes start of a new sample or docment) and zeroes
15
- *
16
- * Example boundary of 3 samples that are packed into one:
17
- * `[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]`
18
- */
19
- export declare function packing(boundaries: Int32Array): tf.Tensor<tf.Rank>;
20
- //# sourceMappingURL=masks.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"masks.d.ts","sourceRoot":"","sources":["../../src/masks.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;GAIG;AACH,wBAAgB,MAAM,CAAC,gBAAgB,EAAE,MAAM,EAAE,cAAc,EAAE,MAAM,sBAItE;AAGD;;;;;;;;;;GAUG;AACH,wBAAgB,OAAO,CAAC,UAAU,EAAE,UAAU,sBAc7C"}
package/dist/src/masks.js DELETED
@@ -1,37 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- /**
3
- * Generate a causal mask used in self-attention to prevent tokens from looking
4
- * ahead. The values in the upper right portion of the mask matrix are set to
5
- * -1e7 so that they have no impact during scaled dot product attention.
6
- */
7
- export function causal(query_seq_length, key_seq_length) {
8
- return tf.linalg.bandPart(tf.ones([query_seq_length, key_seq_length]), -1, 0)
9
- .sub(1)
10
- .mul(1e7);
11
- }
12
- /**
13
- * Generate a self-attention mask that prevents packed sequences from cross document
14
- * boundaries and attending to each other. The result is a tensor of diagonally
15
- * positioned blocks of zeroes (allow attention) and -1e7 values (prevent attention).
16
- * The latter is scored zero during the scaled dot product attention's softmax operation.
17
- *
18
- * @param boundaries an array of ones (denotes start of a new sample or docment) and zeroes
19
- *
20
- * Example boundary of 3 samples that are packed into one:
21
- * `[1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]`
22
- */
23
- export function packing(boundaries) {
24
- // see images at
25
- // https://reddit.com/r/LocalLLaMA/comments/197efaz/training_llama_mistral_and_mixtralmoe_faster_with/
26
- return tf.tidy(() => {
27
- // cumsum transforms the tensor such that each sequence in the pack gets its own id,
28
- const partitions = tf.tensor1d(boundaries).cumsum();
29
- return partitions.expandDims(1)
30
- .equal(partitions.expandDims(0))
31
- .sub(1)
32
- .mul(1e7)
33
- // introduce a head dimension so it can be broadcasted
34
- .expandDims(0);
35
- });
36
- }
37
- //# sourceMappingURL=masks.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"masks.js","sourceRoot":"","sources":["../../src/masks.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,MAAM,kBAAkB,CAAC;AAGvC;;;;GAIG;AACH,MAAM,UAAU,MAAM,CAAC,gBAAwB,EAAE,cAAsB;IACnE,OAAO,EAAE,CAAC,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,CAAC,gBAAgB,EAAE,cAAc,CAAC,CAAC,EAAE,CAAC,CAAC,EAAE,CAAC,CAAC;SACxE,GAAG,CAAC,CAAC,CAAC;SACN,GAAG,CAAC,GAAG,CAAC,CAAC;AAClB,CAAC;AAGD;;;;;;;;;;GAUG;AACH,MAAM,UAAU,OAAO,CAAC,UAAsB;IAC1C,gBAAgB;IAChB,sGAAsG;IACtG,OAAO,EAAE,CAAC,IAAI,CAAC,GAAG,EAAE;QAChB,oFAAoF;QACpF,MAAM,UAAU,GAAG,EAAE,CAAC,QAAQ,CAAC,UAAU,CAAC,CAAC,MAAM,EAAE,CAAC;QAEpD,OAAO,UAAU,CAAC,UAAU,CAAC,CAAC,CAAC;aAC1B,KAAK,CAAC,UAAU,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;aAC/B,GAAG,CAAC,CAAC,CAAC;aACN,GAAG,CAAC,GAAG,CAAC;YACT,sDAAsD;aACrD,UAAU,CAAC,CAAC,CAAC,CAAC;IACvB,CAAC,CAAC,CAAA;AACN,CAAC"}
@@ -1,20 +0,0 @@
1
- import { Tensor } from "@tensorflow/tfjs";
2
- /**
3
- * Applies the recall metric with the prediction rounded based on a threshold
4
- *
5
- * @param y_true the label tensor
6
- * @param y_pred the prediction tensor
7
- * @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
8
- * @returns
9
- */
10
- export declare function recall(y_true: Tensor, y_pred: Tensor, threshold?: number): Tensor<import("@tensorflow/tfjs-core").Rank>;
11
- /**
12
- * Applies the precision metric with the prediction rounded based on a threshold
13
- *
14
- * @param y_true the label tensor
15
- * @param y_pred the prediction tensor
16
- * @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
17
- * @returns
18
- */
19
- export declare function precision(y_true: Tensor, y_pred: Tensor, threshold?: number): Tensor<import("@tensorflow/tfjs-core").Rank>;
20
- //# sourceMappingURL=metrics.d.ts.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"metrics.d.ts","sourceRoot":"","sources":["../../src/metrics.ts"],"names":[],"mappings":"AAAA,OAAO,EAAW,MAAM,EAAE,MAAM,kBAAkB,CAAC;AAGnD;;;;;;;GAOG;AACH,wBAAgB,MAAM,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,SAAS,GAAE,MAAY,gDAE7E;AAKD;;;;;;;GAOG;AACH,wBAAgB,SAAS,CAAC,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,MAAM,EAAE,SAAS,GAAE,MAAY,gDAEhF"}
@@ -1,28 +0,0 @@
1
- import { metrics } from "@tensorflow/tfjs";
2
- /**
3
- * Applies the recall metric with the prediction rounded based on a threshold
4
- *
5
- * @param y_true the label tensor
6
- * @param y_pred the prediction tensor
7
- * @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
8
- * @returns
9
- */
10
- export function recall(y_true, y_pred, threshold = 0.5) {
11
- return metrics.recall(y_true, y_pred.greaterEqual(threshold));
12
- }
13
- // prevents minification of function name which TFJS relies on
14
- Object.defineProperty(recall, "name", { value: "recall", configurable: false });
15
- /**
16
- * Applies the precision metric with the prediction rounded based on a threshold
17
- *
18
- * @param y_true the label tensor
19
- * @param y_pred the prediction tensor
20
- * @param threshold threshold value to be considered a positive prediction, defaults to `0.5`
21
- * @returns
22
- */
23
- export function precision(y_true, y_pred, threshold = 0.5) {
24
- return metrics.precision(y_true, y_pred.greaterEqual(threshold));
25
- }
26
- // prevents minification of function name which TFJS relies on
27
- Object.defineProperty(precision, "name", { value: "precision", configurable: false });
28
- //# sourceMappingURL=metrics.js.map
@@ -1 +0,0 @@
1
- {"version":3,"file":"metrics.js","sourceRoot":"","sources":["../../src/metrics.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,OAAO,EAAU,MAAM,kBAAkB,CAAC;AAGnD;;;;;;;GAOG;AACH,MAAM,UAAU,MAAM,CAAC,MAAc,EAAE,MAAc,EAAE,YAAoB,GAAG;IAC1E,OAAO,OAAO,CAAC,MAAM,CAAC,MAAM,EAAE,MAAM,CAAC,YAAY,CAAC,SAAS,CAAC,CAAC,CAAC;AAClE,CAAC;AAED,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,MAAM,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,QAAQ,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC;AAEhF;;;;;;;GAOG;AACH,MAAM,UAAU,SAAS,CAAC,MAAc,EAAE,MAAc,EAAE,YAAoB,GAAG;IAC7E,OAAO,OAAO,CAAC,SAAS,CAAC,MAAM,EAAE,MAAM,CAAC,YAAY,CAAC,SAAS,CAAC,CAAC,CAAC;AACrE,CAAC;AAED,8DAA8D;AAC9D,MAAM,CAAC,cAAc,CAAC,SAAS,EAAE,MAAM,EAAE,EAAE,KAAK,EAAE,WAAW,EAAE,YAAY,EAAE,KAAK,EAAE,CAAC,CAAC"}
@@ -1,94 +0,0 @@
1
- import * as tf from "@tensorflow/tfjs";
2
- import { type LossOrMetricFn } from "@/tfjs_types";
3
- import { LlmModel, type LlmModelArgs } from "@/models/llm_model";
4
- import { KvCacheContainer } from "@/kv_cache";
5
- import { type DisposeResult } from "@tensorflow/tfjs-layers/dist/engine/topology";
6
- export interface GptModelArgs extends LlmModelArgs {
7
- /**
8
- * Number of heads per attention layer.
9
- */
10
- numHeads: number;
11
- /**
12
- * Number of GPT decoder blocks.
13
- */
14
- numLayers: number;
15
- /**
16
- * The embedding size of each token.
17
- */
18
- embedDim: number;
19
- /**
20
- * The vocabulary size of the embedding layer and number of units of the output
21
- * layer. This is also the tokenizer vocabulary size.
22
- */
23
- vocabSize: number;
24
- /**
25
- * Pad the embeddings' vocab size and output layer's units to the next nearest
26
- * multiple of 64 to optimize hardware efficiency. Defaults to `true`.
27
- *
28
- * For example: if a tokenizer has 50,257 tokens, the model uses 50,304 for the
29
- * vocab size and output units count.
30
- */
31
- padToMultipleOf64?: boolean;
32
- }
33
- /**
34
- * This is a subclass of tf.Sequential that creating a GPT-like model and
35
- * automatically handles padding (and masking) the vocab size for hardware
36
- * efficiency.
37
- *
38
- * Example:
39
- *
40
- * ```javascript
41
- *
42
- * const model = new GptModel({ numLayers: 1, numHeads: 1, embedDim: 16, vocabSize: 64 });
43
- * model.compile({ loss: "sparseCategoricalCrossentropy", optimizer: "adam" });
44
- *
45
- * // use fitDataset() instead of fit for masking support
46
- * model.fitDataset(your_batched_generator_dataset, { epochs: 1 });
47
- *
48
- * const kv_cache = new KvCacheContainer(your_preferred_max_sequence_length);
49
- *
50
- * // use generate() and predictNextToken() instead of predict() for masking and auto memory cleanup
51
- * model.generate(tokenized_tensor1d_input, kv_cache, onPredict_callback)
52
- *
53
- *
54
- * ```
55
- */
56
- export declare class GptModel extends LlmModel {
57
- static className: string;
58
- protected readonly numHeads: number;
59
- protected readonly numLayers: number;
60
- protected readonly embedDim: number;
61
- protected readonly vocabSize: number;
62
- protected readonly padToMultipleOf64: boolean;
63
- protected readonly vocabSizePadded: number;
64
- protected vocab_padding_mask?: tf.Tensor1D;
65
- /**
66
- * DO NOT add layers in the constructor or it will break tf.loadLayersModel().
67
- * It should be done in build() instead.
68
- */
69
- constructor(args: GptModelArgs);
70
- protected fitBatch(xs: tf.Tensor, ys: tf.Tensor, loss_mask: tf.Tensor | undefined, loss_function: LossOrMetricFn, other_masks?: {
71
- [key: string]: tf.Tensor | undefined;
72
- }): {
73
- y_pred: tf.Tensor<tf.Rank>;
74
- loss: tf.Scalar;
75
- };
76
- /**
77
- * Overrides LlmModel.predictNextToken to add softmax before argMax because the final
78
- * dense layer doesn't have an activation.
79
- *
80
- * TODO: implement temperature and multinomial sampling so that the model has varied outputs
81
- */
82
- predictNextToken(input: tf.Tensor2D, kv_cache: KvCacheContainer): tf.Tensor2D;
83
- build(inputShape?: tf.Shape | tf.Shape[]): void;
84
- dispose(): DisposeResult;
85
- getConfig(): {
86
- numHeads: number;
87
- numLayers: number;
88
- embedDim: number;
89
- vocabSize: number;
90
- vocabSizePadded: number;
91
- padToMultipleOf64: boolean;
92
- };
93
- }
94
- //# sourceMappingURL=gpt_model.d.ts.map