@epfml/discojs 3.0.1-p20240902094132.0 → 3.0.1-p20240902162912.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. package/dist/aggregator/base.d.ts +16 -2
  2. package/dist/aggregator/base.js +25 -3
  3. package/dist/aggregator/mean.d.ts +1 -0
  4. package/dist/aggregator/mean.js +11 -6
  5. package/dist/aggregator/secure.js +1 -1
  6. package/dist/client/{base.d.ts → client.d.ts} +13 -30
  7. package/dist/client/{base.js → client.js} +10 -20
  8. package/dist/client/decentralized/{base.d.ts → decentralized_client.d.ts} +5 -5
  9. package/dist/client/decentralized/{base.js → decentralized_client.js} +20 -16
  10. package/dist/client/decentralized/index.d.ts +1 -1
  11. package/dist/client/decentralized/index.js +1 -1
  12. package/dist/client/decentralized/messages.d.ts +7 -2
  13. package/dist/client/decentralized/messages.js +4 -2
  14. package/dist/client/event_connection.js +2 -2
  15. package/dist/client/federated/federated_client.d.ts +44 -0
  16. package/dist/client/federated/federated_client.js +210 -0
  17. package/dist/client/federated/index.d.ts +1 -1
  18. package/dist/client/federated/index.js +1 -1
  19. package/dist/client/federated/messages.d.ts +17 -2
  20. package/dist/client/federated/messages.js +3 -1
  21. package/dist/client/index.d.ts +2 -2
  22. package/dist/client/index.js +2 -2
  23. package/dist/client/local_client.d.ts +10 -0
  24. package/dist/client/local_client.js +14 -0
  25. package/dist/client/messages.d.ts +6 -8
  26. package/dist/client/messages.js +23 -7
  27. package/dist/client/utils.js +1 -1
  28. package/dist/default_tasks/cifar10.js +1 -1
  29. package/dist/default_tasks/lus_covid.js +1 -0
  30. package/dist/default_tasks/mnist.js +1 -1
  31. package/dist/default_tasks/simple_face.js +2 -1
  32. package/dist/default_tasks/titanic.js +2 -1
  33. package/dist/default_tasks/wikitext.js +1 -0
  34. package/dist/index.d.ts +4 -1
  35. package/dist/index.js +1 -0
  36. package/dist/logging/logger.d.ts +1 -1
  37. package/dist/task/index.d.ts +0 -1
  38. package/dist/task/index.js +0 -1
  39. package/dist/task/task.d.ts +0 -2
  40. package/dist/task/task.js +2 -4
  41. package/dist/task/training_information.d.ts +1 -1
  42. package/dist/task/training_information.js +3 -3
  43. package/dist/training/disco.d.ts +11 -12
  44. package/dist/training/disco.js +19 -34
  45. package/dist/training/index.d.ts +1 -1
  46. package/dist/training/trainer.d.ts +3 -2
  47. package/dist/training/trainer.js +12 -5
  48. package/dist/utils/event_emitter.js +1 -3
  49. package/package.json +1 -1
  50. package/dist/client/federated/base.d.ts +0 -38
  51. package/dist/client/federated/base.js +0 -130
  52. package/dist/client/local.d.ts +0 -5
  53. package/dist/client/local.js +0 -6
  54. package/dist/task/digest.d.ts +0 -5
  55. package/dist/task/digest.js +0 -14
@@ -0,0 +1,210 @@
1
+ import createDebug from "debug";
2
+ import { serialization } from "../../index.js";
3
+ import { Client } from "../client.js";
4
+ import { type } from "../messages.js";
5
+ import { waitMessage, waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
6
+ import * as messages from "./messages.js";
7
+ const debug = createDebug("discojs:client:federated");
8
+ /**
9
+ * Arbitrary node id assigned to the federated server which we are communicating with.
10
+ * Indeed, the server acts as a node within the network. In the federated setting described
11
+ * by this client class, the server is the only node which we are communicating with.
12
+ */
13
+ const SERVER_NODE_ID = "federated-server-node-id";
14
+ /**
15
+ * Client class that communicates with a centralized, federated server, when training
16
+ * a specific task in the federated setting.
17
+ */
18
+ export class FederatedClient extends Client {
19
+ // Total number of other federated contributors, including this client, excluding the server
20
+ // E.g., if 3 users are training a federated model, nbOfParticipants is 3
21
+ #nbOfParticipants = 1;
22
+ /**
23
+ * When the server notifies clients to pause and wait until more
24
+ * participants join, we rely on this promise to wait
25
+ * until the server signals that the training can resume
26
+ */
27
+ #promiseForMoreParticipants = undefined;
28
+ /**
29
+ * When the server notifies the client that they can resume training
30
+ * after waiting for more participants, we want to be able to display what
31
+ * we were doing before waiting (training locally or updating our model).
32
+ * We use this attribute to store the status to rollback to when we stop waiting
33
+ */
34
+ #previousStatus = undefined;
35
+ /**
36
+ * Whether the client should wait until more
37
+ * participants join the session, i.e. a promise has been created
38
+ */
39
+ get #waitingForMoreParticipants() {
40
+ return this.#promiseForMoreParticipants !== undefined;
41
+ }
42
+ // the number of participants excluding the server
43
+ get nbOfParticipants() {
44
+ return this.#nbOfParticipants;
45
+ }
46
+ /**
47
+ * Opens a new WebSocket connection with the server and listens to new messages over the channel
48
+ */
49
+ async connectServer(url) {
50
+ const server = await WebSocketServer.connect(url, messages.isMessageFederated, // can only receive federated message types from the server
51
+ messages.isMessageFederated);
52
+ return server;
53
+ }
54
+ /**
55
+ * Initializes the connection to the server, gets our node ID
56
+ * as well as the latest training information: latest global model, current round and
57
+ * whether we are waiting for more participants.
58
+ */
59
+ async connect() {
60
+ const model = await super.connect(); // Get the server base model
61
+ const serverURL = new URL("", this.url.href);
62
+ switch (this.url.protocol) {
63
+ case "http:":
64
+ serverURL.protocol = "ws:";
65
+ break;
66
+ case "https:":
67
+ serverURL.protocol = "wss:";
68
+ break;
69
+ default:
70
+ throw new Error(`unknown protocol: ${this.url.protocol}`);
71
+ }
72
+ serverURL.pathname += `federated/${this.task.id}`;
73
+ this._server = await this.connectServer(serverURL);
74
+ // Setup an event callback if the server signals that we should
75
+ // wait for more participants
76
+ this.server.on(type.WaitingForMoreParticipants, () => {
77
+ debug(`[${id.slice(0, 4)}] received WaitingForMoreParticipants message from server`);
78
+ // Display the waiting status right away
79
+ this.emit("status", "Waiting for more participants");
80
+ // Upon receiving a WaitingForMoreParticipants message,
81
+ // the client will await for this promise to resolve before sending its
82
+ // local weight update
83
+ this.#promiseForMoreParticipants = this.waitForMoreParticipants();
84
+ });
85
+ // As an example assume we need at least 2 participants to train,
86
+ // When two participants join almost at the same time, the server
87
+ // sends a NewFederatedNodeInfo with waitForMoreParticipants=true to the first participant
88
+ // and directly follows with an EnoughParticipants message when the 2nd participant joins
89
+ // However, the EnoughParticipants can arrive before the NewFederatedNodeInfo (which is much bigger)
90
+ // so we check whether we received the EnoughParticipants before being assigned a node ID
91
+ let receivedEnoughParticipants = false;
92
+ this.server.once(type.EnoughParticipants, () => {
93
+ if (this._ownId === undefined) {
94
+ debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
95
+ receivedEnoughParticipants = true;
96
+ }
97
+ });
98
+ this.aggregator.registerNode(SERVER_NODE_ID);
99
+ const msg = {
100
+ type: type.ClientConnected,
101
+ };
102
+ this.server.send(msg);
103
+ const { id, waitForMoreParticipants, payload, round, nbOfParticipants } = await waitMessageWithTimeout(this.server, type.NewFederatedNodeInfo);
104
+ // This should come right after receiving the message to make sure
105
+ // we don't miss a subsequent message from the server
106
+ // We check if the server is telling us to wait for more participants
107
+ // and we also check if a EnoughParticipant message ended up arriving
108
+ // before the NewFederatedNodeInfo
109
+ if (waitForMoreParticipants && !receivedEnoughParticipants) {
110
+ // Create a promise that resolves when enough participants join
111
+ // The client will await this promise before sending its local weight update
112
+ this.#promiseForMoreParticipants = this.waitForMoreParticipants();
113
+ }
114
+ if (this._ownId !== undefined) {
115
+ throw new Error('received id from server but was already received');
116
+ }
117
+ this._ownId = id;
118
+ debug(`[${id.slice(0, 4)}] joined session at round ${round} `);
119
+ this.aggregator.setRound(round);
120
+ this.#nbOfParticipants = nbOfParticipants;
121
+ // Upon connecting, the server answers with a boolean
122
+ // which indicates whether there are enough participants or not
123
+ debug(`[${this.ownId.slice(0, 4)}] upon connecting, wait for participant flag %o`, this.#waitingForMoreParticipants);
124
+ model.weights = serialization.weights.decode(payload);
125
+ return model;
126
+ }
127
+ /**
128
+ * Method called when the server notifies the client that there aren't enough
129
+ * participants (anymore) to start/continue training
130
+ * The method creates a promise that will resolve once the server notifies
131
+ * the client that the training can resume via a subsequent EnoughParticipants message
132
+ * @returns a promise which resolves when enough participants joined the session
133
+ */
134
+ async waitForMoreParticipants() {
135
+ return new Promise((resolve) => {
136
+ // "once" is important because we can't resolve the same promise multiple times
137
+ this.server.once(type.EnoughParticipants, () => {
138
+ debug(`[${this.ownId.slice(0, 4)}] received EnoughParticipants message from server`);
139
+ // Emit the last status emitted before waiting if defined
140
+ if (this.#previousStatus !== undefined)
141
+ this.emit("status", this.#previousStatus);
142
+ resolve();
143
+ });
144
+ });
145
+ }
146
+ /**
147
+ * Disconnection process when user quits the task.
148
+ */
149
+ async disconnect() {
150
+ await this.server.disconnect();
151
+ this._server = undefined;
152
+ this._ownId = undefined;
153
+ this.aggregator.setNodes(this.aggregator.nodes.delete(SERVER_NODE_ID));
154
+ }
155
+ onRoundBeginCommunication() {
156
+ // Prepare the result promise for the incoming round
157
+ this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
158
+ // Save the status in case participants leave and we switch to waiting for more participants
159
+ // Once enough new participants join we can display the previous status again
160
+ this.#previousStatus = "Training the model on the data you connected";
161
+ this.emit("status", this.#previousStatus);
162
+ return Promise.resolve();
163
+ }
164
+ /**
165
+ * Send the local weight update to the server and waits (indefinitely) for the server global update
166
+ *
167
+ * If the waitingForMoreParticipants flag is set, we first wait (also indefinitely) until the
168
+ * server notifies us that the training can resume.
169
+ *
170
+ // NB: For now, we suppose a fully-federated setting.
171
+ * @param weights Local weights sent to the server at the end of the local training round
172
+ * @returns the new global weights sent by the server
173
+ */
174
+ async onRoundEndCommunication(weights) {
175
+ if (this.aggregationResult === undefined) {
176
+ throw new Error("local aggregation result was not set");
177
+ }
178
+ // First we check if we are waiting for more participants before sending our weight update
179
+ if (this.#waitingForMoreParticipants) {
180
+ // wait for the promise to resolve, which takes as long as it takes for new participants to join
181
+ debug(`[${this.ownId.slice(0, 4)}] is awaiting the promise for more participants`);
182
+ this.emit("status", "Waiting for more participants");
183
+ await this.#promiseForMoreParticipants;
184
+ // Make sure to set the promise back to undefined once resolved
185
+ this.#promiseForMoreParticipants = undefined;
186
+ }
187
+ // Save the status in case participants leave and we switch to waiting for more participants
188
+ // Once enough new participants join we can display the previous status again
189
+ this.#previousStatus = "Updating the model with other participants' models";
190
+ this.emit("status", this.#previousStatus);
191
+ // Send our local contribution to the server
192
+ // and receive the server global update for this round as an answer to our contribution
193
+ const payloadToServer = this.aggregator.makePayloads(weights).first();
194
+ const msg = {
195
+ type: type.SendPayload,
196
+ payload: await serialization.weights.encode(payloadToServer),
197
+ round: this.aggregator.round,
198
+ };
199
+ // Need to await the resulting global model right after sending our local contribution
200
+ // to make sure we don't miss it
201
+ debug(`[${this.ownId.slice(0, 4)}] sent its local update to the server for round ${this.aggregator.round}`);
202
+ this.server.send(msg);
203
+ debug(`[${this.ownId.slice(0, 4)}] is waiting for server update for round ${this.aggregator.round + 1}`);
204
+ const { payload: payloadFromServer, round: serverRound, nbOfParticipants } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
205
+ this.#nbOfParticipants = nbOfParticipants; // Save the current participants
206
+ const serverResult = serialization.weights.decode(payloadFromServer);
207
+ this.aggregator.setRound(serverRound);
208
+ return serverResult;
209
+ }
210
+ }
@@ -1,2 +1,2 @@
1
- export { Base as FederatedClient } from './base.js';
1
+ export { FederatedClient } from './federated_client.js';
2
2
  export * as messages from './messages.js';
@@ -1,2 +1,2 @@
1
- export { Base as FederatedClient } from './base.js';
1
+ export { FederatedClient } from './federated_client.js';
2
2
  export * as messages from './messages.js';
@@ -1,6 +1,15 @@
1
1
  import { type weights } from '../../serialization/index.js';
2
- import { type, type AssignNodeID, type ClientConnected } from '../messages.js';
3
- export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | AssignNodeID;
2
+ import { type NodeID } from '..//types.js';
3
+ import { type, type ClientConnected } from '../messages.js';
4
+ export type MessageFederated = ClientConnected | NewFederatedNodeInfo | SendPayload | ReceiveServerPayload | WaitingForMoreParticipants | EnoughParticipants;
5
+ export interface NewFederatedNodeInfo {
6
+ type: type.NewFederatedNodeInfo;
7
+ id: NodeID;
8
+ waitForMoreParticipants: boolean;
9
+ payload: weights.Encoded;
10
+ round: number;
11
+ nbOfParticipants: number;
12
+ }
4
13
  export interface SendPayload {
5
14
  type: type.SendPayload;
6
15
  payload: weights.Encoded;
@@ -12,4 +21,10 @@ export interface ReceiveServerPayload {
12
21
  round: number;
13
22
  nbOfParticipants: number;
14
23
  }
24
+ export interface EnoughParticipants {
25
+ type: type.EnoughParticipants;
26
+ }
27
+ export interface WaitingForMoreParticipants {
28
+ type: type.WaitingForMoreParticipants;
29
+ }
15
30
  export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
@@ -5,9 +5,11 @@ export function isMessageFederated(raw) {
5
5
  }
6
6
  switch (raw.type) {
7
7
  case type.ClientConnected:
8
+ case type.NewFederatedNodeInfo:
8
9
  case type.SendPayload:
9
10
  case type.ReceiveServerPayload:
10
- case type.AssignNodeID:
11
+ case type.WaitingForMoreParticipants:
12
+ case type.EnoughParticipants:
11
13
  return true;
12
14
  }
13
15
  return false;
@@ -1,8 +1,8 @@
1
- export { Base as Client } from './base.js';
1
+ export { Client } from './client.js';
2
2
  export * from './types.js';
3
3
  export * as aggregator from '../aggregator/index.js';
4
4
  export * as decentralized from './decentralized/index.js';
5
5
  export * as federated from './federated/index.js';
6
6
  export * as messages from './messages.js';
7
7
  export { getClient, timeout } from './utils.js';
8
- export { Local } from './local.js';
8
+ export { LocalClient } from './local_client.js';
@@ -1,8 +1,8 @@
1
- export { Base as Client } from './base.js';
1
+ export { Client } from './client.js';
2
2
  export * from './types.js';
3
3
  export * as aggregator from '../aggregator/index.js';
4
4
  export * as decentralized from './decentralized/index.js';
5
5
  export * as federated from './federated/index.js';
6
6
  export * as messages from './messages.js';
7
7
  export { getClient, timeout } from './utils.js';
8
- export { Local } from './local.js';
8
+ export { LocalClient } from './local_client.js';
@@ -0,0 +1,10 @@
1
+ import { WeightsContainer } from "../index.js";
2
+ import { Client } from "./client.js";
3
+ /**
4
+ * A LocalClient represents a Disco user training only on their local data without collaborating
5
+ * with anyone. Thus LocalClient doesn't do anything during communication
6
+ */
7
+ export declare class LocalClient extends Client {
8
+ onRoundBeginCommunication(): Promise<void>;
9
+ onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
10
+ }
@@ -0,0 +1,14 @@
1
+ import { Client } from "./client.js";
2
+ /**
3
+ * A LocalClient represents a Disco user training only on their local data without collaborating
4
+ * with anyone. Thus LocalClient doesn't do anything during communication
5
+ */
6
+ export class LocalClient extends Client {
7
+ onRoundBeginCommunication() {
8
+ return Promise.resolve();
9
+ }
10
+ // Simply return the local weights
11
+ onRoundEndCommunication(weights) {
12
+ return Promise.resolve(weights);
13
+ }
14
+ }
@@ -1,23 +1,21 @@
1
1
  import type * as decentralized from './decentralized/messages.js';
2
2
  import type * as federated from './federated/messages.js';
3
- import { type NodeID } from './types.js';
4
3
  export declare enum type {
5
4
  ClientConnected = 0,
6
- AssignNodeID = 1,
5
+ NewDecentralizedNodeInfo = 1,
7
6
  SignalForPeer = 2,
8
7
  PeerIsReady = 3,
9
8
  PeersForRound = 4,
10
9
  Payload = 5,
11
- SendPayload = 6,
12
- ReceiveServerPayload = 7
10
+ NewFederatedNodeInfo = 6,
11
+ WaitingForMoreParticipants = 7,
12
+ EnoughParticipants = 8,
13
+ SendPayload = 9,
14
+ ReceiveServerPayload = 10
13
15
  }
14
16
  export interface ClientConnected {
15
17
  type: type.ClientConnected;
16
18
  }
17
- export interface AssignNodeID {
18
- type: type.AssignNodeID;
19
- id: NodeID;
20
- }
21
19
  export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
22
20
  export type NarrowMessage<D> = Extract<Message, {
23
21
  type: D;
@@ -1,21 +1,37 @@
1
1
  export var type;
2
2
  (function (type) {
3
3
  // Sent from client to server as first point of contact to join a task.
4
- // The server answers with an node id in a AssignNodeID message
4
+ // The server answers with an node id in a NewFederatedNodeInfo
5
+ // or NewDecentralizedNodeInfo message
5
6
  type[type["ClientConnected"] = 0] = "ClientConnected";
6
- // When a user joins a task with a ClientConnected message, the server
7
- // answers with an AssignNodeID message with its peer id.
8
- type[type["AssignNodeID"] = 1] = "AssignNodeID";
9
7
  /* Decentralized */
8
+ // When a user joins a task with a ClientConnected message, the server
9
+ // answers with its peer id and also tells the client whether we are waiting
10
+ // for more participants before starting training
11
+ type[type["NewDecentralizedNodeInfo"] = 1] = "NewDecentralizedNodeInfo";
10
12
  // Message forwarded by the server from a client to another client
11
13
  // to establish a peer-to-peer (WebRTC) connection
12
14
  type[type["SignalForPeer"] = 2] = "SignalForPeer";
15
+ // Message sent by nodes to server signaling they are ready to
16
+ // start the next round
13
17
  type[type["PeerIsReady"] = 3] = "PeerIsReady";
18
+ // Sent by the server to participating peers containing the list
19
+ // of peers for the round
14
20
  type[type["PeersForRound"] = 4] = "PeersForRound";
21
+ // The weight update
15
22
  type[type["Payload"] = 5] = "Payload";
16
- // Federated
17
- type[type["SendPayload"] = 6] = "SendPayload";
18
- type[type["ReceiveServerPayload"] = 7] = "ReceiveServerPayload";
23
+ /* Federated */
24
+ // The server answers the ClientConnected message with the necessary information
25
+ // to start training: node id, latest model global weights, current round etc
26
+ type[type["NewFederatedNodeInfo"] = 6] = "NewFederatedNodeInfo";
27
+ // Message sent by server to notify clients that there are not enough
28
+ // participants to continue training
29
+ type[type["WaitingForMoreParticipants"] = 7] = "WaitingForMoreParticipants";
30
+ // Message sent by server to notify clients that there are now enough
31
+ // participants to start training collaboratively
32
+ type[type["EnoughParticipants"] = 8] = "EnoughParticipants";
33
+ type[type["SendPayload"] = 9] = "SendPayload";
34
+ type[type["ReceiveServerPayload"] = 10] = "ReceiveServerPayload";
19
35
  })(type || (type = {}));
20
36
  export function hasMessageType(raw) {
21
37
  if (typeof raw !== 'object' || raw === null) {
@@ -13,7 +13,7 @@ export function getClient(trainingScheme, serverURL, task, aggregator) {
13
13
  case 'federated':
14
14
  return new clients.federated.FederatedClient(serverURL, task, aggregator);
15
15
  case 'local':
16
- return new clients.Local(serverURL, task, aggregator);
16
+ return new clients.LocalClient(serverURL, task, aggregator);
17
17
  default: {
18
18
  const _ = trainingScheme;
19
19
  throw new Error('should never happen');
@@ -32,7 +32,7 @@ export const cifar10 = {
32
32
  scheme: 'decentralized',
33
33
  privacy: { clippingRadius: 20, noiseScale: 1 },
34
34
  decentralizedSecure: true,
35
- minimumReadyPeers: 3,
35
+ minNbOfParticipants: 3,
36
36
  maxShareValue: 100,
37
37
  tensorBackend: 'tfjs'
38
38
  }
@@ -29,6 +29,7 @@ export const lusCovid = {
29
29
  LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
30
30
  dataType: 'image',
31
31
  scheme: 'federated',
32
+ minNbOfParticipants: 2,
32
33
  tensorBackend: 'tfjs'
33
34
  }
34
35
  };
@@ -31,7 +31,7 @@ export const mnist = {
31
31
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
32
32
  scheme: 'decentralized',
33
33
  decentralizedSecure: true,
34
- minimumReadyPeers: 3,
34
+ minNbOfParticipants: 3,
35
35
  maxShareValue: 100,
36
36
  tensorBackend: 'tfjs'
37
37
  }
@@ -28,7 +28,8 @@ export const simpleFace = {
28
28
  IMAGE_H: 200,
29
29
  IMAGE_W: 200,
30
30
  LABEL_LIST: ['child', 'adult'],
31
- scheme: 'federated', // secure aggregation not yet implemented for federated
31
+ scheme: 'federated',
32
+ minNbOfParticipants: 2,
32
33
  tensorBackend: 'tfjs'
33
34
  }
34
35
  };
@@ -62,7 +62,8 @@ export const titanic = {
62
62
  outputColumns: [
63
63
  'Survived'
64
64
  ],
65
- scheme: 'federated', // secure aggregation not yet implemented for FeAI
65
+ scheme: 'federated',
66
+ minNbOfParticipants: 2,
66
67
  tensorBackend: 'tfjs'
67
68
  }
68
69
  };
@@ -26,6 +26,7 @@ export const wikitext = {
26
26
  modelID: 'llm-raw-model',
27
27
  preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
28
28
  scheme: 'federated',
29
+ minNbOfParticipants: 2,
29
30
  epochs: 6,
30
31
  // Unused by wikitext because data already comes split
31
32
  // But if set to 0 then the webapp doesn't display the validation metrics
package/dist/index.d.ts CHANGED
@@ -1,5 +1,7 @@
1
1
  export * as data from './dataset/index.js';
2
2
  export * as serialization from './serialization/index.js';
3
+ export { Encoded as EncodedModel } from './serialization/model.js';
4
+ export { Encoded as EncodedWeights } from './serialization/weights.js';
3
5
  export * as training from './training/index.js';
4
6
  export * as privacy from './privacy.js';
5
7
  export * as client from './client/index.js';
@@ -7,13 +9,14 @@ export * as aggregator from './aggregator/index.js';
7
9
  export { WeightsContainer, aggregation } from './weights/index.js';
8
10
  export { Logger, ConsoleLogger } from './logging/index.js';
9
11
  export { Memory, type ModelInfo, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
10
- export { Disco, RoundLogs } from './training/index.js';
12
+ export { Disco, RoundLogs, RoundStatus } from './training/index.js';
11
13
  export { Validator } from './validation/index.js';
12
14
  export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
13
15
  export * as models from './models/index.js';
14
16
  export * from './task/index.js';
15
17
  export * as defaultTasks from './default_tasks/index.js';
16
18
  export * as async_iterator from "./utils/async_iterator.js";
19
+ export { EventEmitter } from "./utils/event_emitter.js";
17
20
  export { Dataset } from "./dataset/index.js";
18
21
  export * from "./dataset/types.js";
19
22
  export * from "./types.js";
package/dist/index.js CHANGED
@@ -14,6 +14,7 @@ export * as models from './models/index.js';
14
14
  export * from './task/index.js';
15
15
  export * as defaultTasks from './default_tasks/index.js';
16
16
  export * as async_iterator from "./utils/async_iterator.js";
17
+ export { EventEmitter } from "./utils/event_emitter.js";
17
18
  export { Dataset } from "./dataset/index.js";
18
19
  export * from "./dataset/types.js"; // TODO merge with above
19
20
  export * from "./types.js";
@@ -1,6 +1,6 @@
1
1
  export interface Logger {
2
2
  /**
3
- * Logs sucess message (in green)
3
+ * Logs success message (in green)
4
4
  * @param message - message to be displayed
5
5
  */
6
6
  success(message: string): void;
@@ -1,6 +1,5 @@
1
1
  export { isTask, type Task, isTaskID, type TaskID } from './task.js';
2
2
  export { type TaskProvider } from './task_provider.js';
3
- export { isDigest, type Digest } from './digest.js';
4
3
  export { isDisplayInformation, type DisplayInformation } from './display_information.js';
5
4
  export type { TrainingInformation } from './training_information.js';
6
5
  export { pushTask, fetchTasks } from './task_handler.js';
@@ -1,4 +1,3 @@
1
1
  export { isTask, isTaskID } from './task.js';
2
- export { isDigest } from './digest.js';
3
2
  export { isDisplayInformation } from './display_information.js';
4
3
  export { pushTask, fetchTasks } from './task_handler.js';
@@ -1,10 +1,8 @@
1
1
  import { type DisplayInformation } from './display_information.js';
2
2
  import { type TrainingInformation } from './training_information.js';
3
- import { type Digest } from './digest.js';
4
3
  export type TaskID = string;
5
4
  export interface Task {
6
5
  id: TaskID;
7
- digest?: Digest;
8
6
  displayInformation: DisplayInformation;
9
7
  trainingInformation: TrainingInformation;
10
8
  }
package/dist/task/task.js CHANGED
@@ -1,6 +1,5 @@
1
1
  import { isDisplayInformation } from './display_information.js';
2
2
  import { isTrainingInformation } from './training_information.js';
3
- import { isDigest } from './digest.js';
4
3
  export function isTaskID(obj) {
5
4
  return typeof obj === 'string';
6
5
  }
@@ -8,14 +7,13 @@ export function isTask(raw) {
8
7
  if (typeof raw !== 'object' || raw === null) {
9
8
  return false;
10
9
  }
11
- const { id, digest, displayInformation, trainingInformation } = raw;
10
+ const { id, displayInformation, trainingInformation } = raw;
12
11
  if (!isTaskID(id) ||
13
- (digest !== undefined && !isDigest(digest)) ||
14
12
  !isDisplayInformation(displayInformation) ||
15
13
  !isTrainingInformation(trainingInformation)) {
16
14
  return false;
17
15
  }
18
- const repack = { id, digest, displayInformation, trainingInformation };
16
+ const repack = { id, displayInformation, trainingInformation };
19
17
  const _correct = repack;
20
18
  const _total = repack;
21
19
  return true;
@@ -21,7 +21,7 @@ export interface TrainingInformation {
21
21
  privacy?: Privacy;
22
22
  decentralizedSecure?: boolean;
23
23
  maxShareValue?: number;
24
- minimumReadyPeers?: number;
24
+ minNbOfParticipants: number;
25
25
  aggregator?: 'mean' | 'secure';
26
26
  tokenizer?: string | PreTrainedTokenizer;
27
27
  maxSequenceLength?: number;
@@ -24,20 +24,20 @@ export function isTrainingInformation(raw) {
24
24
  if (typeof raw !== 'object' || raw === null) {
25
25
  return false;
26
26
  }
27
- const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
27
+ const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
28
28
  if (typeof dataType !== 'string' ||
29
29
  typeof modelID !== 'string' ||
30
30
  typeof epochs !== 'number' ||
31
31
  typeof batchSize !== 'number' ||
32
32
  typeof roundDuration !== 'number' ||
33
33
  typeof validationSplit !== 'number' ||
34
+ typeof minNbOfParticipants !== 'number' ||
34
35
  (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
35
36
  (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
36
37
  (aggregator !== undefined && typeof aggregator !== 'string') ||
37
38
  (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
38
39
  (privacy !== undefined && !isPrivacy(privacy)) ||
39
40
  (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
40
- (minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
41
41
  (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
42
42
  (IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
43
43
  (LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
@@ -96,7 +96,7 @@ export function isTrainingInformation(raw) {
96
96
  epochs,
97
97
  inputColumns,
98
98
  maxShareValue,
99
- minimumReadyPeers,
99
+ minNbOfParticipants,
100
100
  modelID,
101
101
  outputColumns,
102
102
  preprocessingFunctions,