@epfml/discojs 3.0.1-p20241007204240.0 → 3.0.1-p20241024094708.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
  2. package/dist/aggregator/{base.js → aggregator.js} +48 -36
  3. package/dist/aggregator/get.d.ts +2 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/index.d.ts +1 -4
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +4 -4
  8. package/dist/aggregator/mean.js +5 -15
  9. package/dist/aggregator/secure.d.ts +4 -4
  10. package/dist/aggregator/secure.js +7 -17
  11. package/dist/client/client.d.ts +71 -17
  12. package/dist/client/client.js +118 -17
  13. package/dist/client/decentralized/decentralized_client.d.ts +11 -13
  14. package/dist/client/decentralized/decentralized_client.js +121 -84
  15. package/dist/client/decentralized/messages.d.ts +12 -6
  16. package/dist/client/decentralized/messages.js +9 -8
  17. package/dist/client/event_connection.js +2 -2
  18. package/dist/client/federated/federated_client.d.ts +1 -13
  19. package/dist/client/federated/federated_client.js +15 -94
  20. package/dist/client/federated/messages.d.ts +6 -11
  21. package/dist/client/local_client.d.ts +1 -0
  22. package/dist/client/local_client.js +3 -0
  23. package/dist/client/messages.d.ts +14 -7
  24. package/dist/client/messages.js +13 -11
  25. package/dist/default_tasks/cifar10.js +1 -1
  26. package/dist/default_tasks/lus_covid.js +1 -0
  27. package/dist/default_tasks/mnist.js +1 -1
  28. package/dist/default_tasks/simple_face.js +1 -0
  29. package/dist/default_tasks/titanic.js +1 -0
  30. package/dist/default_tasks/wikitext.js +1 -0
  31. package/dist/index.d.ts +0 -2
  32. package/dist/serialization/coder.d.ts +4 -0
  33. package/dist/serialization/coder.js +51 -0
  34. package/dist/serialization/index.d.ts +2 -0
  35. package/dist/serialization/index.js +1 -0
  36. package/dist/serialization/model.d.ts +1 -2
  37. package/dist/serialization/model.js +9 -24
  38. package/dist/serialization/weights.d.ts +2 -3
  39. package/dist/serialization/weights.js +15 -26
  40. package/dist/task/task_handler.d.ts +5 -5
  41. package/dist/task/task_handler.js +21 -15
  42. package/dist/task/training_information.d.ts +1 -2
  43. package/dist/task/training_information.js +6 -8
  44. package/dist/training/disco.d.ts +4 -1
  45. package/dist/training/trainer.js +1 -1
  46. package/dist/utils/event_emitter.d.ts +3 -3
  47. package/dist/utils/event_emitter.js +10 -9
  48. package/package.json +2 -3
@@ -1,22 +1,28 @@
1
- import axios from 'axios';
2
1
  import createDebug from "debug";
3
- import { Map } from 'immutable';
4
- import { serialization } from '../index.js';
5
- import { isTask } from './task.js';
2
+ import { Map } from "immutable";
3
+ import { serialization } from "../index.js";
4
+ import { isTask } from "./task.js";
6
5
  const debug = createDebug("discojs:task:handlers");
7
- const TASK_ENDPOINT = 'tasks';
8
- export async function pushTask(url, task, model) {
9
- await axios.post(url.href + TASK_ENDPOINT, {
10
- task,
11
- model: await serialization.model.encode(model),
12
- weights: await serialization.weights.encode(model.weights)
6
+ function urlToTasks(base) {
7
+ const ret = new URL(base);
8
+ ret.pathname += "tasks";
9
+ return ret;
10
+ }
11
+ export async function pushTask(base, task, model) {
12
+ await fetch(urlToTasks(base), {
13
+ method: "POST",
14
+ body: JSON.stringify({
15
+ task,
16
+ model: await serialization.model.encode(model),
17
+ weights: await serialization.weights.encode(model.weights),
18
+ }),
13
19
  });
14
20
  }
15
- export async function fetchTasks(url) {
16
- const response = await axios.get(new URL(TASK_ENDPOINT, url).href);
17
- const tasks = response.data;
21
+ export async function fetchTasks(base) {
22
+ const response = await fetch(urlToTasks(base));
23
+ const tasks = await response.json();
18
24
  if (!Array.isArray(tasks)) {
19
- throw new Error('Expected to receive an array of Tasks when fetching tasks');
25
+ throw new Error("Expected to receive an array of Tasks when fetching tasks");
20
26
  }
21
27
  else if (!tasks.every(isTask)) {
22
28
  for (const task of tasks) {
@@ -24,7 +30,7 @@ export async function fetchTasks(url) {
24
30
  debug("task has invalid format: :O", task);
25
31
  }
26
32
  }
27
- throw new Error('invalid tasks response, the task object received is not well formatted');
33
+ throw new Error("invalid tasks response, the task object received is not well formatted");
28
34
  }
29
35
  return Map(tasks.map((t) => [t.id, t]));
30
36
  }
@@ -18,10 +18,9 @@ export interface TrainingInformation {
18
18
  LABEL_LIST?: string[];
19
19
  scheme: 'decentralized' | 'federated' | 'local';
20
20
  privacy?: Privacy;
21
- decentralizedSecure?: boolean;
22
21
  maxShareValue?: number;
23
22
  minNbOfParticipants: number;
24
- aggregator?: 'mean' | 'secure';
23
+ aggregationStrategy?: 'mean' | 'secure';
25
24
  tokenizer?: string | PreTrainedTokenizer;
26
25
  maxSequenceLength?: number;
27
26
  tensorBackend: 'tfjs' | 'gpt';
@@ -24,7 +24,7 @@ export function isTrainingInformation(raw) {
24
24
  if (typeof raw !== 'object' || raw === null) {
25
25
  return false;
26
26
  }
27
- const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
27
+ const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregationStrategy, batchSize, dataType, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
28
28
  if (typeof dataType !== 'string' ||
29
29
  typeof epochs !== 'number' ||
30
30
  typeof batchSize !== 'number' ||
@@ -33,8 +33,7 @@ export function isTrainingInformation(raw) {
33
33
  typeof minNbOfParticipants !== 'number' ||
34
34
  (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
35
35
  (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
36
- (aggregator !== undefined && typeof aggregator !== 'string') ||
37
- (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
36
+ (aggregationStrategy !== undefined && typeof aggregationStrategy !== 'string') ||
38
37
  (privacy !== undefined && !isPrivacy(privacy)) ||
39
38
  (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
40
39
  (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
@@ -45,8 +44,8 @@ export function isTrainingInformation(raw) {
45
44
  (preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
46
45
  return false;
47
46
  }
48
- if (aggregator !== undefined) {
49
- switch (aggregator) {
47
+ if (aggregationStrategy !== undefined) {
48
+ switch (aggregationStrategy) {
50
49
  case 'mean': break;
51
50
  case 'secure': break;
52
51
  default: return false;
@@ -58,7 +57,7 @@ export function isTrainingInformation(raw) {
58
57
  case 'text': break;
59
58
  default: return false;
60
59
  }
61
- // interdepences on data type
60
+ // interdependencies on data type
62
61
  if (dataType === 'image') {
63
62
  if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
64
63
  return false;
@@ -87,10 +86,9 @@ export function isTrainingInformation(raw) {
87
86
  IMAGE_W,
88
87
  IMAGE_H,
89
88
  LABEL_LIST,
90
- aggregator,
89
+ aggregationStrategy,
91
90
  batchSize,
92
91
  dataType,
93
- decentralizedSecure,
94
92
  privacy,
95
93
  epochs,
96
94
  inputColumns,
@@ -7,7 +7,10 @@ interface DiscoConfig {
7
7
  scheme: TrainingInformation["scheme"];
8
8
  logger: Logger;
9
9
  }
10
- export type RoundStatus = "Waiting for more participants" | "Retrieving peers' information" | "Updating the model with other participants' models" | "Training the model on the data you connected";
10
+ export type RoundStatus = 'not enough participants' | // Server notification to wait for more participants
11
+ 'updating model' | // fetching/aggregating local updates into a global model
12
+ 'local training' | // Training the model locally
13
+ 'connecting to peers';
11
14
  /**
12
15
  * Top-level class handling distributed training from a client's perspective. It is meant to be
13
16
  * a convenient object providing a reduced yet complete API that wraps model training and
@@ -62,7 +62,7 @@ export class Trainer {
62
62
  }
63
63
  return {
64
64
  epochs: epochsLogs,
65
- participants: this.#client.nbOfParticipants,
65
+ participants: this.#client.getNbOfParticipants(),
66
66
  };
67
67
  }
68
68
  }
@@ -1,11 +1,11 @@
1
- type Listener<T> = (_: T) => void;
1
+ type Listener<T> = (_: T) => void | Promise<void>;
2
2
  /**
3
3
  * Call handlers on given events
4
4
  *
5
5
  * @typeParam I object/mapping from event name to emitted value type
6
6
  */
7
7
  export declare class EventEmitter<I extends Record<string, unknown>> {
8
- private listeners;
8
+ #private;
9
9
  /**
10
10
  * @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
11
11
  */
@@ -13,7 +13,7 @@ export declare class EventEmitter<I extends Record<string, unknown>> {
13
13
  [E in keyof I]?: Listener<I[E]>;
14
14
  });
15
15
  /**
16
- * Register listener to call on event
16
+ * Register listener to call on event.
17
17
  *
18
18
  * @param event event name to listen to
19
19
  * @param listener handler to call
@@ -6,7 +6,8 @@ import { List } from 'immutable';
6
6
  * @typeParam I object/mapping from event name to emitted value type
7
7
  */
8
8
  export class EventEmitter {
9
- listeners = {};
9
+ // List of callbacks to run per event
10
+ #listeners = {};
10
11
  /**
11
12
  * @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
12
13
  */
@@ -19,14 +20,14 @@ export class EventEmitter {
19
20
  }
20
21
  }
21
22
  /**
22
- * Register listener to call on event
23
+ * Register listener to call on event.
23
24
  *
24
25
  * @param event event name to listen to
25
26
  * @param listener handler to call
26
27
  */
27
28
  on(event, listener) {
28
- const eventListeners = this.listeners[event] ?? List();
29
- this.listeners[event] = eventListeners.push([false, listener]);
29
+ const eventListeners = this.#listeners[event] ?? List();
30
+ this.#listeners[event] = eventListeners.push([false, listener]);
30
31
  }
31
32
  /**
32
33
  * Register listener to call once on next event
@@ -35,8 +36,8 @@ export class EventEmitter {
35
36
  * @param listener handler to call next time
36
37
  */
37
38
  once(event, listener) {
38
- const eventListeners = this.listeners[event] ?? List();
39
- this.listeners[event] = eventListeners.push([true, listener]);
39
+ const eventListeners = this.#listeners[event] ?? List();
40
+ this.#listeners[event] = eventListeners.push([true, listener]);
40
41
  }
41
42
  /**
42
43
  * Send value to registered listeners of event name
@@ -45,9 +46,9 @@ export class EventEmitter {
45
46
  * @param value what to call listeners with
46
47
  */
47
48
  emit(event, value) {
48
- const eventListeners = this.listeners[event] ?? List();
49
- this.listeners[event] = eventListeners.filterNot(([once]) => once);
50
- eventListeners.forEach(([_, listener]) => { listener(value); });
49
+ const eventListeners = this.#listeners[event] ?? List();
50
+ this.#listeners[event] = eventListeners.filterNot(([once]) => once);
51
+ eventListeners.forEach(async ([_, listener]) => { await listener(value); });
51
52
  }
52
53
  }
53
54
  /** `EventEmitter` for all events */
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20241007204240.0",
3
+ "version": "3.0.1-p20241024094708.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -19,13 +19,12 @@
19
19
  },
20
20
  "homepage": "https://github.com/epfml/disco#readme",
21
21
  "dependencies": {
22
+ "@msgpack/msgpack": "^3.0.0-beta2",
22
23
  "@tensorflow/tfjs": "4",
23
24
  "@xenova/transformers": "2",
24
- "axios": "1",
25
25
  "immutable": "4",
26
26
  "isomorphic-wrtc": "1",
27
27
  "isomorphic-ws": "5",
28
- "msgpack-lite": "0.1",
29
28
  "simple-peer": "9",
30
29
  "tslib": "2",
31
30
  "ws": "8"