@epfml/discojs 3.0.1-p20240902162912.0 → 3.0.1-p20240904094219.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,7 +19,6 @@ export const cifar10 = {
19
19
  sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, use the CSV option below and select the file named "cifar10-labels.csv". You can now connect the images located in the "CIFAR10" folder. Note that there are only 24 images in this sample dataset which is far too few to successfully train a machine learning model.'
20
20
  },
21
21
  trainingInformation: {
22
- modelID: 'cifar10-model',
23
22
  epochs: 10,
24
23
  roundDuration: 10,
25
24
  validationSplit: 0.2,
@@ -18,7 +18,6 @@ export const lusCovid = {
18
18
  sampleDatasetInstructions: 'Opening the link will take you to a Switch Drive folder. You can click on the Download button in the top right corner. Unzip the file and you will get two subfolders: "COVID-" and "COVID+". You can connect the data by using the Group option and selecting each image group in its respective field.'
19
19
  },
20
20
  trainingInformation: {
21
- modelID: 'lus-covid-model',
22
21
  epochs: 50,
23
22
  roundDuration: 2,
24
23
  validationSplit: 0.2,
@@ -18,7 +18,6 @@ export const mnist = {
18
18
  sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. You can connect the data with the CSV option below using the CSV file named "mnist_labels.csv". After selecting in the CSV file, you will be able to connect the data under in the "images" folder.'
19
19
  },
20
20
  trainingInformation: {
21
- modelID: 'mnist-model',
22
21
  epochs: 20,
23
22
  roundDuration: 2,
24
23
  validationSplit: 0.2,
@@ -18,7 +18,6 @@ export const simpleFace = {
18
18
  sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. Inside the "example_training_data" directory you should find the "simple_face" folder which contains the "adult" and "child" folders. To connect the data, select the Group option below and connect adults and children image groups.'
19
19
  },
20
20
  trainingInformation: {
21
- modelID: 'simple_face-model',
22
21
  epochs: 50,
23
22
  roundDuration: 1,
24
23
  validationSplit: 0.2,
@@ -45,7 +45,6 @@ export const titanic = {
45
45
  sampleDatasetInstructions: 'Opening the link should start downloading a CSV file which you can drag and drop in the field below.'
46
46
  },
47
47
  trainingInformation: {
48
- modelID: 'titanic-model',
49
48
  epochs: 10,
50
49
  roundDuration: 2,
51
50
  validationSplit: 0.2,
@@ -23,7 +23,6 @@ export const wikitext = {
23
23
  },
24
24
  trainingInformation: {
25
25
  dataType: 'text',
26
- modelID: 'llm-raw-model',
27
26
  preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
28
27
  scheme: 'federated',
29
28
  minNbOfParticipants: 2,
package/dist/index.d.ts CHANGED
@@ -8,7 +8,6 @@ export * as client from './client/index.js';
8
8
  export * as aggregator from './aggregator/index.js';
9
9
  export { WeightsContainer, aggregation } from './weights/index.js';
10
10
  export { Logger, ConsoleLogger } from './logging/index.js';
11
- export { Memory, type ModelInfo, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
12
11
  export { Disco, RoundLogs, RoundStatus } from './training/index.js';
13
12
  export { Validator } from './validation/index.js';
14
13
  export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
package/dist/index.js CHANGED
@@ -6,7 +6,6 @@ export * as client from './client/index.js';
6
6
  export * as aggregator from './aggregator/index.js';
7
7
  export { WeightsContainer, aggregation } from './weights/index.js';
8
8
  export { ConsoleLogger } from './logging/index.js';
9
- export { Memory, Empty as EmptyMemory } from './memory/index.js';
10
9
  export { Disco } from './training/index.js';
11
10
  export { Validator } from './validation/index.js';
12
11
  export { Model, EpochLogs } from './models/index.js';
@@ -8,16 +8,25 @@ export function isEncoded(raw) {
8
8
  return raw instanceof Uint8Array;
9
9
  }
10
10
  export async function encode(model) {
11
- if (model instanceof models.TFJS) {
12
- const serialized = await model.serialize();
13
- return msgpack.encode([Type.TFJS, serialized]);
14
- }
15
- if (model instanceof models.GPT) {
16
- const { weights, config } = model.serialize();
17
- const serializedWeights = await serialization.weights.encode(weights);
18
- return msgpack.encode([Type.GPT, serializedWeights, config]);
11
+ let encoded;
12
+ switch (true) {
13
+ case model instanceof models.TFJS: {
14
+ const serialized = await model.serialize();
15
+ encoded = msgpack.encode([Type.TFJS, serialized]);
16
+ break;
17
+ }
18
+ case model instanceof models.GPT: {
19
+ const { weights, config } = model.serialize();
20
+ const serializedWeights = await serialization.weights.encode(weights);
21
+ encoded = msgpack.encode([Type.GPT, serializedWeights, config]);
22
+ break;
23
+ }
24
+ default:
25
+ throw new Error("unknown model type");
19
26
  }
20
- throw new Error('unknown model type');
27
+ // Node's Buffer extends Node's Uint8Array, which might not be the same
28
+ // as the browser's Uint8Array. we ensure here that it is.
29
+ return new Uint8Array(encoded);
21
30
  }
22
31
  export async function decode(encoded) {
23
32
  if (!isEncoded(encoded)) {
@@ -5,7 +5,6 @@ interface Privacy {
5
5
  noiseScale?: number;
6
6
  }
7
7
  export interface TrainingInformation {
8
- modelID: string;
9
8
  epochs: number;
10
9
  roundDuration: number;
11
10
  validationSplit: number;
@@ -24,9 +24,8 @@ export function isTrainingInformation(raw) {
24
24
  if (typeof raw !== 'object' || raw === null) {
25
25
  return false;
26
26
  }
27
- const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
27
+ const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
28
28
  if (typeof dataType !== 'string' ||
29
- typeof modelID !== 'string' ||
30
29
  typeof epochs !== 'number' ||
31
30
  typeof batchSize !== 'number' ||
32
31
  typeof roundDuration !== 'number' ||
@@ -97,7 +96,6 @@ export function isTrainingInformation(raw) {
97
96
  inputColumns,
98
97
  maxShareValue,
99
98
  minNbOfParticipants,
100
- modelID,
101
99
  outputColumns,
102
100
  preprocessingFunctions,
103
101
  roundDuration,
@@ -1,4 +1,4 @@
1
- import { client as clients, BatchLogs, EpochLogs, Logger, Memory, Task, TrainingInformation } from "../index.js";
1
+ import { client as clients, BatchLogs, EpochLogs, Logger, Task, TrainingInformation } from "../index.js";
2
2
  import type { TypedLabeledDataset } from "../index.js";
3
3
  import type { Aggregator } from "../aggregator/index.js";
4
4
  import { EventEmitter } from "../utils/event_emitter.js";
@@ -6,13 +6,12 @@ import { RoundLogs, Trainer } from "./trainer.js";
6
6
  interface DiscoConfig {
7
7
  scheme: TrainingInformation["scheme"];
8
8
  logger: Logger;
9
- memory: Memory;
10
9
  }
11
10
  export type RoundStatus = "Waiting for more participants" | "Retrieving peers' information" | "Updating the model with other participants' models" | "Training the model on the data you connected";
12
11
  /**
13
12
  * Top-level class handling distributed training from a client's perspective. It is meant to be
14
- * a convenient object providing a reduced yet complete API that wraps model training,
15
- * communication with nodes, logs and model memory.
13
+ * a convenient object providing a reduced yet complete API that wraps model training and
14
+ * communication with nodes.
16
15
  */
17
16
  export declare class Disco extends EventEmitter<{
18
17
  'status': RoundStatus;
@@ -1,4 +1,4 @@
1
- import { async_iterator, client as clients, ConsoleLogger, EmptyMemory, } from "../index.js";
1
+ import { async_iterator, client as clients, ConsoleLogger, } from "../index.js";
2
2
  import { getAggregator } from "../aggregator/index.js";
3
3
  import { enumerate, split } from "../utils/async_iterator.js";
4
4
  import { EventEmitter } from "../utils/event_emitter.js";
@@ -6,14 +6,13 @@ import { Trainer } from "./trainer.js";
6
6
  import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js";
7
7
  /**
8
8
  * Top-level class handling distributed training from a client's perspective. It is meant to be
9
- * a convenient object providing a reduced yet complete API that wraps model training,
10
- * communication with nodes, logs and model memory.
9
+ * a convenient object providing a reduced yet complete API that wraps model training and
10
+ * communication with nodes.
11
11
  */
12
12
  export class Disco extends EventEmitter {
13
13
  trainer;
14
14
  #client;
15
15
  #logger;
16
- #memory;
17
16
  #task;
18
17
  /**
19
18
  * Connect to the given task and get ready to train.
@@ -24,10 +23,9 @@ export class Disco extends EventEmitter {
24
23
  */
25
24
  constructor(task, clientConfig, config) {
26
25
  super();
27
- const { scheme, logger, memory } = {
26
+ const { scheme, logger } = {
28
27
  scheme: task.trainingInformation.scheme,
29
28
  logger: new ConsoleLogger(),
30
- memory: new EmptyMemory(),
31
29
  ...config,
32
30
  };
33
31
  let client;
@@ -49,7 +47,6 @@ export class Disco extends EventEmitter {
49
47
  throw new Error("client not setup for given task");
50
48
  this.#logger = logger;
51
49
  this.#client = client;
52
- this.#memory = memory;
53
50
  this.#task = task;
54
51
  this.trainer = new Trainer(task, client);
55
52
  // Simply propagate the training status events emitted by the client
@@ -125,12 +122,6 @@ export class Disco extends EventEmitter {
125
122
  }
126
123
  return await returnedRoundLogs;
127
124
  }.bind(this)();
128
- await this.#memory.updateWorkingModel({
129
- type: "working",
130
- taskID: this.#task.id,
131
- name: this.#task.trainingInformation.modelID,
132
- tensorBackend: this.#task.trainingInformation.tensorBackend,
133
- }, this.trainer.model);
134
125
  }
135
126
  this.#logger.success("Training finished");
136
127
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20240902162912.0",
3
+ "version": "3.0.1-p20240904094219.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1,111 +0,0 @@
1
- import type { Model, TaskID } from '../index.js';
2
- /**
3
- * Type of models stored in memory. Stored models can either be a model currently
4
- * being trained ("working model") or a regular model saved in memory ("saved model").
5
- * There can only be a single working model for a given task.
6
- */
7
- type StoredModelType = 'saved' | 'working';
8
- /**
9
- * Model information which uniquely identifies a model in memory.
10
- */
11
- export interface ModelInfo {
12
- type: StoredModelType;
13
- version?: number;
14
- taskID: TaskID;
15
- name: string;
16
- tensorBackend: 'gpt' | 'tfjs';
17
- }
18
- /**
19
- * A model source uniquely identifies a model stored in memory.
20
- * It can be in the form of either a model info object or an ID
21
- * (one-to-one mapping between the two)
22
- */
23
- export type ModelSource = ModelInfo | string;
24
- /**
25
- * Represents a model memory system, providing functions to fetch, save, delete and update models.
26
- * Stored models can either be a model currently being trained ("working model") or a regular model
27
- * saved in memory ("saved model"). There can only be a single working model for a given task.
28
- */
29
- export declare abstract class Memory {
30
- /**
31
- * Fetches the model identified by the given model source.
32
- * @param source The model source
33
- * @returns The model
34
- */
35
- abstract getModel(source: ModelSource): Promise<Model>;
36
- /**
37
- * Removes the model identified by the given model source from memory.
38
- * @param source The model source
39
- * @returns The model
40
- */
41
- abstract deleteModel(source: ModelSource): Promise<void>;
42
- /**
43
- * Replaces the corresponding working model with the saved model identified by the given model source.
44
- * @param source The model source
45
- */
46
- abstract loadModel(source: ModelSource): Promise<void>;
47
- /**
48
- * Fetches metadata for the model identified by the given model source.
49
- * If the model does not exist in memory, returns undefined.
50
- * @param source The model source
51
- * @returns The model metadata or undefined
52
- */
53
- abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
54
- /**
55
- * Replaces the working model identified by the given source with the newly provided model.
56
- * @param source The model source
57
- * @param model The new model
58
- */
59
- abstract updateWorkingModel(source: ModelSource, model: Model): Promise<void>;
60
- /**
61
- * Creates a saved model copy from the working model identified by the given model source.
62
- * Returns the saved model's path.
63
- * @param source The model source
64
- * @returns The saved model's path
65
- */
66
- abstract saveWorkingModel(source: ModelSource): Promise<string | undefined>;
67
- /**
68
- * Saves the newly provided model to the given model source.
69
- * Returns the saved model's path
70
- * @param source The model source
71
- * @param model The new model
72
- * @returns The saved model's path
73
- */
74
- abstract saveModel(source: ModelSource, model: Model): Promise<string | undefined>;
75
- /**
76
- * Moves the model identified by the model source to a file system. This is platform-dependent.
77
- * @param source The model source
78
- */
79
- abstract downloadModel(source: ModelSource): Promise<void>;
80
- /**
81
- * Checks whether the model memory contains the model identified by the given source.
82
- * @param source The model source
83
- * @returns True if the memory contains the model, false otherwise
84
- */
85
- abstract contains(source: ModelSource): Promise<boolean>;
86
- /**
87
- * Computes the path in memory corresponding to the given model source, be it a path or model information.
88
- * This is used to easily switch between model path and information, which are both unique model identifiers
89
- * with a one-to-one equivalence. Returns undefined instead if no path could be inferred from the given
90
- * model source.
91
- * @param source The model source
92
- * @returns The model path
93
- */
94
- abstract getModelMemoryPath(source: ModelSource): string | undefined;
95
- /**
96
- * Computes the model information corresponding to the given model source, be it a path or model information.
97
- * This is used to easily switch between model path and information, which are both unique model identifiers
98
- * with a one-to-one equivalence. Returns undefined instead if no unique model information could be inferred
99
- * from the given model source.
100
- * @param source The model source
101
- * @returns The model information
102
- */
103
- abstract getModelInfo(source: ModelSource): ModelInfo | undefined;
104
- /**
105
- * Computes the lowest version a model source can have without conflicting with model versions currently in memory.
106
- * @param source The model source
107
- * @returns The duplicated model source
108
- */
109
- abstract duplicateSource(source: ModelSource): Promise<ModelSource | undefined>;
110
- }
111
- export {};
@@ -1,9 +0,0 @@
1
- // only used browser-side
2
- // TODO: replace IO type
3
- /**
4
- * Represents a model memory system, providing functions to fetch, save, delete and update models.
5
- * Stored models can either be a model currently being trained ("working model") or a regular model
6
- * saved in memory ("saved model"). There can only be a single working model for a given task.
7
- */
8
- export class Memory {
9
- }
@@ -1,20 +0,0 @@
1
- import type { Model } from '../index.js';
2
- import type { ModelInfo } from './base.js';
3
- import { Memory } from './base.js';
4
- /**
5
- * Represents an empty model memory.
6
- */
7
- export declare class Empty extends Memory {
8
- getModelMetadata(): Promise<undefined>;
9
- contains(): Promise<boolean>;
10
- getModel(): Promise<Model>;
11
- loadModel(): Promise<void>;
12
- updateWorkingModel(): Promise<void>;
13
- saveWorkingModel(): Promise<undefined>;
14
- saveModel(): Promise<undefined>;
15
- deleteModel(): Promise<void>;
16
- downloadModel(): Promise<void>;
17
- getModelMemoryPath(): string;
18
- getModelInfo(): ModelInfo;
19
- duplicateSource(): Promise<undefined>;
20
- }
@@ -1,43 +0,0 @@
1
- import { Memory } from './base.js';
2
- /**
3
- * Represents an empty model memory.
4
- */
5
- export class Empty extends Memory {
6
- getModelMetadata() {
7
- return Promise.resolve(undefined);
8
- }
9
- contains() {
10
- return Promise.resolve(false);
11
- }
12
- getModel() {
13
- return Promise.reject(new Error('empty'));
14
- }
15
- loadModel() {
16
- return Promise.reject(new Error('empty'));
17
- }
18
- updateWorkingModel() {
19
- // nothing to do
20
- return Promise.resolve();
21
- }
22
- saveWorkingModel() {
23
- return Promise.resolve(undefined);
24
- }
25
- saveModel() {
26
- return Promise.resolve(undefined);
27
- }
28
- async deleteModel() {
29
- // nothing to do
30
- }
31
- downloadModel() {
32
- return Promise.reject(new Error('empty'));
33
- }
34
- getModelMemoryPath() {
35
- throw new Error('empty');
36
- }
37
- getModelInfo() {
38
- throw new Error('empty');
39
- }
40
- duplicateSource() {
41
- return Promise.resolve(undefined);
42
- }
43
- }
@@ -1,2 +0,0 @@
1
- export { Empty } from './empty.js';
2
- export { Memory, type ModelInfo, type ModelSource } from './base.js';
@@ -1,2 +0,0 @@
1
- export { Empty } from './empty.js';
2
- export { Memory } from './base.js';