@epfml/discojs 3.0.1-p20240902100041.0 → 3.0.1-p20240904094219.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. package/dist/aggregator/base.d.ts +16 -2
  2. package/dist/aggregator/base.js +25 -3
  3. package/dist/aggregator/mean.d.ts +1 -0
  4. package/dist/aggregator/mean.js +11 -6
  5. package/dist/aggregator/secure.js +1 -1
  6. package/dist/client/{base.d.ts → client.d.ts} +13 -30
  7. package/dist/client/{base.js → client.js} +10 -20
  8. package/dist/client/decentralized/{base.d.ts → decentralized_client.d.ts} +5 -5
  9. package/dist/client/decentralized/{base.js → decentralized_client.js} +20 -16
  10. package/dist/client/decentralized/index.d.ts +1 -1
  11. package/dist/client/decentralized/index.js +1 -1
  12. package/dist/client/decentralized/messages.d.ts +7 -2
  13. package/dist/client/decentralized/messages.js +4 -2
  14. package/dist/client/event_connection.js +2 -2
  15. package/dist/client/federated/federated_client.d.ts +44 -0
  16. package/dist/client/federated/federated_client.js +210 -0
  17. package/dist/client/federated/index.d.ts +1 -1
  18. package/dist/client/federated/index.js +1 -1
  19. package/dist/client/federated/messages.d.ts +17 -2
  20. package/dist/client/federated/messages.js +3 -1
  21. package/dist/client/index.d.ts +2 -2
  22. package/dist/client/index.js +2 -2
  23. package/dist/client/local_client.d.ts +10 -0
  24. package/dist/client/local_client.js +14 -0
  25. package/dist/client/messages.d.ts +6 -8
  26. package/dist/client/messages.js +23 -7
  27. package/dist/client/utils.js +1 -1
  28. package/dist/default_tasks/cifar10.js +1 -2
  29. package/dist/default_tasks/lus_covid.js +1 -1
  30. package/dist/default_tasks/mnist.js +1 -2
  31. package/dist/default_tasks/simple_face.js +2 -2
  32. package/dist/default_tasks/titanic.js +2 -2
  33. package/dist/default_tasks/wikitext.js +1 -1
  34. package/dist/index.d.ts +4 -2
  35. package/dist/index.js +1 -1
  36. package/dist/logging/logger.d.ts +1 -1
  37. package/dist/serialization/model.js +18 -9
  38. package/dist/task/index.d.ts +0 -1
  39. package/dist/task/index.js +0 -1
  40. package/dist/task/task.d.ts +0 -2
  41. package/dist/task/task.js +2 -4
  42. package/dist/task/training_information.d.ts +1 -2
  43. package/dist/task/training_information.js +3 -5
  44. package/dist/training/disco.d.ts +14 -16
  45. package/dist/training/disco.js +22 -46
  46. package/dist/training/index.d.ts +1 -1
  47. package/dist/training/trainer.d.ts +3 -2
  48. package/dist/training/trainer.js +12 -5
  49. package/dist/utils/event_emitter.js +1 -3
  50. package/package.json +1 -1
  51. package/dist/client/federated/base.d.ts +0 -38
  52. package/dist/client/federated/base.js +0 -130
  53. package/dist/client/local.d.ts +0 -5
  54. package/dist/client/local.js +0 -6
  55. package/dist/memory/base.d.ts +0 -111
  56. package/dist/memory/base.js +0 -9
  57. package/dist/memory/empty.d.ts +0 -20
  58. package/dist/memory/empty.js +0 -43
  59. package/dist/memory/index.d.ts +0 -2
  60. package/dist/memory/index.js +0 -2
  61. package/dist/task/digest.d.ts +0 -5
  62. package/dist/task/digest.js +0 -14
@@ -5,7 +5,6 @@ interface Privacy {
5
5
  noiseScale?: number;
6
6
  }
7
7
  export interface TrainingInformation {
8
- modelID: string;
9
8
  epochs: number;
10
9
  roundDuration: number;
11
10
  validationSplit: number;
@@ -21,7 +20,7 @@ export interface TrainingInformation {
21
20
  privacy?: Privacy;
22
21
  decentralizedSecure?: boolean;
23
22
  maxShareValue?: number;
24
- minimumReadyPeers?: number;
23
+ minNbOfParticipants: number;
25
24
  aggregator?: 'mean' | 'secure';
26
25
  tokenizer?: string | PreTrainedTokenizer;
27
26
  maxSequenceLength?: number;
@@ -24,20 +24,19 @@ export function isTrainingInformation(raw) {
24
24
  if (typeof raw !== 'object' || raw === null) {
25
25
  return false;
26
26
  }
27
- const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
27
+ const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
28
28
  if (typeof dataType !== 'string' ||
29
- typeof modelID !== 'string' ||
30
29
  typeof epochs !== 'number' ||
31
30
  typeof batchSize !== 'number' ||
32
31
  typeof roundDuration !== 'number' ||
33
32
  typeof validationSplit !== 'number' ||
33
+ typeof minNbOfParticipants !== 'number' ||
34
34
  (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
35
35
  (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
36
36
  (aggregator !== undefined && typeof aggregator !== 'string') ||
37
37
  (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
38
38
  (privacy !== undefined && !isPrivacy(privacy)) ||
39
39
  (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
40
- (minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
41
40
  (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
42
41
  (IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
43
42
  (LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
@@ -96,8 +95,7 @@ export function isTrainingInformation(raw) {
96
95
  epochs,
97
96
  inputColumns,
98
97
  maxShareValue,
99
- minimumReadyPeers,
100
- modelID,
98
+ minNbOfParticipants,
101
99
  outputColumns,
102
100
  preprocessingFunctions,
103
101
  roundDuration,
@@ -1,32 +1,34 @@
1
- import { client as clients, BatchLogs, EpochLogs, Logger, Memory, Task, TrainingInformation } from "../index.js";
1
+ import { client as clients, BatchLogs, EpochLogs, Logger, Task, TrainingInformation } from "../index.js";
2
2
  import type { TypedLabeledDataset } from "../index.js";
3
3
  import type { Aggregator } from "../aggregator/index.js";
4
+ import { EventEmitter } from "../utils/event_emitter.js";
4
5
  import { RoundLogs, Trainer } from "./trainer.js";
5
- interface Config {
6
+ interface DiscoConfig {
6
7
  scheme: TrainingInformation["scheme"];
7
8
  logger: Logger;
8
- memory: Memory;
9
9
  }
10
+ export type RoundStatus = "Waiting for more participants" | "Retrieving peers' information" | "Updating the model with other participants' models" | "Training the model on the data you connected";
10
11
  /**
11
12
  * Top-level class handling distributed training from a client's perspective. It is meant to be
12
- * a convenient object providing a reduced yet complete API that wraps model training,
13
- * communication with nodes, logs and model memory.
13
+ * a convenient object providing a reduced yet complete API that wraps model training and
14
+ * communication with nodes.
14
15
  */
15
- export declare class Disco {
16
+ export declare class Disco extends EventEmitter<{
17
+ 'status': RoundStatus;
18
+ }> {
16
19
  #private;
17
20
  readonly trainer: Trainer;
18
- private constructor();
19
21
  /**
20
22
  * Connect to the given task and get ready to train.
21
23
  *
22
- * Will load the model from memory if available or fetch it from the server.
23
- *
24
+ * @param task
24
25
  * @param clientConfig client to connect with or parameters on how to create one.
25
- **/
26
- static fromTask(task: Task, clientConfig: clients.Client | URL | {
26
+ * @param config the DiscoConfig
27
+ */
28
+ constructor(task: Task, clientConfig: clients.Client | URL | {
27
29
  aggregator: Aggregator;
28
30
  url: URL;
29
- }, config: Partial<Config>): Promise<Disco>;
31
+ }, config: Partial<DiscoConfig>);
30
32
  /** Train on dataset, yielding logs of every round. */
31
33
  trainByRound(dataset: TypedLabeledDataset): AsyncGenerator<RoundLogs>;
32
34
  /** Train on dataset, yielding logs of every epoch. */
@@ -42,10 +44,6 @@ export declare class Disco {
42
44
  * If you don't care about the whole process, use one of the other train methods.
43
45
  **/
44
46
  train(dataset: TypedLabeledDataset): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
45
- /**
46
- * Stops the ongoing training instance without disconnecting the client.
47
- */
48
- pause(): Promise<void>;
49
47
  /**
50
48
  * Completely stops the ongoing training instance.
51
49
  */
@@ -1,38 +1,31 @@
1
- import { async_iterator, client as clients, ConsoleLogger, EmptyMemory, } from "../index.js";
1
+ import { async_iterator, client as clients, ConsoleLogger, } from "../index.js";
2
2
  import { getAggregator } from "../aggregator/index.js";
3
3
  import { enumerate, split } from "../utils/async_iterator.js";
4
+ import { EventEmitter } from "../utils/event_emitter.js";
4
5
  import { Trainer } from "./trainer.js";
5
6
  import { labeledDatasetToDataSplit } from "../dataset/data/helpers.js";
6
7
  /**
7
8
  * Top-level class handling distributed training from a client's perspective. It is meant to be
8
- * a convenient object providing a reduced yet complete API that wraps model training,
9
- * communication with nodes, logs and model memory.
9
+ * a convenient object providing a reduced yet complete API that wraps model training and
10
+ * communication with nodes.
10
11
  */
11
- export class Disco {
12
+ export class Disco extends EventEmitter {
12
13
  trainer;
13
14
  #client;
14
15
  #logger;
15
- #memory;
16
16
  #task;
17
- constructor(trainer, task, client, memory, logger) {
18
- this.trainer = trainer;
19
- this.#client = client;
20
- this.#logger = logger;
21
- this.#memory = memory;
22
- this.#task = task;
23
- }
24
17
  /**
25
18
  * Connect to the given task and get ready to train.
26
19
  *
27
- * Will load the model from memory if available or fetch it from the server.
28
- *
20
+ * @param task
29
21
  * @param clientConfig client to connect with or parameters on how to create one.
30
- **/
31
- static async fromTask(task, clientConfig, config) {
32
- const { scheme, logger, memory } = {
22
+ * @param config the DiscoConfig
23
+ */
24
+ constructor(task, clientConfig, config) {
25
+ super();
26
+ const { scheme, logger } = {
33
27
  scheme: task.trainingInformation.scheme,
34
28
  logger: new ConsoleLogger(),
35
- memory: new EmptyMemory(),
36
29
  ...config,
37
30
  };
38
31
  let client;
@@ -52,18 +45,12 @@ export class Disco {
52
45
  }
53
46
  if (client.task !== task)
54
47
  throw new Error("client not setup for given task");
55
- let model;
56
- const memoryInfo = {
57
- type: "working",
58
- taskID: task.id,
59
- name: task.trainingInformation.modelID,
60
- tensorBackend: task.trainingInformation.tensorBackend,
61
- };
62
- if (await memory.contains(memoryInfo))
63
- model = await memory.getModel(memoryInfo);
64
- else
65
- model = await client.getLatestModel();
66
- return new Disco(new Trainer(task, model, client), task, client, memory, logger);
48
+ this.#logger = logger;
49
+ this.#client = client;
50
+ this.#task = task;
51
+ this.trainer = new Trainer(task, client);
52
+ // Simply propagate the training status events emitted by the client
53
+ this.#client.on('status', status => this.emit('status', status));
67
54
  }
68
55
  /** Train on dataset, yielding logs of every round. */
69
56
  async *trainByRound(dataset) {
@@ -106,11 +93,12 @@ export class Disco {
106
93
  * If you don't care about the whole process, use one of the other train methods.
107
94
  **/
108
95
  async *train(dataset) {
109
- this.#logger.success("Training started.");
96
+ this.#logger.success("Training started");
110
97
  const data = await labeledDatasetToDataSplit(this.#task, dataset);
111
98
  const trainData = data.train.preprocess().batch().dataset;
112
99
  const validationData = data.validation?.preprocess().batch().dataset ?? trainData;
113
- await this.#client.connect();
100
+ // the client fetches the latest weights upon connection
101
+ this.trainer.model = await this.#client.connect();
114
102
  for await (const [round, epochs] of enumerate(this.trainer.train(trainData, validationData))) {
115
103
  yield async function* () {
116
104
  const [gen, returnedRoundLogs] = split(epochs);
@@ -123,6 +111,7 @@ export class Disco {
123
111
  ` Epoch: ${epoch}`,
124
112
  ` Training loss: ${epochLogs.training.loss}`,
125
113
  ` Training accuracy: ${epochLogs.training.accuracy}`,
114
+ ` Peak memory: ${epochLogs.peakMemory}`,
126
115
  epochLogs.validation !== undefined
127
116
  ? ` Validation loss: ${epochLogs.validation.loss}`
128
117
  : "",
@@ -133,26 +122,13 @@ export class Disco {
133
122
  }
134
123
  return await returnedRoundLogs;
135
124
  }.bind(this)();
136
- await this.#memory.updateWorkingModel({
137
- type: "working",
138
- taskID: this.#task.id,
139
- name: this.#task.trainingInformation.modelID,
140
- tensorBackend: this.#task.trainingInformation.tensorBackend,
141
- }, this.trainer.model);
142
125
  }
143
- this.#logger.success("Training finished.");
144
- }
145
- /**
146
- * Stops the ongoing training instance without disconnecting the client.
147
- */
148
- async pause() {
149
- await this.trainer.stopTraining();
126
+ this.#logger.success("Training finished");
150
127
  }
151
128
  /**
152
129
  * Completely stops the ongoing training instance.
153
130
  */
154
131
  async close() {
155
- await this.pause();
156
132
  await this.#client.disconnect();
157
133
  }
158
134
  }
@@ -1,2 +1,2 @@
1
- export { Disco } from './disco.js';
1
+ export { Disco, RoundStatus } from './disco.js';
2
2
  export { RoundLogs, Trainer } from './trainer.js';
@@ -9,8 +9,9 @@ export interface RoundLogs {
9
9
  /** Train a model and exchange with others **/
10
10
  export declare class Trainer {
11
11
  #private;
12
- readonly model: Model;
13
- constructor(task: Task, model: Model, client: Client);
12
+ get model(): Model;
13
+ set model(model: Model);
14
+ constructor(task: Task, client: Client);
14
15
  stopTraining(): Promise<void>;
15
16
  train(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
16
17
  }
@@ -4,14 +4,21 @@ import { privacy } from "../index.js";
4
4
  import * as async_iterator from "../utils/async_iterator.js";
5
5
  /** Train a model and exchange with others **/
6
6
  export class Trainer {
7
- model;
8
7
  #client;
9
8
  #roundDuration;
10
9
  #epochs;
11
10
  #privacy;
11
+ #model;
12
12
  #training;
13
- constructor(task, model, client) {
14
- this.model = model;
13
+ get model() {
14
+ if (this.#model === undefined)
15
+ throw new Error("trainer's model has not been set");
16
+ return this.#model;
17
+ }
18
+ set model(model) {
19
+ this.#model = model;
20
+ }
21
+ constructor(task, client) {
15
22
  this.#client = client;
16
23
  this.#roundDuration = task.trainingInformation.roundDuration;
17
24
  this.#epochs = task.trainingInformation.epochs;
@@ -37,12 +44,12 @@ export class Trainer {
37
44
  const totalRound = Math.trunc(this.#epochs / this.#roundDuration);
38
45
  let previousRoundWeights;
39
46
  for (let round = 0; round < totalRound; round++) {
40
- await this.#client.onRoundBeginCommunication(this.model.weights, round);
47
+ await this.#client.onRoundBeginCommunication();
41
48
  yield this.#runRound(dataset, valDataset);
42
49
  let localWeights = this.model.weights;
43
50
  if (this.#privacy !== undefined)
44
51
  localWeights = await applyPrivacy(previousRoundWeights, localWeights, this.#privacy);
45
- const networkWeights = await this.#client.onRoundEndCommunication(localWeights, round);
52
+ const networkWeights = await this.#client.onRoundEndCommunication(localWeights);
46
53
  this.model.weights = previousRoundWeights = networkWeights;
47
54
  }
48
55
  }
@@ -47,9 +47,7 @@ export class EventEmitter {
47
47
  emit(event, value) {
48
48
  const eventListeners = this.listeners[event] ?? List();
49
49
  this.listeners[event] = eventListeners.filterNot(([once]) => once);
50
- eventListeners.forEach(([_, listener]) => {
51
- listener(value);
52
- });
50
+ eventListeners.forEach(([_, listener]) => { listener(value); });
53
51
  }
54
52
  }
55
53
  /** `EventEmitter` for all events */
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20240902100041.0",
3
+ "version": "3.0.1-p20240904094219.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -1,38 +0,0 @@
1
- import { type WeightsContainer } from "../../index.js";
2
- import { Base as Client } from "../base.js";
3
- /**
4
- * Client class that communicates with a centralized, federated server, when training
5
- * a specific task in the federated setting.
6
- */
7
- export declare class Base extends Client {
8
- #private;
9
- /**
10
- * Arbitrary node id assigned to the federated server which we are communicating with.
11
- * Indeed, the server acts as a node within the network. In the federated setting described
12
- * by this client class, the server is the only node which we are communicating with.
13
- */
14
- static readonly SERVER_NODE_ID = "federated-server-node-id";
15
- get nbOfParticipants(): number;
16
- /**
17
- * Opens a new WebSocket connection with the server and listens to new messages over the channel
18
- */
19
- private connectServer;
20
- /**
21
- * Initializes the connection to the server and get our own node id.
22
- * TODO: In the federated setting, should return the current server-side round
23
- * for the task.
24
- */
25
- connect(): Promise<void>;
26
- /**
27
- * Disconnection process when user quits the task.
28
- */
29
- disconnect(): Promise<void>;
30
- onRoundBeginCommunication(): Promise<void>;
31
- onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<WeightsContainer>;
32
- /**
33
- * Send a message containing our local weight updates to the federated server.
34
- * And waits for the server to reply with the most recent aggregated weights
35
- * @param payload The weight updates to send
36
- */
37
- private sendPayloadAndReceiveResult;
38
- }
@@ -1,130 +0,0 @@
1
- import createDebug from "debug";
2
- import { serialization, } from "../../index.js";
3
- import { Base as Client } from "../base.js";
4
- import { type } from "../messages.js";
5
- import { waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
6
- import * as messages from "./messages.js";
7
- const debug = createDebug("discojs:client:federated");
8
- /**
9
- * Client class that communicates with a centralized, federated server, when training
10
- * a specific task in the federated setting.
11
- */
12
- export class Base extends Client {
13
- /**
14
- * Arbitrary node id assigned to the federated server which we are communicating with.
15
- * Indeed, the server acts as a node within the network. In the federated setting described
16
- * by this client class, the server is the only node which we are communicating with.
17
- */
18
- static SERVER_NODE_ID = "federated-server-node-id";
19
- // Total number of other federated contributors, including this client, excluding the server
20
- // E.g., if 3 users are training a federated model, nbOfParticipants is 3
21
- #nbOfParticipants = 1;
22
- // the number of participants excluding the server
23
- get nbOfParticipants() {
24
- return this.#nbOfParticipants;
25
- }
26
- /**
27
- * Opens a new WebSocket connection with the server and listens to new messages over the channel
28
- */
29
- async connectServer(url) {
30
- const server = await WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated);
31
- return server;
32
- }
33
- /**
34
- * Initializes the connection to the server and get our own node id.
35
- * TODO: In the federated setting, should return the current server-side round
36
- * for the task.
37
- */
38
- async connect() {
39
- const serverURL = new URL("", this.url.href);
40
- switch (this.url.protocol) {
41
- case "http:":
42
- serverURL.protocol = "ws:";
43
- break;
44
- case "https:":
45
- serverURL.protocol = "wss:";
46
- break;
47
- default:
48
- throw new Error(`unknown protocol: ${this.url.protocol}`);
49
- }
50
- serverURL.pathname += `feai/${this.task.id}`;
51
- this._server = await this.connectServer(serverURL);
52
- this.aggregator.registerNode(Base.SERVER_NODE_ID);
53
- const msg = {
54
- type: type.ClientConnected,
55
- };
56
- this.server.send(msg);
57
- const received = await waitMessageWithTimeout(this.server, type.AssignNodeID);
58
- debug(`[${received.id}] assign id generated by the server`);
59
- this._ownId = received.id;
60
- }
61
- /**
62
- * Disconnection process when user quits the task.
63
- */
64
- async disconnect() {
65
- await this.server.disconnect();
66
- this._server = undefined;
67
- this._ownId = undefined;
68
- this.aggregator.setNodes(this.aggregator.nodes.delete(Base.SERVER_NODE_ID));
69
- return Promise.resolve();
70
- }
71
- onRoundBeginCommunication() {
72
- // Prepare the result promise for the incoming round
73
- this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
74
- return Promise.resolve();
75
- }
76
- async onRoundEndCommunication(weights, round) {
77
- // NB: For now, we suppose a fully-federated setting.
78
- if (this.aggregationResult === undefined) {
79
- throw new Error("local aggregation result was not set");
80
- }
81
- // Send our local contribution to the server
82
- // and receive the most recent weights as an answer to our contribution
83
- const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first());
84
- if (serverResult !== undefined &&
85
- this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
86
- // Regular case: the server sends us its aggregation result which will serve our
87
- // own aggregation result.
88
- }
89
- else {
90
- // Unexpected case: for some reason, the server result is stale.
91
- // We proceed to the next round without its result.
92
- debug(`[${this.ownId}] server result is either stale or not received`);
93
- this.aggregator.nextRound();
94
- }
95
- return await this.aggregationResult;
96
- }
97
- /**
98
- * Send a message containing our local weight updates to the federated server.
99
- * And waits for the server to reply with the most recent aggregated weights
100
- * @param payload The weight updates to send
101
- */
102
- async sendPayloadAndReceiveResult(payload) {
103
- const msg = {
104
- type: type.SendPayload,
105
- payload: await serialization.weights.encode(payload),
106
- round: this.aggregator.round,
107
- };
108
- this.server.send(msg);
109
- // Waits for the server's result for its current (most recent) round and add it to our aggregator.
110
- // Updates the aggregator's round if it's behind the server's.
111
- try {
112
- // It is important than the client immediately awaits the server result or it may miss it
113
- const { payload, round, nbOfParticipants } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload);
114
- const serverRound = round;
115
- this.#nbOfParticipants = nbOfParticipants; // Save the current participants
116
- // Store the server result only if it is not stale
117
- if (this.aggregator.round <= round) {
118
- const serverResult = serialization.weights.decode(payload);
119
- // Update the local round to match the server's
120
- if (this.aggregator.round < serverRound) {
121
- this.aggregator.setRound(serverRound);
122
- }
123
- return serverResult;
124
- }
125
- }
126
- catch (e) {
127
- debug(`[${this.ownId}] while receiving results: %o`, e);
128
- }
129
- }
130
- }
@@ -1,5 +0,0 @@
1
- import { WeightsContainer } from "../weights/weights_container.js";
2
- import { Base } from "./base.js";
3
- export declare class Local extends Base {
4
- onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
5
- }
@@ -1,6 +0,0 @@
1
- import { Base } from "./base.js";
2
- export class Local extends Base {
3
- onRoundEndCommunication(weights) {
4
- return Promise.resolve(weights);
5
- }
6
- }
@@ -1,111 +0,0 @@
1
- import type { Model, TaskID } from '../index.js';
2
- /**
3
- * Type of models stored in memory. Stored models can either be a model currently
4
- * being trained ("working model") or a regular model saved in memory ("saved model").
5
- * There can only be a single working model for a given task.
6
- */
7
- type StoredModelType = 'saved' | 'working';
8
- /**
9
- * Model information which uniquely identifies a model in memory.
10
- */
11
- export interface ModelInfo {
12
- type: StoredModelType;
13
- version?: number;
14
- taskID: TaskID;
15
- name: string;
16
- tensorBackend: 'gpt' | 'tfjs';
17
- }
18
- /**
19
- * A model source uniquely identifies a model stored in memory.
20
- * It can be in the form of either a model info object or an ID
21
- * (one-to-one mapping between the two)
22
- */
23
- export type ModelSource = ModelInfo | string;
24
- /**
25
- * Represents a model memory system, providing functions to fetch, save, delete and update models.
26
- * Stored models can either be a model currently being trained ("working model") or a regular model
27
- * saved in memory ("saved model"). There can only be a single working model for a given task.
28
- */
29
- export declare abstract class Memory {
30
- /**
31
- * Fetches the model identified by the given model source.
32
- * @param source The model source
33
- * @returns The model
34
- */
35
- abstract getModel(source: ModelSource): Promise<Model>;
36
- /**
37
- * Removes the model identified by the given model source from memory.
38
- * @param source The model source
39
- * @returns The model
40
- */
41
- abstract deleteModel(source: ModelSource): Promise<void>;
42
- /**
43
- * Replaces the corresponding working model with the saved model identified by the given model source.
44
- * @param source The model source
45
- */
46
- abstract loadModel(source: ModelSource): Promise<void>;
47
- /**
48
- * Fetches metadata for the model identified by the given model source.
49
- * If the model does not exist in memory, returns undefined.
50
- * @param source The model source
51
- * @returns The model metadata or undefined
52
- */
53
- abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
54
- /**
55
- * Replaces the working model identified by the given source with the newly provided model.
56
- * @param source The model source
57
- * @param model The new model
58
- */
59
- abstract updateWorkingModel(source: ModelSource, model: Model): Promise<void>;
60
- /**
61
- * Creates a saved model copy from the working model identified by the given model source.
62
- * Returns the saved model's path.
63
- * @param source The model source
64
- * @returns The saved model's path
65
- */
66
- abstract saveWorkingModel(source: ModelSource): Promise<string | undefined>;
67
- /**
68
- * Saves the newly provided model to the given model source.
69
- * Returns the saved model's path
70
- * @param source The model source
71
- * @param model The new model
72
- * @returns The saved model's path
73
- */
74
- abstract saveModel(source: ModelSource, model: Model): Promise<string | undefined>;
75
- /**
76
- * Moves the model identified by the model source to a file system. This is platform-dependent.
77
- * @param source The model source
78
- */
79
- abstract downloadModel(source: ModelSource): Promise<void>;
80
- /**
81
- * Checks whether the model memory contains the model identified by the given source.
82
- * @param source The model source
83
- * @returns True if the memory contains the model, false otherwise
84
- */
85
- abstract contains(source: ModelSource): Promise<boolean>;
86
- /**
87
- * Computes the path in memory corresponding to the given model source, be it a path or model information.
88
- * This is used to easily switch between model path and information, which are both unique model identifiers
89
- * with a one-to-one equivalence. Returns undefined instead if no path could be inferred from the given
90
- * model source.
91
- * @param source The model source
92
- * @returns The model path
93
- */
94
- abstract getModelMemoryPath(source: ModelSource): string | undefined;
95
- /**
96
- * Computes the model information corresponding to the given model source, be it a path or model information.
97
- * This is used to easily switch between model path and information, which are both unique model identifiers
98
- * with a one-to-one equivalence. Returns undefined instead if no unique model information could be inferred
99
- * from the given model source.
100
- * @param source The model source
101
- * @returns The model information
102
- */
103
- abstract getModelInfo(source: ModelSource): ModelInfo | undefined;
104
- /**
105
- * Computes the lowest version a model source can have without conflicting with model versions currently in memory.
106
- * @param source The model source
107
- * @returns The duplicated model source
108
- */
109
- abstract duplicateSource(source: ModelSource): Promise<ModelSource | undefined>;
110
- }
111
- export {};
@@ -1,9 +0,0 @@
1
- // only used browser-side
2
- // TODO: replace IO type
3
- /**
4
- * Represents a model memory system, providing functions to fetch, save, delete and update models.
5
- * Stored models can either be a model currently being trained ("working model") or a regular model
6
- * saved in memory ("saved model"). There can only be a single working model for a given task.
7
- */
8
- export class Memory {
9
- }
@@ -1,20 +0,0 @@
1
- import type { Model } from '../index.js';
2
- import type { ModelInfo } from './base.js';
3
- import { Memory } from './base.js';
4
- /**
5
- * Represents an empty model memory.
6
- */
7
- export declare class Empty extends Memory {
8
- getModelMetadata(): Promise<undefined>;
9
- contains(): Promise<boolean>;
10
- getModel(): Promise<Model>;
11
- loadModel(): Promise<void>;
12
- updateWorkingModel(): Promise<void>;
13
- saveWorkingModel(): Promise<undefined>;
14
- saveModel(): Promise<undefined>;
15
- deleteModel(): Promise<void>;
16
- downloadModel(): Promise<void>;
17
- getModelMemoryPath(): string;
18
- getModelInfo(): ModelInfo;
19
- duplicateSource(): Promise<undefined>;
20
- }