@epfml/discojs 3.0.1-p20240902094132.0 → 3.0.1-p20240902162912.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. package/dist/aggregator/base.d.ts +16 -2
  2. package/dist/aggregator/base.js +25 -3
  3. package/dist/aggregator/mean.d.ts +1 -0
  4. package/dist/aggregator/mean.js +11 -6
  5. package/dist/aggregator/secure.js +1 -1
  6. package/dist/client/{base.d.ts → client.d.ts} +13 -30
  7. package/dist/client/{base.js → client.js} +10 -20
  8. package/dist/client/decentralized/{base.d.ts → decentralized_client.d.ts} +5 -5
  9. package/dist/client/decentralized/{base.js → decentralized_client.js} +20 -16
  10. package/dist/client/decentralized/index.d.ts +1 -1
  11. package/dist/client/decentralized/index.js +1 -1
  12. package/dist/client/decentralized/messages.d.ts +7 -2
  13. package/dist/client/decentralized/messages.js +4 -2
  14. package/dist/client/event_connection.js +2 -2
  15. package/dist/client/federated/federated_client.d.ts +44 -0
  16. package/dist/client/federated/federated_client.js +210 -0
  17. package/dist/client/federated/index.d.ts +1 -1
  18. package/dist/client/federated/index.js +1 -1
  19. package/dist/client/federated/messages.d.ts +17 -2
  20. package/dist/client/federated/messages.js +3 -1
  21. package/dist/client/index.d.ts +2 -2
  22. package/dist/client/index.js +2 -2
  23. package/dist/client/local_client.d.ts +10 -0
  24. package/dist/client/local_client.js +14 -0
  25. package/dist/client/messages.d.ts +6 -8
  26. package/dist/client/messages.js +23 -7
  27. package/dist/client/utils.js +1 -1
  28. package/dist/default_tasks/cifar10.js +1 -1
  29. package/dist/default_tasks/lus_covid.js +1 -0
  30. package/dist/default_tasks/mnist.js +1 -1
  31. package/dist/default_tasks/simple_face.js +2 -1
  32. package/dist/default_tasks/titanic.js +2 -1
  33. package/dist/default_tasks/wikitext.js +1 -0
  34. package/dist/index.d.ts +4 -1
  35. package/dist/index.js +1 -0
  36. package/dist/logging/logger.d.ts +1 -1
  37. package/dist/task/index.d.ts +0 -1
  38. package/dist/task/index.js +0 -1
  39. package/dist/task/task.d.ts +0 -2
  40. package/dist/task/task.js +2 -4
  41. package/dist/task/training_information.d.ts +1 -1
  42. package/dist/task/training_information.js +3 -3
  43. package/dist/training/disco.d.ts +11 -12
  44. package/dist/training/disco.js +19 -34
  45. package/dist/training/index.d.ts +1 -1
  46. package/dist/training/trainer.d.ts +3 -2
  47. package/dist/training/trainer.js +12 -5
  48. package/dist/utils/event_emitter.js +1 -3
  49. package/package.json +1 -1
  50. package/dist/client/federated/base.d.ts +0 -38
  51. package/dist/client/federated/base.js +0 -130
  52. package/dist/client/local.d.ts +0 -5
  53. package/dist/client/local.js +0 -6
  54. package/dist/task/digest.d.ts +0 -5
  55. package/dist/task/digest.js +0 -14
@@ -1,32 +1,35 @@
1
1
  import { client as clients, BatchLogs, EpochLogs, Logger, Memory, Task, TrainingInformation } from "../index.js";
2
2
  import type { TypedLabeledDataset } from "../index.js";
3
3
  import type { Aggregator } from "../aggregator/index.js";
4
+ import { EventEmitter } from "../utils/event_emitter.js";
4
5
  import { RoundLogs, Trainer } from "./trainer.js";
5
- interface Config {
6
+ interface DiscoConfig {
6
7
  scheme: TrainingInformation["scheme"];
7
8
  logger: Logger;
8
9
  memory: Memory;
9
10
  }
11
+ export type RoundStatus = "Waiting for more participants" | "Retrieving peers' information" | "Updating the model with other participants' models" | "Training the model on the data you connected";
10
12
  /**
11
13
  * Top-level class handling distributed training from a client's perspective. It is meant to be
12
14
  * a convenient object providing a reduced yet complete API that wraps model training,
13
15
  * communication with nodes, logs and model memory.
14
16
  */
15
- export declare class Disco {
17
+ export declare class Disco extends EventEmitter<{
18
+ 'status': RoundStatus;
19
+ }> {
16
20
  #private;
17
21
  readonly trainer: Trainer;
18
- private constructor();
19
22
  /**
20
23
  * Connect to the given task and get ready to train.
21
24
  *
22
- * Will load the model from memory if available or fetch it from the server.
23
- *
25
+ * @param task
24
26
  * @param clientConfig client to connect with or parameters on how to create one.
25
- **/
26
- static fromTask(task: Task, clientConfig: clients.Client | URL | {
27
+ * @param config the DiscoConfig
28
+ */
29
+ constructor(task: Task, clientConfig: clients.Client | URL | {
27
30
  aggregator: Aggregator;
28
31
  url: URL;
29
- }, config: Partial<Config>): Promise<Disco>;
32
+ }, config: Partial<DiscoConfig>);
30
33
  /** Train on dataset, yielding logs of every round. */
31
34
  trainByRound(dataset: TypedLabeledDataset): AsyncGenerator<RoundLogs>;
32
35
  /** Train on dataset, yielding logs of every epoch. */
@@ -42,10 +45,6 @@ export declare class Disco {
42
45
  * If you don't care about the whole process, use one of the other train methods.
43
46
  **/
44
47
  train(dataset: TypedLabeledDataset): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
45
- /**
46
- * Stops the ongoing training instance without disconnecting the client.
47
- */
48
- pause(): Promise<void>;
49
48
  /**
50
49
  * Completely stops the ongoing training instance.
51
50
  */
@@ -1,6 +1,7 @@
1
1
  import { async_iterator, client as clients, ConsoleLogger, EmptyMemory, } from "../index.js";
2
2
  import { getAggregator } from "../aggregator/index.js";
3
3
  import { enumerate, split } from "../utils/async_iterator.js";
4
+ import { EventEmitter } from "../utils/event_emitter.js";
4
5
  import { Trainer } from "./trainer.js";
5
6
  import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js";
6
7
  /**
@@ -8,27 +9,21 @@ import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js";
8
9
  * a convenient object providing a reduced yet complete API that wraps model training,
9
10
  * communication with nodes, logs and model memory.
10
11
  */
11
- export class Disco {
12
+ export class Disco extends EventEmitter {
12
13
  trainer;
13
14
  #client;
14
15
  #logger;
15
16
  #memory;
16
17
  #task;
17
- constructor(trainer, task, client, memory, logger) {
18
- this.trainer = trainer;
19
- this.#client = client;
20
- this.#logger = logger;
21
- this.#memory = memory;
22
- this.#task = task;
23
- }
24
18
  /**
25
19
  * Connect to the given task and get ready to train.
26
20
  *
27
- * Will load the model from memory if available or fetch it from the server.
28
- *
21
+ * @param task
29
22
  * @param clientConfig client to connect with or parameters on how to create one.
30
- **/
31
- static async fromTask(task, clientConfig, config) {
23
+ * @param config the DiscoConfig
24
+ */
25
+ constructor(task, clientConfig, config) {
26
+ super();
32
27
  const { scheme, logger, memory } = {
33
28
  scheme: task.trainingInformation.scheme,
34
29
  logger: new ConsoleLogger(),
@@ -52,18 +47,13 @@ export class Disco {
52
47
  }
53
48
  if (client.task !== task)
54
49
  throw new Error("client not setup for given task");
55
- let model;
56
- const memoryInfo = {
57
- type: "working",
58
- taskID: task.id,
59
- name: task.trainingInformation.modelID,
60
- tensorBackend: task.trainingInformation.tensorBackend,
61
- };
62
- if (await memory.contains(memoryInfo))
63
- model = await memory.getModel(memoryInfo);
64
- else
65
- model = await client.getLatestModel();
66
- return new Disco(new Trainer(task, model, client), task, client, memory, logger);
50
+ this.#logger = logger;
51
+ this.#client = client;
52
+ this.#memory = memory;
53
+ this.#task = task;
54
+ this.trainer = new Trainer(task, client);
55
+ // Simply propagate the training status events emitted by the client
56
+ this.#client.on('status', status => this.emit('status', status));
67
57
  }
68
58
  /** Train on dataset, yielding logs of every round. */
69
59
  async *trainByRound(dataset) {
@@ -106,11 +96,12 @@ export class Disco {
106
96
  * If you don't care about the whole process, use one of the other train methods.
107
97
  **/
108
98
  async *train(dataset) {
109
- this.#logger.success("Training started.");
99
+ this.#logger.success("Training started");
110
100
  const data = await labeledDatasetToDataSplit(this.#task, dataset);
111
101
  const trainData = data.train.preprocess().batch().dataset;
112
102
  const validationData = data.validation?.preprocess().batch().dataset ?? trainData;
113
- await this.#client.connect();
103
+ // the client fetches the latest weights upon connection
104
+ this.trainer.model = await this.#client.connect();
114
105
  for await (const [round, epochs] of enumerate(this.trainer.train(trainData, validationData))) {
115
106
  yield async function* () {
116
107
  const [gen, returnedRoundLogs] = split(epochs);
@@ -123,6 +114,7 @@ export class Disco {
123
114
  ` Epoch: ${epoch}`,
124
115
  ` Training loss: ${epochLogs.training.loss}`,
125
116
  ` Training accuracy: ${epochLogs.training.accuracy}`,
117
+ ` Peak memory: ${epochLogs.peakMemory}`,
126
118
  epochLogs.validation !== undefined
127
119
  ? ` Validation loss: ${epochLogs.validation.loss}`
128
120
  : "",
@@ -140,19 +132,12 @@ export class Disco {
140
132
  tensorBackend: this.#task.trainingInformation.tensorBackend,
141
133
  }, this.trainer.model);
142
134
  }
143
- this.#logger.success("Training finished.");
144
- }
145
- /**
146
- * Stops the ongoing training instance without disconnecting the client.
147
- */
148
- async pause() {
149
- await this.trainer.stopTraining();
135
+ this.#logger.success("Training finished");
150
136
  }
151
137
  /**
152
138
  * Completely stops the ongoing training instance.
153
139
  */
154
140
  async close() {
155
- await this.pause();
156
141
  await this.#client.disconnect();
157
142
  }
158
143
  }
@@ -1,2 +1,2 @@
1
- export { Disco } from './disco.js';
1
+ export { Disco, RoundStatus } from './disco.js';
2
2
  export { RoundLogs, Trainer } from './trainer.js';
@@ -9,8 +9,9 @@ export interface RoundLogs {
9
9
  /** Train a model and exchange with others **/
10
10
  export declare class Trainer {
11
11
  #private;
12
- readonly model: Model;
13
- constructor(task: Task, model: Model, client: Client);
12
+ get model(): Model;
13
+ set model(model: Model);
14
+ constructor(task: Task, client: Client);
14
15
  stopTraining(): Promise<void>;
15
16
  train(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
16
17
  }
@@ -4,14 +4,21 @@ import { privacy } from "../index.js";
4
4
  import * as async_iterator from "../utils/async_iterator.js";
5
5
  /** Train a model and exchange with others **/
6
6
  export class Trainer {
7
- model;
8
7
  #client;
9
8
  #roundDuration;
10
9
  #epochs;
11
10
  #privacy;
11
+ #model;
12
12
  #training;
13
- constructor(task, model, client) {
14
- this.model = model;
13
+ get model() {
14
+ if (this.#model === undefined)
15
+ throw new Error("trainer's model has not been set");
16
+ return this.#model;
17
+ }
18
+ set model(model) {
19
+ this.#model = model;
20
+ }
21
+ constructor(task, client) {
15
22
  this.#client = client;
16
23
  this.#roundDuration = task.trainingInformation.roundDuration;
17
24
  this.#epochs = task.trainingInformation.epochs;
@@ -37,12 +44,12 @@ export class Trainer {
37
44
  const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
38
45
  let previousRoundWeights;
39
46
  for (let round = 0; round < totalRound; round++) {
40
- await this.#client.onRoundBeginCommunication(this.model.weights, round);
47
+ await this.#client.onRoundBeginCommunication();
41
48
  yield this.#runRound(dataset, valDataset);
42
49
  let localWeights = this.model.weights;
43
50
  if (this.#privacy !== undefined)
44
51
  localWeights = await applyPrivacy(previousRoundWeights, localWeights, this.#privacy);
45
- const networkWeights = await this.#client.onRoundEndCommunication(localWeights, round);
52
+ const networkWeights = await this.#client.onRoundEndCommunication(localWeights);
46
53
  this.model.weights = previousRoundWeights = networkWeights;
47
54
  }
48
55
  }
@@ -47,9 +47,7 @@ export class EventEmitter {
47
47
  emit(event, value) {
48
48
  const eventListeners = this.listeners[event] ?? List();
49
49
  this.listeners[event] = eventListeners.filterNot(([once]) => once);
50
- eventListeners.forEach(([_, listener]) => {
51
- listener(value);
52
- });
50
+ eventListeners.forEach(([_, listener]) => { listener(value); });
53
51
  }
54
52
  }
55
53
  /** `EventEmitter` for all events */
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20240902094132.0",
3
+ "version": "3.0.1-p20240902162912.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1,38 +0,0 @@
1
- import { type WeightsContainer } from "../../index.js";
2
- import { Base as Client } from "../base.js";
3
- /**
4
- * Client class that communicates with a centralized, federated server, when training
5
- * a specific task in the federated setting.
6
- */
7
- export declare class Base extends Client {
8
- #private;
9
- /**
10
- * Arbitrary node id assigned to the federated server which we are communicating with.
11
- * Indeed, the server acts as a node within the network. In the federated setting described
12
- * by this client class, the server is the only node which we are communicating with.
13
- */
14
- static readonly SERVER_NODE_ID = "federated-server-node-id";
15
- get nbOfParticipants(): number;
16
- /**
17
- * Opens a new WebSocket connection with the server and listens to new messages over the channel
18
- */
19
- private connectServer;
20
- /**
21
- * Initializes the connection to the server and get our own node id.
22
- * TODO: In the federated setting, should return the current server-side round
23
- * for the task.
24
- */
25
- connect(): Promise<void>;
26
- /**
27
- * Disconnection process when user quits the task.
28
- */
29
- disconnect(): Promise<void>;
30
- onRoundBeginCommunication(): Promise<void>;
31
- onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
32
- /**
33
- * Send a message containing our local weight updates to the federated server.
34
- * And waits for the server to reply with the most recent aggregated weights
35
- * @param payload The weight updates to send
36
- */
37
- private sendPayloadAndReceiveResult;
38
- }
@@ -1,130 +0,0 @@
1
- import createDebug from "debug";
2
- import { serialization, } from "../../index.js";
3
- import { Base as Client } from "../base.js";
4
- import { type } from "../messages.js";
5
- import { waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
6
- import * as messages from "./messages.js";
7
- const debug = createDebug("discojs:client:federated");
8
- /**
9
- * Client class that communicates with a centralized, federated server, when training
10
- * a specific task in the federated setting.
11
- */
12
- export class Base extends Client {
13
- /**
14
- * Arbitrary node id assigned to the federated server which we are communicating with.
15
- * Indeed, the server acts as a node within the network. In the federated setting described
16
- * by this client class, the server is the only node which we are communicating with.
17
- */
18
- static SERVER_NODE_ID = "federated-server-node-id";
19
- // Total number of other federated contributors, including this client, excluding the server
20
- // E.g., if 3 users are training a federated model, nbOfParticipants is 3
21
- #nbOfParticipants = 1;
22
- // the number of participants excluding the server
23
- get nbOfParticipants() {
24
- return this.#nbOfParticipants;
25
- }
26
- /**
27
- * Opens a new WebSocket connection with the server and listens to new messages over the channel
28
- */
29
- async connectServer(url) {
30
- const server = await WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated);
31
- return server;
32
- }
33
- /**
34
- * Initializes the connection to the server and get our own node id.
35
- * TODO: In the federated setting, should return the current server-side round
36
- * for the task.
37
- */
38
- async connect() {
39
- const serverURL = new URL("", this.url.href);
40
- switch (this.url.protocol) {
41
- case "http:":
42
- serverURL.protocol = "ws:";
43
- break;
44
- case "https:":
45
- serverURL.protocol = "wss:";
46
- break;
47
- default:
48
- throw new Error(`unknown protocol: ${this.url.protocol}`);
49
- }
50
- serverURL.pathname += `feai/${this.task.id}`;
51
- this._server = await this.connectServer(serverURL);
52
- this.aggregator.registerNode(Base.SERVER_NODE_ID);
53
- const msg = {
54
- type: type.ClientConnected,
55
- };
56
- this.server.send(msg);
57
- const received = await waitMessageWithTimeout(this.server, type.AssignNodeID);
58
- debug(`[${received.id}] assign id generated by the server`);
59
- this._ownId = received.id;
60
- }
61
- /**
62
- * Disconnection process when user quits the task.
63
- */
64
- async disconnect() {
65
- await this.server.disconnect();
66
- this._server = undefined;
67
- this._ownId = undefined;
68
- this.aggregator.setNodes(this.aggregator.nodes.delete(Base.SERVER_NODE_ID));
69
- return Promise.resolve();
70
- }
71
- onRoundBeginCommunication() {
72
- // Prepare the result promise for the incoming round
73
- this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
74
- return Promise.resolve();
75
- }
76
- async onRoundEndCommunication(weights, round) {
77
- // NB: For now, we suppose a fully-federated setting.
78
- if (this.aggregationResult === undefined) {
79
- throw new Error("local aggregation result was not set");
80
- }
81
- // Send our local contribution to the server
82
- // and receive the most recent weights as an answer to our contribution
83
- const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first());
84
- if (serverResult !== undefined &&
85
- this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
86
- // Regular case: the server sends us its aggregation result which will serve our
87
- // own aggregation result.
88
- }
89
- else {
90
- // Unexpected case: for some reason, the server result is stale.
91
- // We proceed to the next round without its result.
92
- debug(`[${this.ownId}] server result is either stale or not received`);
93
- this.aggregator.nextRound();
94
- }
95
- return await this.aggregationResult;
96
- }
97
- /**
98
- * Send a message containing our local weight updates to the federated server.
99
- * And waits for the server to reply with the most recent aggregated weights
100
- * @param payload The weight updates to send
101
- */
102
- async sendPayloadAndReceiveResult(payload) {
103
- const msg = {
104
- type: type.SendPayload,
105
- payload: await serialization.weights.encode(payload),
106
- round: this.aggregator.round,
107
- };
108
- this.server.send(msg);
109
- // Waits for the server's result for its current (most recent) round and add it to our aggregator.
110
- // Updates the aggregator's round if it's behind the server's.
111
- try {
112
- // It is important than the client immediately awaits the server result or it may miss it
113
- const { payload, round, nbOfParticipants } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload);
114
- const serverRound = round;
115
- this.#nbOfParticipants = nbOfParticipants; // Save the current participants
116
- // Store the server result only if it is not stale
117
- if (this.aggregator.round <= round) {
118
- const serverResult = serialization.weights.decode(payload);
119
- // Update the local round to match the server's
120
- if (this.aggregator.round < serverRound) {
121
- this.aggregator.setRound(serverRound);
122
- }
123
- return serverResult;
124
- }
125
- }
126
- catch (e) {
127
- debug(`[${this.ownId}] while receiving results: %o`, e);
128
- }
129
- }
130
- }
@@ -1,5 +0,0 @@
1
- import { WeightsContainer } from "../weights/weights_container.js";
2
- import { Base } from "./base.js";
3
- export declare class Local extends Base {
4
- onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
5
- }
@@ -1,6 +0,0 @@
1
- import { Base } from "./base.js";
2
- export class Local extends Base {
3
- onRoundEndCommunication(weights) {
4
- return Promise.resolve(weights);
5
- }
6
- }
@@ -1,5 +0,0 @@
1
- export interface Digest {
2
- algorithm: string;
3
- value: string;
4
- }
5
- export declare function isDigest(raw: unknown): raw is Digest;
@@ -1,14 +0,0 @@
1
- export function isDigest(raw) {
2
- if (typeof raw !== 'object' || raw === null) {
3
- return false;
4
- }
5
- const { algorithm, value } = raw;
6
- if (!(typeof algorithm === 'string' &&
7
- typeof value === 'string')) {
8
- return false;
9
- }
10
- const repack = { algorithm, value };
11
- const _correct = repack;
12
- const _total = repack;
13
- return true;
14
- }