@epfml/discojs 3.0.1-p20241119093954.0 → 3.0.1-p20241206133538.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -149,6 +149,8 @@ export class Client extends EventEmitter {
149
149
  }
150
150
  url.pathname += `tasks/${this.task.id}/model.json`;
151
151
  const response = await fetch(url);
152
+ if (!response.ok)
153
+ throw new Error(`fetch: HTTP status ${response.status}`);
152
154
  const encoded = new Uint8Array(await response.arrayBuffer());
153
155
  return await serialization.model.decode(encoded);
154
156
  }
@@ -2,7 +2,7 @@ import createDebug from "debug";
2
2
  import { serialization } from "../../index.js";
3
3
  import { Client, shortenId } from "../client.js";
4
4
  import { type } from "../messages.js";
5
- import { waitMessage, waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
5
+ import { waitMessage, WebSocketServer, } from "../event_connection.js";
6
6
  import * as messages from "./messages.js";
7
7
  const debug = createDebug("discojs:client:federated");
8
8
  /**
@@ -53,7 +53,7 @@ export class FederatedClient extends Client {
53
53
  type: type.ClientConnected,
54
54
  };
55
55
  this.server.send(msg);
56
- const { id, waitForMoreParticipants, payload, round, nbOfParticipants } = await waitMessageWithTimeout(this.server, type.NewFederatedNodeInfo);
56
+ const { id, waitForMoreParticipants, payload, round, nbOfParticipants } = await waitMessage(this.server, type.NewFederatedNodeInfo);
57
57
  // This should come right after receiving the message to make sure
58
58
  // we don't miss a subsequent message from the server
59
59
  // We check if the server is telling us to wait for more participants
@@ -25,15 +25,22 @@ export declare class Dataset<T> implements AsyncIterable<T> {
25
25
  * @param ratio between 0 (all on left) and 1 (all on right)
26
26
  */
27
27
  split(ratio: number): [Dataset<T>, Dataset<T>];
28
- /** Slice into chunks
28
+ /** Create batches of `size` elements with potential overlap.
29
+ * Last batch is smaller if dataset isn't perfectly divisible
29
30
  *
30
- * Last slice is smaller if dataset isn't perfectly divisible
31
+ * If overlap is set to a positive integer, the last `overlap` elements of a batch
32
+ * are the first `overlap` elements of the next batch.
33
+ *
34
+ * This method is tailored to create text sequences where each token's label is the following token.
35
+ * In order to have a label for the last token of the input sequence, we include the first token
36
+ * of the next sequence (i.e. with an overlap of 1).
31
37
  *
32
38
  * @param size count of element per chunk
39
+ * @param overlap number of elements overlapping between two consecutive batches
33
40
  */
34
- batch(size: number): Dataset<Batched<T>>;
35
- /** Flatten chunks */
36
- unbatch<U>(this: Dataset<Batched<U>>): Dataset<U>;
41
+ batch(size: number, overlap?: number): Dataset<Batched<T>>;
42
+ /** Flatten batches/arrays of elements */
43
+ flatten<U>(this: Dataset<DatasetLike<U>>): Dataset<U>;
37
44
  /** Join side-by-side
38
45
  *
39
46
  * Stops as soon as one runs out
@@ -41,6 +48,12 @@ export declare class Dataset<T> implements AsyncIterable<T> {
41
48
  * @param other right side
42
49
  **/
43
50
  zip<U>(other: Dataset<U> | DatasetLike<U>): Dataset<[T, U]>;
51
+ /**
52
+ * Repeat the dataset `times` times
53
+ * @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
54
+ * @returns a dataset repeated `times` times
55
+ */
56
+ repeat(times?: number): Dataset<T>;
44
57
  /** Compute size
45
58
  *
46
59
  * This is a costly operation as we need to go through the whole Dataset.
@@ -1,6 +1,22 @@
1
1
  import createDebug from "debug";
2
2
  import { List, Range } from "immutable";
3
3
  const debug = createDebug("discojs:dataset");
4
+ /** Convert a DatasetLike object to an async generator */
5
+ async function* datasetLikeToGenerator(content) {
6
+ let iter;
7
+ if (typeof content === "function")
8
+ iter = content();
9
+ else if (Symbol.asyncIterator in content)
10
+ iter = content[Symbol.asyncIterator]();
11
+ else
12
+ iter = content[Symbol.iterator]();
13
+ while (true) {
14
+ const result = await iter.next();
15
+ if (result.done === true)
16
+ break;
17
+ yield result.value;
18
+ }
19
+ }
4
20
  /** Immutable series of data */
5
21
  export class Dataset {
6
22
  #content;
@@ -11,19 +27,7 @@ export class Dataset {
11
27
  */
12
28
  constructor(content) {
13
29
  this.#content = async function* () {
14
- let iter;
15
- if (typeof content === "function")
16
- iter = content();
17
- else if (Symbol.asyncIterator in content)
18
- iter = content[Symbol.asyncIterator]();
19
- else
20
- iter = content[Symbol.iterator]();
21
- while (true) {
22
- const result = await iter.next();
23
- if (result.done === true)
24
- break;
25
- yield result.value;
26
- }
30
+ yield* datasetLikeToGenerator(content);
27
31
  };
28
32
  }
29
33
  [Symbol.asyncIterator]() {
@@ -87,19 +91,31 @@ export class Dataset {
87
91
  }.bind(this)),
88
92
  ];
89
93
  }
90
- /** Slice into chunks
94
+ /** Create batches of `size` elements with potential overlap.
95
+ * Last batch is smaller if dataset isn't perfectly divisible
96
+ *
97
+ * If overlap is set to a positive integer, the last `overlap` elements of a batch
98
+ * are the first `overlap` elements of the next batch.
91
99
  *
92
- * Last slice is smaller if dataset isn't perfectly divisible
100
+ * This method is tailored to create text sequences where each token's label is the following token.
101
+ * In order to have a label for the last token of the input sequence, we include the first token
102
+ * of the next sequence (i.e. with an overlap of 1).
93
103
  *
94
104
  * @param size count of element per chunk
105
+ * @param overlap number of elements overlapping between two consecutive batches
95
106
  */
96
- batch(size) {
107
+ batch(size, overlap = 0) {
97
108
  if (size <= 0 || !Number.isInteger(size))
98
109
  throw new Error("invalid size");
110
+ if (overlap >= size || !Number.isInteger(overlap))
111
+ throw new Error("invalid overlap");
99
112
  return new Dataset(async function* () {
100
113
  const iter = this[Symbol.asyncIterator]();
114
+ let overlapped = List();
101
115
  for (;;) {
102
- const batch = List(await Promise.all(Range(0, size).map(() => iter.next()))).flatMap((res) => {
116
+ const batch = List(
117
+ // get the first elements of the next batch
118
+ await Promise.all(Range(overlapped.size, size).map(() => iter.next()))).flatMap((res) => {
103
119
  if (res.done)
104
120
  return [];
105
121
  else
@@ -107,18 +123,21 @@ export class Dataset {
107
123
  });
108
124
  if (batch.isEmpty())
109
125
  break;
110
- yield batch;
126
+ // yield the current batch with the first elements of the next batch
127
+ yield overlapped.concat(batch);
128
+ overlapped = batch.takeLast(overlap);
111
129
  // iterator couldn't generate more
112
- if (batch.size < size)
130
+ if (batch.size < size - overlap)
113
131
  break;
114
132
  }
115
133
  }.bind(this));
116
134
  }
117
- /** Flatten chunks */
118
- unbatch() {
135
+ /** Flatten batches/arrays of elements */
136
+ flatten() {
119
137
  return new Dataset(async function* () {
120
- for await (const batch of this)
121
- yield* batch;
138
+ for await (const batch of this) {
139
+ yield* datasetLikeToGenerator(batch);
140
+ }
122
141
  }.bind(this));
123
142
  }
124
143
  /** Join side-by-side
@@ -141,6 +160,22 @@ export class Dataset {
141
160
  }
142
161
  }.bind(this));
143
162
  }
163
+ /**
164
+ * Repeat the dataset `times` times
165
+ * @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
166
+ * @returns a dataset repeated `times` times
167
+ */
168
+ repeat(times) {
169
+ if (times !== undefined && (!Number.isInteger(times) || times < 1))
170
+ throw new Error("times needs to be a positive integer or undefined");
171
+ return new Dataset(async function* () {
172
+ let loop = 0;
173
+ do {
174
+ yield* this;
175
+ loop++;
176
+ } while (times === undefined || loop < times);
177
+ }.bind(this));
178
+ }
144
179
  /** Compute size
145
180
  *
146
181
  * This is a costly operation as we need to go through the whole Dataset.
@@ -4,3 +4,4 @@ export type Batched<T> = List<T>;
4
4
  export { Image };
5
5
  export type Tabular = Partial<Record<string, string>>;
6
6
  export type Text = string;
7
+ export type TokenizedText = List<number>;
@@ -4,3 +4,4 @@ export { mnist } from './mnist.js';
4
4
  export { simpleFace } from './simple_face.js';
5
5
  export { titanic } from './titanic.js';
6
6
  export { wikitext } from './wikitext.js';
7
+ export { tinderDog } from './tinder_dog.js';
@@ -4,3 +4,4 @@ export { mnist } from './mnist.js';
4
4
  export { simpleFace } from './simple_face.js';
5
5
  export { titanic } from './titanic.js';
6
6
  export { wikitext } from './wikitext.js';
7
+ export { tinderDog } from './tinder_dog.js';
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const tinderDog: TaskProvider<'image'>;
@@ -0,0 +1,72 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { models } from '../index.js';
3
+ export const tinderDog = {
4
+ getTask() {
5
+ return {
6
+ id: 'tinder_dog',
7
+ displayInformation: {
8
+ taskTitle: 'GDHF 2024 | TinderDog',
9
+ summary: {
10
+ preview: 'Which dog is the cutest....or not?',
11
+ overview: "Binary classification model for dog cuteness."
12
+ },
13
+ model: 'The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1',
14
+ dataFormatInformation: 'Accepted image formats are .png .jpg and .jpeg.',
15
+ dataExampleText: '',
16
+ dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog_preview.png',
17
+ sampleDatasetLink: 'https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip',
18
+ sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, pick one of the data splits (the folder 0 for example) and use the CSV option below to select the file named "labels.csv". You can now connect the images located in the same folder.'
19
+ },
20
+ trainingInformation: {
21
+ epochs: 10,
22
+ roundDuration: 2,
23
+ validationSplit: 0, // nicer plot for GDHF demo
24
+ batchSize: 10,
25
+ dataType: 'image',
26
+ IMAGE_H: 64,
27
+ IMAGE_W: 64,
28
+ LABEL_LIST: ['Cute dogs', 'Less cute dogs'],
29
+ scheme: 'federated',
30
+ aggregationStrategy: 'mean',
31
+ minNbOfParticipants: 3,
32
+ tensorBackend: 'tfjs'
33
+ }
34
+ };
35
+ },
36
+ async getModel() {
37
+ const seed = 42; // set a seed to ensure reproducibility during GDHF demo
38
+ const imageHeight = this.getTask().trainingInformation.IMAGE_H;
39
+ const imageWidth = this.getTask().trainingInformation.IMAGE_W;
40
+ const imageChannels = 3;
41
+ const model = tf.sequential();
42
+ model.add(tf.layers.conv2d({
43
+ inputShape: [imageHeight, imageWidth, imageChannels],
44
+ kernelSize: 5,
45
+ filters: 8,
46
+ activation: 'relu',
47
+ kernelInitializer: tf.initializers.heNormal({ seed })
48
+ }));
49
+ model.add(tf.layers.conv2d({
50
+ kernelSize: 5, filters: 16, activation: 'relu',
51
+ kernelInitializer: tf.initializers.heNormal({ seed })
52
+ }));
53
+ model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
54
+ model.add(tf.layers.dropout({ rate: 0.25, seed }));
55
+ model.add(tf.layers.flatten());
56
+ model.add(tf.layers.dense({
57
+ units: 32, activation: 'relu',
58
+ kernelInitializer: tf.initializers.heNormal({ seed })
59
+ }));
60
+ model.add(tf.layers.dropout({ rate: 0.25, seed }));
61
+ model.add(tf.layers.dense({
62
+ units: 2, activation: 'softmax',
63
+ kernelInitializer: tf.initializers.heNormal({ seed })
64
+ }));
65
+ model.compile({
66
+ optimizer: tf.train.adam(0.0005),
67
+ loss: 'categoricalCrossentropy',
68
+ metrics: ['accuracy']
69
+ });
70
+ return Promise.resolve(new models.TFJS('image', model));
71
+ }
72
+ };
@@ -31,14 +31,16 @@ export const wikitext = {
31
31
  // But if set to 0 then the webapp doesn't display the validation metrics
32
32
  validationSplit: 0.1,
33
33
  roundDuration: 2,
34
- batchSize: 1, // If set too high (e.g. 16) firefox raises a WebGL error
34
+ batchSize: 8, // If set too high firefox raises a WebGL error
35
35
  tokenizer: 'Xenova/gpt2',
36
- maxSequenceLength: 128,
36
+ contextLength: 64,
37
37
  tensorBackend: 'gpt'
38
38
  }
39
39
  };
40
40
  },
41
41
  getModel() {
42
- return Promise.resolve(new models.GPT());
42
+ return Promise.resolve(new models.GPT({
43
+ contextLength: this.getTask().trainingInformation.contextLength,
44
+ }));
43
45
  }
44
46
  };
@@ -1,8 +1,8 @@
1
1
  type GPTModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
2
2
  export interface GPTConfig {
3
3
  lr: number;
4
- blockSize: number;
5
- vocabSize: number;
4
+ contextLength: number;
5
+ vocabSize?: number;
6
6
  modelType: GPTModelType;
7
7
  name?: string;
8
8
  evaluate?: boolean;
@@ -11,22 +11,27 @@ export interface GPTConfig {
11
11
  maxIter?: number;
12
12
  weightDecay?: number;
13
13
  verbose?: 0 | 1;
14
- bias?: boolean;
15
14
  debug?: boolean;
16
15
  dropout?: number;
17
16
  residDrop?: number;
18
17
  embdDrop?: number;
19
- tokEmb?: boolean;
20
- lmHead?: boolean;
21
18
  nLayer?: number;
22
19
  nHead?: number;
23
20
  nEmbd?: number;
21
+ seed?: number;
24
22
  }
25
- export declare const DEFAULT_CONFIG: Required<GPTConfig>;
23
+ export declare const DefaultGPTConfig: Required<GPTConfig>;
26
24
  export type ModelSize = {
27
25
  nLayer: number;
28
26
  nHead: number;
29
27
  nEmbd: number;
30
28
  };
31
29
  export declare function getModelSizes(modelType: GPTModelType): Required<ModelSize>;
30
+ export interface GenerationConfig {
31
+ doSample: boolean;
32
+ temperature: number;
33
+ topk: number;
34
+ seed: number;
35
+ }
36
+ export declare const DefaultGenerationConfig: Required<GenerationConfig>;
32
37
  export {};
@@ -1,6 +1,6 @@
1
1
  // for a benchmark of performance, see https://github.com/epfml/disco/pull/659
2
- export const DEFAULT_CONFIG = {
3
- name: 'transformer',
2
+ export const DefaultGPTConfig = {
3
+ name: 'transformer', // prefix for the model layer names
4
4
  lr: 0.001,
5
5
  weightDecay: 0,
6
6
  maxIter: 10,
@@ -9,18 +9,16 @@ export const DEFAULT_CONFIG = {
9
9
  evaluate: true,
10
10
  maxEvalBatches: 12,
11
11
  evaluateEvery: 100,
12
- blockSize: 128,
13
- vocabSize: 50258,
14
- bias: true,
12
+ contextLength: 128,
13
+ vocabSize: 50257,
15
14
  debug: false,
16
15
  dropout: 0.2,
17
16
  residDrop: 0.2,
18
17
  embdDrop: 0.2,
19
- tokEmb: true,
20
- lmHead: true,
21
18
  nLayer: 3,
22
19
  nHead: 3,
23
20
  nEmbd: 48,
21
+ seed: Math.random(),
24
22
  };
25
23
  export function getModelSizes(modelType) {
26
24
  switch (modelType) {
@@ -40,3 +38,9 @@ export function getModelSizes(modelType) {
40
38
  return { nLayer: 3, nHead: 3, nEmbd: 48 };
41
39
  }
42
40
  }
41
+ export const DefaultGenerationConfig = {
42
+ temperature: 1.0,
43
+ doSample: false,
44
+ seed: Math.random(),
45
+ topk: 50
46
+ };
@@ -1,23 +1,20 @@
1
1
  /**
2
- * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
2
+ * Source: https://github.com/zemlyansky/gpt-tfjs and https://github.com/karpathy/build-nanogpt
3
+ * With modifications from @peacefulotter, @lukemovement and the Disco team
3
4
  **/
4
5
  import * as tf from "@tensorflow/tfjs";
5
6
  import type { Batched, Dataset, DataFormat } from "../../index.js";
6
7
  import { WeightsContainer } from "../../index.js";
7
8
  import { BatchLogs, Model, EpochLogs } from "../index.js";
8
- import { type GPTConfig } from "./config.js";
9
+ import type { GPTConfig, GenerationConfig } from './config.js';
9
10
  export type GPTSerialization = {
10
11
  weights: WeightsContainer;
11
12
  config?: GPTConfig;
12
13
  };
13
- interface PredictConfig {
14
- temperature: number;
15
- doSample: boolean;
16
- }
17
14
  export declare class GPT extends Model<"text"> {
18
15
  #private;
19
16
  private readonly model;
20
- constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
17
+ constructor(partialConfig?: Partial<GPTConfig>, layersModel?: tf.LayersModel);
21
18
  /**
22
19
  * The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
23
20
  * This allows for getting logs and stopping training without callbacks.
@@ -28,7 +25,7 @@ export declare class GPT extends Model<"text"> {
28
25
  * @param tracker
29
26
  */
30
27
  train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded["text"]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded["text"]>>): AsyncGenerator<BatchLogs, EpochLogs>;
31
- predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<PredictConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
28
+ predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<GenerationConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
32
29
  get config(): Required<GPTConfig>;
33
30
  get weights(): WeightsContainer;
34
31
  set weights(ws: WeightsContainer);
@@ -37,4 +34,3 @@ export declare class GPT extends Model<"text"> {
37
34
  extract(): tf.LayersModel;
38
35
  [Symbol.dispose](): void;
39
36
  }
40
- export {};
@@ -1,5 +1,6 @@
1
1
  /**
2
- * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
2
+ * Source: https://github.com/zemlyansky/gpt-tfjs and https://github.com/karpathy/build-nanogpt
3
+ * With modifications from @peacefulotter, @lukemovement and the Disco team
3
4
  **/
4
5
  import createDebug from "debug";
5
6
  import { List, Range } from "immutable";
@@ -7,12 +8,12 @@ import * as tf from "@tensorflow/tfjs";
7
8
  import { WeightsContainer } from "../../index.js";
8
9
  import { Model, EpochLogs } from "../index.js";
9
10
  import { GPTModel } from "./model.js";
10
- import { DEFAULT_CONFIG } from "./config.js";
11
11
  import evaluate from "./evaluate.js";
12
+ import { DefaultGPTConfig, DefaultGenerationConfig } from './config.js';
12
13
  const debug = createDebug("discojs:models:gpt");
13
14
  export class GPT extends Model {
14
15
  model;
15
- #blockSize;
16
+ #contextLength;
16
17
  #maxBatchCount;
17
18
  #vocabSize;
18
19
  constructor(partialConfig, layersModel) {
@@ -20,9 +21,9 @@ export class GPT extends Model {
20
21
  const model = new GPTModel(partialConfig, layersModel);
21
22
  model.compile();
22
23
  this.model = model;
23
- this.#blockSize = partialConfig?.blockSize ?? DEFAULT_CONFIG.blockSize;
24
- this.#maxBatchCount = partialConfig?.maxIter ?? DEFAULT_CONFIG.maxIter;
25
- this.#vocabSize = partialConfig?.vocabSize ?? DEFAULT_CONFIG.vocabSize;
24
+ this.#contextLength = partialConfig?.contextLength ?? DefaultGPTConfig.contextLength;
25
+ this.#maxBatchCount = partialConfig?.maxIter ?? DefaultGPTConfig.maxIter;
26
+ this.#vocabSize = partialConfig?.vocabSize ?? DefaultGPTConfig.vocabSize;
26
27
  }
27
28
  /**
28
29
  * The GPT train methods wraps the model.fitDataset call in a for loop to act as a generator (of logs)
@@ -85,16 +86,21 @@ export class GPT extends Model {
85
86
  }));
86
87
  }
87
88
  async predict(batch, options) {
88
- const config = {
89
- temperature: 1.0,
90
- doSample: false,
91
- ...options,
92
- };
89
+ // overwrite default with user config
90
+ const config = Object.assign({}, DefaultGenerationConfig, options);
93
91
  return List(await Promise.all(batch.map((tokens) => this.#predictSingle(tokens, config))));
94
92
  }
93
+ /**
94
+ * Generate the next token after the input sequence.
95
+ * In other words, takes an input tensor of shape (prompt length T) and returns a tensor of shape (T+1)
96
+ *
97
+ * @param token input tokens of shape (T,). T is truncated to the model's context length
98
+ * @param config generation config: temperature, doSample, topk
99
+ * @returns the next token predicted by the model
100
+ */
95
101
  async #predictSingle(tokens, config) {
96
102
  // slice input tokens if longer than context length
97
- tokens = tokens.slice(-this.#blockSize);
103
+ tokens = tokens.slice(-this.#contextLength);
98
104
  const input = tf.tidy(() => tf.tensor1d(tokens.toArray(), "int32").expandDims(0));
99
105
  const logits = tf.tidy(() => {
100
106
  const output = this.model.predict(input);
@@ -111,9 +117,24 @@ export class GPT extends Model {
111
117
  .div(config.temperature)
112
118
  .softmax());
113
119
  logits.dispose();
114
- const next = tf.tidy(() => config.doSample
115
- ? tf.multinomial(probs, 1).squeeze([0])
116
- : probs.argMax());
120
+ const next = tf.tidy(() => {
121
+ if (config.doSample) {
122
+ // returns topk biggest values among the `vocab_size` probabilities and the corresponding tokens indices
123
+ // both shapes are (config.topk,)
124
+ const { values: topkProbs, indices: topkTokens } = tf.topk(probs, config.topk);
125
+ // sample an index from the top-k probabilities
126
+ // e.g. [[0.1, 0.4, 0.3], [0.1, 0.2, 0.5]] -> [[1], [2]]
127
+ // note: multinomial does not need the input to sum to 1
128
+ const selectedIndices = tf.multinomial(topkProbs, 1, config.seed, false); // (B, )
129
+ // return the corresponding token from the sampled indices (one per sequence in the batch).
130
+ // if for some reason the probabilities are NaN, selectedIndices will be out of bounds
131
+ return topkTokens.gather(selectedIndices).squeeze([0]); // (1)
132
+ }
133
+ else {
134
+ // greedy decoding: return the token with the highest probability
135
+ return probs.argMax();
136
+ }
137
+ });
117
138
  probs.dispose();
118
139
  const ret = await next.array();
119
140
  next.dispose();