@epfml/discojs 3.0.1-p20240902162912.0 → 3.0.1-p20240904094219.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/default_tasks/cifar10.js +0 -1
- package/dist/default_tasks/lus_covid.js +0 -1
- package/dist/default_tasks/mnist.js +0 -1
- package/dist/default_tasks/simple_face.js +0 -1
- package/dist/default_tasks/titanic.js +0 -1
- package/dist/default_tasks/wikitext.js +0 -1
- package/dist/index.d.ts +0 -1
- package/dist/index.js +0 -1
- package/dist/serialization/model.js +18 -9
- package/dist/task/training_information.d.ts +0 -1
- package/dist/task/training_information.js +1 -3
- package/dist/training/disco.d.ts +3 -4
- package/dist/training/disco.js +4 -13
- package/package.json +1 -1
- package/dist/memory/base.d.ts +0 -111
- package/dist/memory/base.js +0 -9
- package/dist/memory/empty.d.ts +0 -20
- package/dist/memory/empty.js +0 -43
- package/dist/memory/index.d.ts +0 -2
- package/dist/memory/index.js +0 -2
|
@@ -19,7 +19,6 @@ export const cifar10 = {
|
|
|
19
19
|
sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, use the CSV option below and select the file named "cifar10-labels.csv". You can now connect the images located in the "CIFAR10" folder. Note that there are only 24 images in this sample dataset which is far too few to successfully train a machine learning model.'
|
|
20
20
|
},
|
|
21
21
|
trainingInformation: {
|
|
22
|
-
modelID: 'cifar10-model',
|
|
23
22
|
epochs: 10,
|
|
24
23
|
roundDuration: 10,
|
|
25
24
|
validationSplit: 0.2,
|
|
@@ -18,7 +18,6 @@ export const lusCovid = {
|
|
|
18
18
|
sampleDatasetInstructions: 'Opening the link will take you to a Switch Drive folder. You can click on the Download button in the top right corner. Unzip the file and you will get two subfolders: "COVID-" and "COVID+". You can connect the data by using the Group option and selecting each image group in its respective field.'
|
|
19
19
|
},
|
|
20
20
|
trainingInformation: {
|
|
21
|
-
modelID: 'lus-covid-model',
|
|
22
21
|
epochs: 50,
|
|
23
22
|
roundDuration: 2,
|
|
24
23
|
validationSplit: 0.2,
|
|
@@ -18,7 +18,6 @@ export const mnist = {
|
|
|
18
18
|
sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. You can connect the data with the CSV option below using the CSV file named "mnist_labels.csv". After selecting in the CSV file, you will be able to connect the data under in the "images" folder.'
|
|
19
19
|
},
|
|
20
20
|
trainingInformation: {
|
|
21
|
-
modelID: 'mnist-model',
|
|
22
21
|
epochs: 20,
|
|
23
22
|
roundDuration: 2,
|
|
24
23
|
validationSplit: 0.2,
|
|
@@ -18,7 +18,6 @@ export const simpleFace = {
|
|
|
18
18
|
sampleDatasetInstructions: 'Opening the link should start downloading a zip file which you can unzip. Inside the "example_training_data" directory you should find the "simple_face" folder which contains the "adult" and "child" folders. To connect the data, select the Group option below and connect adults and children image groups.'
|
|
19
19
|
},
|
|
20
20
|
trainingInformation: {
|
|
21
|
-
modelID: 'simple_face-model',
|
|
22
21
|
epochs: 50,
|
|
23
22
|
roundDuration: 1,
|
|
24
23
|
validationSplit: 0.2,
|
|
@@ -45,7 +45,6 @@ export const titanic = {
|
|
|
45
45
|
sampleDatasetInstructions: 'Opening the link should start downloading a CSV file which you can drag and drop in the field below.'
|
|
46
46
|
},
|
|
47
47
|
trainingInformation: {
|
|
48
|
-
modelID: 'titanic-model',
|
|
49
48
|
epochs: 10,
|
|
50
49
|
roundDuration: 2,
|
|
51
50
|
validationSplit: 0.2,
|
package/dist/index.d.ts
CHANGED
|
@@ -8,7 +8,6 @@ export * as client from './client/index.js';
|
|
|
8
8
|
export * as aggregator from './aggregator/index.js';
|
|
9
9
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
10
10
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
11
|
-
export { Memory, type ModelInfo, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
|
|
12
11
|
export { Disco, RoundLogs, RoundStatus } from './training/index.js';
|
|
13
12
|
export { Validator } from './validation/index.js';
|
|
14
13
|
export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
|
package/dist/index.js
CHANGED
|
@@ -6,7 +6,6 @@ export * as client from './client/index.js';
|
|
|
6
6
|
export * as aggregator from './aggregator/index.js';
|
|
7
7
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
8
|
export { ConsoleLogger } from './logging/index.js';
|
|
9
|
-
export { Memory, Empty as EmptyMemory } from './memory/index.js';
|
|
10
9
|
export { Disco } from './training/index.js';
|
|
11
10
|
export { Validator } from './validation/index.js';
|
|
12
11
|
export { Model, EpochLogs } from './models/index.js';
|
|
@@ -8,16 +8,25 @@ export function isEncoded(raw) {
|
|
|
8
8
|
return raw instanceof Uint8Array;
|
|
9
9
|
}
|
|
10
10
|
export async function encode(model) {
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
11
|
+
let encoded;
|
|
12
|
+
switch (true) {
|
|
13
|
+
case model instanceof models.TFJS: {
|
|
14
|
+
const serialized = await model.serialize();
|
|
15
|
+
encoded = msgpack.encode([Type.TFJS, serialized]);
|
|
16
|
+
break;
|
|
17
|
+
}
|
|
18
|
+
case model instanceof models.GPT: {
|
|
19
|
+
const { weights, config } = model.serialize();
|
|
20
|
+
const serializedWeights = await serialization.weights.encode(weights);
|
|
21
|
+
encoded = msgpack.encode([Type.GPT, serializedWeights, config]);
|
|
22
|
+
break;
|
|
23
|
+
}
|
|
24
|
+
default:
|
|
25
|
+
throw new Error("unknown model type");
|
|
19
26
|
}
|
|
20
|
-
|
|
27
|
+
// Node's Buffer extends Node's Uint8Array, which might not be the same
|
|
28
|
+
// as the browser's Uint8Array. we ensure here that it is.
|
|
29
|
+
return new Uint8Array(encoded);
|
|
21
30
|
}
|
|
22
31
|
export async function decode(encoded) {
|
|
23
32
|
if (!isEncoded(encoded)) {
|
|
@@ -24,9 +24,8 @@ export function isTrainingInformation(raw) {
|
|
|
24
24
|
if (typeof raw !== 'object' || raw === null) {
|
|
25
25
|
return false;
|
|
26
26
|
}
|
|
27
|
-
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants,
|
|
27
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
|
|
28
28
|
if (typeof dataType !== 'string' ||
|
|
29
|
-
typeof modelID !== 'string' ||
|
|
30
29
|
typeof epochs !== 'number' ||
|
|
31
30
|
typeof batchSize !== 'number' ||
|
|
32
31
|
typeof roundDuration !== 'number' ||
|
|
@@ -97,7 +96,6 @@ export function isTrainingInformation(raw) {
|
|
|
97
96
|
inputColumns,
|
|
98
97
|
maxShareValue,
|
|
99
98
|
minNbOfParticipants,
|
|
100
|
-
modelID,
|
|
101
99
|
outputColumns,
|
|
102
100
|
preprocessingFunctions,
|
|
103
101
|
roundDuration,
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { client as clients, BatchLogs, EpochLogs, Logger,
|
|
1
|
+
import { client as clients, BatchLogs, EpochLogs, Logger, Task, TrainingInformation } from "../index.js";
|
|
2
2
|
import type { TypedLabeledDataset } from "../index.js";
|
|
3
3
|
import type { Aggregator } from "../aggregator/index.js";
|
|
4
4
|
import { EventEmitter } from "../utils/event_emitter.js";
|
|
@@ -6,13 +6,12 @@ import { RoundLogs, Trainer } from "./trainer.js";
|
|
|
6
6
|
interface DiscoConfig {
|
|
7
7
|
scheme: TrainingInformation["scheme"];
|
|
8
8
|
logger: Logger;
|
|
9
|
-
memory: Memory;
|
|
10
9
|
}
|
|
11
10
|
export type RoundStatus = "Waiting for more participants" | "Retrieving peers' information" | "Updating the model with other participants' models" | "Training the model on the data you connected";
|
|
12
11
|
/**
|
|
13
12
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
14
|
-
* a convenient object providing a reduced yet complete API that wraps model training
|
|
15
|
-
* communication with nodes
|
|
13
|
+
* a convenient object providing a reduced yet complete API that wraps model training and
|
|
14
|
+
* communication with nodes.
|
|
16
15
|
*/
|
|
17
16
|
export declare class Disco extends EventEmitter<{
|
|
18
17
|
'status': RoundStatus;
|
package/dist/training/disco.js
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { async_iterator, client as clients, ConsoleLogger,
|
|
1
|
+
import { async_iterator, client as clients, ConsoleLogger, } from "../index.js";
|
|
2
2
|
import { getAggregator } from "../aggregator/index.js";
|
|
3
3
|
import { enumerate, split } from "../utils/async_iterator.js";
|
|
4
4
|
import { EventEmitter } from "../utils/event_emitter.js";
|
|
@@ -6,14 +6,13 @@ import { Trainer } from "./trainer.js";
|
|
|
6
6
|
import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js";
|
|
7
7
|
/**
|
|
8
8
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
9
|
-
* a convenient object providing a reduced yet complete API that wraps model training
|
|
10
|
-
* communication with nodes
|
|
9
|
+
* a convenient object providing a reduced yet complete API that wraps model training and
|
|
10
|
+
* communication with nodes.
|
|
11
11
|
*/
|
|
12
12
|
export class Disco extends EventEmitter {
|
|
13
13
|
trainer;
|
|
14
14
|
#client;
|
|
15
15
|
#logger;
|
|
16
|
-
#memory;
|
|
17
16
|
#task;
|
|
18
17
|
/**
|
|
19
18
|
* Connect to the given task and get ready to train.
|
|
@@ -24,10 +23,9 @@ export class Disco extends EventEmitter {
|
|
|
24
23
|
*/
|
|
25
24
|
constructor(task, clientConfig, config) {
|
|
26
25
|
super();
|
|
27
|
-
const { scheme, logger
|
|
26
|
+
const { scheme, logger } = {
|
|
28
27
|
scheme: task.trainingInformation.scheme,
|
|
29
28
|
logger: new ConsoleLogger(),
|
|
30
|
-
memory: new EmptyMemory(),
|
|
31
29
|
...config,
|
|
32
30
|
};
|
|
33
31
|
let client;
|
|
@@ -49,7 +47,6 @@ export class Disco extends EventEmitter {
|
|
|
49
47
|
throw new Error("client not setup for given task");
|
|
50
48
|
this.#logger = logger;
|
|
51
49
|
this.#client = client;
|
|
52
|
-
this.#memory = memory;
|
|
53
50
|
this.#task = task;
|
|
54
51
|
this.trainer = new Trainer(task, client);
|
|
55
52
|
// Simply propagate the training status events emitted by the client
|
|
@@ -125,12 +122,6 @@ export class Disco extends EventEmitter {
|
|
|
125
122
|
}
|
|
126
123
|
return await returnedRoundLogs;
|
|
127
124
|
}.bind(this)();
|
|
128
|
-
await this.#memory.updateWorkingModel({
|
|
129
|
-
type: "working",
|
|
130
|
-
taskID: this.#task.id,
|
|
131
|
-
name: this.#task.trainingInformation.modelID,
|
|
132
|
-
tensorBackend: this.#task.trainingInformation.tensorBackend,
|
|
133
|
-
}, this.trainer.model);
|
|
134
125
|
}
|
|
135
126
|
this.#logger.success("Training finished");
|
|
136
127
|
}
|
package/package.json
CHANGED
package/dist/memory/base.d.ts
DELETED
|
@@ -1,111 +0,0 @@
|
|
|
1
|
-
import type { Model, TaskID } from '../index.js';
|
|
2
|
-
/**
|
|
3
|
-
* Type of models stored in memory. Stored models can either be a model currently
|
|
4
|
-
* being trained ("working model") or a regular model saved in memory ("saved model").
|
|
5
|
-
* There can only be a single working model for a given task.
|
|
6
|
-
*/
|
|
7
|
-
type StoredModelType = 'saved' | 'working';
|
|
8
|
-
/**
|
|
9
|
-
* Model information which uniquely identifies a model in memory.
|
|
10
|
-
*/
|
|
11
|
-
export interface ModelInfo {
|
|
12
|
-
type: StoredModelType;
|
|
13
|
-
version?: number;
|
|
14
|
-
taskID: TaskID;
|
|
15
|
-
name: string;
|
|
16
|
-
tensorBackend: 'gpt' | 'tfjs';
|
|
17
|
-
}
|
|
18
|
-
/**
|
|
19
|
-
* A model source uniquely identifies a model stored in memory.
|
|
20
|
-
* It can be in the form of either a model info object or an ID
|
|
21
|
-
* (one-to-one mapping between the two)
|
|
22
|
-
*/
|
|
23
|
-
export type ModelSource = ModelInfo | string;
|
|
24
|
-
/**
|
|
25
|
-
* Represents a model memory system, providing functions to fetch, save, delete and update models.
|
|
26
|
-
* Stored models can either be a model currently being trained ("working model") or a regular model
|
|
27
|
-
* saved in memory ("saved model"). There can only be a single working model for a given task.
|
|
28
|
-
*/
|
|
29
|
-
export declare abstract class Memory {
|
|
30
|
-
/**
|
|
31
|
-
* Fetches the model identified by the given model source.
|
|
32
|
-
* @param source The model source
|
|
33
|
-
* @returns The model
|
|
34
|
-
*/
|
|
35
|
-
abstract getModel(source: ModelSource): Promise<Model>;
|
|
36
|
-
/**
|
|
37
|
-
* Removes the model identified by the given model source from memory.
|
|
38
|
-
* @param source The model source
|
|
39
|
-
* @returns The model
|
|
40
|
-
*/
|
|
41
|
-
abstract deleteModel(source: ModelSource): Promise<void>;
|
|
42
|
-
/**
|
|
43
|
-
* Replaces the corresponding working model with the saved model identified by the given model source.
|
|
44
|
-
* @param source The model source
|
|
45
|
-
*/
|
|
46
|
-
abstract loadModel(source: ModelSource): Promise<void>;
|
|
47
|
-
/**
|
|
48
|
-
* Fetches metadata for the model identified by the given model source.
|
|
49
|
-
* If the model does not exist in memory, returns undefined.
|
|
50
|
-
* @param source The model source
|
|
51
|
-
* @returns The model metadata or undefined
|
|
52
|
-
*/
|
|
53
|
-
abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
|
|
54
|
-
/**
|
|
55
|
-
* Replaces the working model identified by the given source with the newly provided model.
|
|
56
|
-
* @param source The model source
|
|
57
|
-
* @param model The new model
|
|
58
|
-
*/
|
|
59
|
-
abstract updateWorkingModel(source: ModelSource, model: Model): Promise<void>;
|
|
60
|
-
/**
|
|
61
|
-
* Creates a saved model copy from the working model identified by the given model source.
|
|
62
|
-
* Returns the saved model's path.
|
|
63
|
-
* @param source The model source
|
|
64
|
-
* @returns The saved model's path
|
|
65
|
-
*/
|
|
66
|
-
abstract saveWorkingModel(source: ModelSource): Promise<string | undefined>;
|
|
67
|
-
/**
|
|
68
|
-
* Saves the newly provided model to the given model source.
|
|
69
|
-
* Returns the saved model's path
|
|
70
|
-
* @param source The model source
|
|
71
|
-
* @param model The new model
|
|
72
|
-
* @returns The saved model's path
|
|
73
|
-
*/
|
|
74
|
-
abstract saveModel(source: ModelSource, model: Model): Promise<string | undefined>;
|
|
75
|
-
/**
|
|
76
|
-
* Moves the model identified by the model source to a file system. This is platform-dependent.
|
|
77
|
-
* @param source The model source
|
|
78
|
-
*/
|
|
79
|
-
abstract downloadModel(source: ModelSource): Promise<void>;
|
|
80
|
-
/**
|
|
81
|
-
* Checks whether the model memory contains the model identified by the given source.
|
|
82
|
-
* @param source The model source
|
|
83
|
-
* @returns True if the memory contains the model, false otherwise
|
|
84
|
-
*/
|
|
85
|
-
abstract contains(source: ModelSource): Promise<boolean>;
|
|
86
|
-
/**
|
|
87
|
-
* Computes the path in memory corresponding to the given model source, be it a path or model information.
|
|
88
|
-
* This is used to easily switch between model path and information, which are both unique model identifiers
|
|
89
|
-
* with a one-to-one equivalence. Returns undefined instead if no path could be inferred from the given
|
|
90
|
-
* model source.
|
|
91
|
-
* @param source The model source
|
|
92
|
-
* @returns The model path
|
|
93
|
-
*/
|
|
94
|
-
abstract getModelMemoryPath(source: ModelSource): string | undefined;
|
|
95
|
-
/**
|
|
96
|
-
* Computes the model information corresponding to the given model source, be it a path or model information.
|
|
97
|
-
* This is used to easily switch between model path and information, which are both unique model identifiers
|
|
98
|
-
* with a one-to-one equivalence. Returns undefined instead if no unique model information could be inferred
|
|
99
|
-
* from the given model source.
|
|
100
|
-
* @param source The model source
|
|
101
|
-
* @returns The model information
|
|
102
|
-
*/
|
|
103
|
-
abstract getModelInfo(source: ModelSource): ModelInfo | undefined;
|
|
104
|
-
/**
|
|
105
|
-
* Computes the lowest version a model source can have without conflicting with model versions currently in memory.
|
|
106
|
-
* @param source The model source
|
|
107
|
-
* @returns The duplicated model source
|
|
108
|
-
*/
|
|
109
|
-
abstract duplicateSource(source: ModelSource): Promise<ModelSource | undefined>;
|
|
110
|
-
}
|
|
111
|
-
export {};
|
package/dist/memory/base.js
DELETED
|
@@ -1,9 +0,0 @@
|
|
|
1
|
-
// only used browser-side
|
|
2
|
-
// TODO: replace IO type
|
|
3
|
-
/**
|
|
4
|
-
* Represents a model memory system, providing functions to fetch, save, delete and update models.
|
|
5
|
-
* Stored models can either be a model currently being trained ("working model") or a regular model
|
|
6
|
-
* saved in memory ("saved model"). There can only be a single working model for a given task.
|
|
7
|
-
*/
|
|
8
|
-
export class Memory {
|
|
9
|
-
}
|
package/dist/memory/empty.d.ts
DELETED
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
import type { Model } from '../index.js';
|
|
2
|
-
import type { ModelInfo } from './base.js';
|
|
3
|
-
import { Memory } from './base.js';
|
|
4
|
-
/**
|
|
5
|
-
* Represents an empty model memory.
|
|
6
|
-
*/
|
|
7
|
-
export declare class Empty extends Memory {
|
|
8
|
-
getModelMetadata(): Promise<undefined>;
|
|
9
|
-
contains(): Promise<boolean>;
|
|
10
|
-
getModel(): Promise<Model>;
|
|
11
|
-
loadModel(): Promise<void>;
|
|
12
|
-
updateWorkingModel(): Promise<void>;
|
|
13
|
-
saveWorkingModel(): Promise<undefined>;
|
|
14
|
-
saveModel(): Promise<undefined>;
|
|
15
|
-
deleteModel(): Promise<void>;
|
|
16
|
-
downloadModel(): Promise<void>;
|
|
17
|
-
getModelMemoryPath(): string;
|
|
18
|
-
getModelInfo(): ModelInfo;
|
|
19
|
-
duplicateSource(): Promise<undefined>;
|
|
20
|
-
}
|
package/dist/memory/empty.js
DELETED
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
import { Memory } from './base.js';
|
|
2
|
-
/**
|
|
3
|
-
* Represents an empty model memory.
|
|
4
|
-
*/
|
|
5
|
-
export class Empty extends Memory {
|
|
6
|
-
getModelMetadata() {
|
|
7
|
-
return Promise.resolve(undefined);
|
|
8
|
-
}
|
|
9
|
-
contains() {
|
|
10
|
-
return Promise.resolve(false);
|
|
11
|
-
}
|
|
12
|
-
getModel() {
|
|
13
|
-
return Promise.reject(new Error('empty'));
|
|
14
|
-
}
|
|
15
|
-
loadModel() {
|
|
16
|
-
return Promise.reject(new Error('empty'));
|
|
17
|
-
}
|
|
18
|
-
updateWorkingModel() {
|
|
19
|
-
// nothing to do
|
|
20
|
-
return Promise.resolve();
|
|
21
|
-
}
|
|
22
|
-
saveWorkingModel() {
|
|
23
|
-
return Promise.resolve(undefined);
|
|
24
|
-
}
|
|
25
|
-
saveModel() {
|
|
26
|
-
return Promise.resolve(undefined);
|
|
27
|
-
}
|
|
28
|
-
async deleteModel() {
|
|
29
|
-
// nothing to do
|
|
30
|
-
}
|
|
31
|
-
downloadModel() {
|
|
32
|
-
return Promise.reject(new Error('empty'));
|
|
33
|
-
}
|
|
34
|
-
getModelMemoryPath() {
|
|
35
|
-
throw new Error('empty');
|
|
36
|
-
}
|
|
37
|
-
getModelInfo() {
|
|
38
|
-
throw new Error('empty');
|
|
39
|
-
}
|
|
40
|
-
duplicateSource() {
|
|
41
|
-
return Promise.resolve(undefined);
|
|
42
|
-
}
|
|
43
|
-
}
|
package/dist/memory/index.d.ts
DELETED
package/dist/memory/index.js
DELETED