@epfml/discojs 2.1.2-p20240722093114.0 → 2.1.2-p20240723143623.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -100,8 +100,8 @@ export declare abstract class Base<T> {
100
100
  */
101
101
  log(step: AggregationStep, from?: client.NodeID): void;
102
102
  /**
103
- * Sets the aggregator's TF.js model.
104
- * @param model The new TF.js model
103
+ * Sets the aggregator's model.
104
+ * @param model The new model
105
105
  */
106
106
  setModel(model: Model): void;
107
107
  /**
@@ -109,6 +109,7 @@ export declare abstract class Base<T> {
109
109
  * peer/client within the network, whom we are communicating with during this aggregation
110
110
  * round.
111
111
  * @param nodeId The node to be added
112
+ * @returns True is the node wasn't already in the list of nodes, False if already included
112
113
  */
113
114
  registerNode(nodeId: client.NodeID): boolean;
114
115
  /**
@@ -97,7 +97,7 @@ export class Base {
97
97
  log(step, from) {
98
98
  switch (step) {
99
99
  case AggregationStep.ADD:
100
- console.log(`> Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
100
+ console.log(`Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
101
101
  break;
102
102
  case AggregationStep.UPDATE:
103
103
  if (from === undefined) {
@@ -116,8 +116,8 @@ export class Base {
116
116
  }
117
117
  }
118
118
  /**
119
- * Sets the aggregator's TF.js model.
120
- * @param model The new TF.js model
119
+ * Sets the aggregator's model.
120
+ * @param model The new model
121
121
  */
122
122
  setModel(model) {
123
123
  this._model = model;
@@ -127,6 +127,7 @@ export class Base {
127
127
  * peer/client within the network, whom we are communicating with during this aggregation
128
128
  * round.
129
129
  * @param nodeId The node to be added
130
+ * @returns True is the node wasn't already in the list of nodes, False if already included
130
131
  */
131
132
  registerNode(nodeId) {
132
133
  if (!this.nodes.has(nodeId)) {
@@ -1,16 +1,30 @@
1
1
  import type { Task } from '../index.js';
2
2
  import { aggregator } from '../index.js';
3
+ import type { Model } from "../index.js";
4
+ type AggregatorOptions = Partial<{
5
+ model: Model;
6
+ scheme: Task['trainingInformation']['scheme'];
7
+ roundCutOff: number;
8
+ threshold: number;
9
+ thresholdType: 'relative' | 'absolute';
10
+ }>;
3
11
  /**
4
- * Enumeration of the available types of aggregator.
5
- */
6
- export declare enum AggregatorChoice {
7
- MEAN = 0,
8
- SECURE = 1,
9
- BANDIT = 2
10
- }
11
- /**
12
- * Provides the aggregator object adequate to the given task.
13
- * @param task The task
12
+ * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
13
+ * Here is the ordered list of parameters used to define the aggregator and its default behavior:
14
+ * task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
15
+ *
16
+ * If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
17
+ * Otherwise, we default to a MeanAggregator for both training schemes.
18
+ *
19
+ * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
20
+ * Unless specified otherwise, for federated learning or local training the aggregator default to waiting
21
+ * for a single contribution to trigger a model update.
22
+ * (the server's model update for federated learning or our own contribution if training locally)
23
+ * For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
24
+ *
25
+ * @param task The task object associated with the current training session
26
+ * @param options Options passed down to the aggregator's constructor
14
27
  * @returns The aggregator
15
28
  */
16
- export declare function getAggregator(task: Task): aggregator.Aggregator;
29
+ export declare function getAggregator(task: Task, options?: AggregatorOptions): aggregator.Aggregator;
30
+ export {};
@@ -1,31 +1,48 @@
1
1
  import { aggregator } from '../index.js';
2
2
  /**
3
- * Enumeration of the available types of aggregator.
4
- */
5
- export var AggregatorChoice;
6
- (function (AggregatorChoice) {
7
- AggregatorChoice[AggregatorChoice["MEAN"] = 0] = "MEAN";
8
- AggregatorChoice[AggregatorChoice["SECURE"] = 1] = "SECURE";
9
- AggregatorChoice[AggregatorChoice["BANDIT"] = 2] = "BANDIT";
10
- })(AggregatorChoice || (AggregatorChoice = {}));
11
- /**
12
- * Provides the aggregator object adequate to the given task.
13
- * @param task The task
3
+ * Initializes an aggregator according to the task definition, the training scheme and the aggregator parameters.
4
+ * Here is the ordered list of parameters used to define the aggregator and its default behavior:
5
+ * task.trainingInformation.aggregator > options.scheme > task.trainingInformation.scheme
6
+ *
7
+ * If `task.trainingInformation.aggregator` is defined, we initialize the chosen aggregator with `options` parameter values.
8
+ * Otherwise, we default to a MeanAggregator for both training schemes.
9
+ *
10
+ * For the MeanAggregator we rely on `options.scheme` and fallback on `task.trainingInformation.scheme` to infer default values.
11
+ * Unless specified otherwise, for federated learning or local training the aggregator default to waiting
12
+ * for a single contribution to trigger a model update.
13
+ * (the server's model update for federated learning or our own contribution if training locally)
14
+ * For decentralized learning the aggregator defaults to waiting for every nodes' contribution to trigger a model update.
15
+ *
16
+ * @param task The task object associated with the current training session
17
+ * @param options Options passed down to the aggregator's constructor
14
18
  * @returns The aggregator
15
19
  */
16
- export function getAggregator(task) {
17
- const error = new Error('not implemented');
18
- switch (task.trainingInformation.aggregator) {
19
- case AggregatorChoice.MEAN:
20
- return new aggregator.MeanAggregator();
21
- case AggregatorChoice.BANDIT:
22
- throw error;
23
- case AggregatorChoice.SECURE:
24
- if (task.trainingInformation.scheme !== 'decentralized') {
20
+ export function getAggregator(task, options = {}) {
21
+ const aggregatorType = task.trainingInformation.aggregator ?? 'mean';
22
+ const scheme = options.scheme ?? task.trainingInformation.scheme;
23
+ switch (aggregatorType) {
24
+ case 'mean':
25
+ if (scheme === 'decentralized') {
26
+ // If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
27
+ options = {
28
+ model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'relative',
29
+ ...options
30
+ };
31
+ }
32
+ else {
33
+ // If scheme == 'federated' then we only expect the server's contribution at each round
34
+ // so we set the aggregation threshold to 1 contribution
35
+ // If scheme == 'local' then we only expect our own contribution
36
+ options = {
37
+ model: undefined, roundCutOff: undefined, threshold: 1, thresholdType: 'absolute',
38
+ ...options
39
+ };
40
+ }
41
+ return new aggregator.MeanAggregator(options.model, options.roundCutOff, options.threshold, options.thresholdType);
42
+ case 'secure':
43
+ if (scheme !== 'decentralized') {
25
44
  throw new Error('secure aggregation is currently supported for decentralized only');
26
45
  }
27
- return new aggregator.SecureAggregator();
28
- default:
29
- return new aggregator.MeanAggregator();
46
+ return new aggregator.SecureAggregator(options.model, task.trainingInformation.maxShareValue);
30
47
  }
31
48
  }
@@ -3,5 +3,5 @@ import type { Base } from './base.js';
3
3
  export { Base as AggregatorBase, AggregationStep } from './base.js';
4
4
  export { MeanAggregator } from './mean.js';
5
5
  export { SecureAggregator } from './secure.js';
6
- export { getAggregator, AggregatorChoice } from './get.js';
6
+ export { getAggregator } from './get.js';
7
7
  export type Aggregator = Base<WeightsContainer>;
@@ -1,4 +1,4 @@
1
1
  export { Base as AggregatorBase, AggregationStep } from './base.js';
2
2
  export { MeanAggregator } from './mean.js';
3
3
  export { SecureAggregator } from './secure.js';
4
- export { getAggregator, AggregatorChoice } from './get.js';
4
+ export { getAggregator } from './get.js';
@@ -1,18 +1,37 @@
1
1
  import type { Map } from "immutable";
2
2
  import { Base as Aggregator } from "./base.js";
3
3
  import type { Model, WeightsContainer, client } from "../index.js";
4
- /** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
4
+ type ThresholdType = 'relative' | 'absolute';
5
+ /**
6
+ * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
7
+ *
8
+ */
5
9
  export declare class MeanAggregator extends Aggregator<WeightsContainer> {
6
10
  #private;
7
11
  /**
8
- * @param threshold - how many contributions for trigger an aggregation step.
9
- * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
10
- * - absolute: t > 1, thus requiring t contributions
12
+ * Create a mean aggregator that averages all weight updates received when a specified threshold is met.
13
+ * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
14
+ * only accepts contributions from the current round (drops contributions from previous rounds).
15
+ *
16
+ * @param threshold - how many contributions trigger an aggregation step.
17
+ * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions.
18
+ * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`.
19
+ * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions
20
+ * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update),
21
+ * set `threshold = 1` and `thresholdType = 'absolute'`
22
+ * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1,
23
+ * If `threshold != 1` then the specified thresholdType is ignored and overwritten
24
+ * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution
25
+ * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions,
26
+ * @param roundCutoff - from how many past rounds do we still accept contributions.
27
+ * If 0 then only accept contributions from the current round,
28
+ * if 1 then the current round and the previous one, etc.
11
29
  */
12
- constructor(model?: Model, roundCutoff?: number, threshold?: number);
30
+ constructor(model?: Model, roundCutoff?: number, threshold?: number, thresholdType?: ThresholdType);
13
31
  /** Checks whether the contributions buffer is full. */
14
32
  isFull(): boolean;
15
33
  add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
16
34
  aggregate(): void;
17
35
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
18
36
  }
37
+ export {};
@@ -1,34 +1,82 @@
1
1
  import { AggregationStep, Base as Aggregator } from "./base.js";
2
2
  import { aggregation } from "../index.js";
3
- /** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
3
+ /**
4
+ * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
5
+ *
6
+ */
4
7
  export class MeanAggregator extends Aggregator {
5
8
  #threshold;
9
+ #thresholdType;
6
10
  /**
7
- * @param threshold - how many contributions for trigger an aggregation step.
8
- * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
9
- * - absolute: t > 1, thus requiring t contributions
11
+ * Create a mean aggregator that averages all weight updates received when a specified threshold is met.
12
+ * By default, initializes an aggregator that waits for 100% of the nodes' contributions and that
13
+ * only accepts contributions from the current round (drops contributions from previous rounds).
14
+ *
15
+ * @param threshold - how many contributions trigger an aggregation step.
16
+ * It can be relative (a proportion): 0 < t <= 1, requiring t * |nodes| contributions.
17
+ * Important: to specify 100% of the nodes, set `threshold = 1` and `thresholdType = 'relative'`.
18
+ * It can be an absolute number, if t >=1 (then t has to be an integer), the aggregator waits fot t contributions
19
+ * Note, to specify waiting for a single contribution (such as a federated client only waiting for the server weight update),
20
+ * set `threshold = 1` and `thresholdType = 'absolute'`
21
+ * @param thresholdType 'relative' or 'absolute', defaults to 'relative'. Is only used to clarify the case when threshold = 1,
22
+ * If `threshold != 1` then the specified thresholdType is ignored and overwritten
23
+ * If `thresholdType = 'absolute'` then `threshold = 1` means waiting for 1 contribution
24
+ * if `thresholdType = 'relative'` then `threshold = 1`` means 100% of this.nodes' contributions,
25
+ * @param roundCutoff - from how many past rounds do we still accept contributions.
26
+ * If 0 then only accept contributions from the current round,
27
+ * if 1 then the current round and the previous one, etc.
10
28
  */
11
- // TODO no way to require a single contribution
12
- constructor(model, roundCutoff = 0, threshold = 1) {
29
+ constructor(model, roundCutoff = 0, threshold = 1, thresholdType) {
13
30
  if (threshold <= 0)
14
- throw new Error("threshold must be striclty positive");
15
- if (threshold > 1 && !Number.isInteger(threshold))
16
- throw new Error("absolute thresholds must be integeral");
31
+ throw new Error("threshold must be strictly positive");
32
+ if (threshold > 1 && (!Number.isInteger(threshold)))
33
+ throw new Error("absolute thresholds must be integral");
17
34
  super(model, roundCutoff, 1);
18
35
  this.#threshold = threshold;
36
+ if (threshold < 1) {
37
+ // Throw exception if threshold and thresholdType are conflicting
38
+ if (thresholdType === 'absolute') {
39
+ throw new Error(`thresholdType has been set to 'absolute' but choosing threshold=${threshold} implies that thresholdType should be 'relative'.`);
40
+ }
41
+ this.#thresholdType = 'relative';
42
+ }
43
+ else if (threshold > 1) {
44
+ // Throw exception if threshold and thresholdType are conflicting
45
+ if (thresholdType === 'relative') {
46
+ throw new Error(`thresholdType has been set to 'relative' but choosing threshold=${threshold} implies that thresholdType should be 'absolute'.`);
47
+ }
48
+ this.#thresholdType = 'absolute';
49
+ }
50
+ // remaining case: threshold == 1
51
+ else {
52
+ // Print a warning regarding the default behavior when thresholdType is not specified
53
+ if (thresholdType === undefined) {
54
+ console.warn("[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
55
+ "To instead wait for a single contribution, set thresholdType = 'absolute'");
56
+ this.#thresholdType = 'relative';
57
+ }
58
+ else {
59
+ this.#thresholdType = thresholdType;
60
+ }
61
+ }
19
62
  }
20
63
  /** Checks whether the contributions buffer is full. */
21
64
  isFull() {
22
- const actualThreshold = this.#threshold <= 1
65
+ const thresholdValue = this.#thresholdType == 'relative'
23
66
  ? this.#threshold * this.nodes.size
24
67
  : this.#threshold;
25
- return (this.contributions.get(0)?.size ?? 0) >= actualThreshold;
68
+ return (this.contributions.get(0)?.size ?? 0) >= thresholdValue;
26
69
  }
27
70
  add(nodeId, contribution, round, currentContributions = 0) {
28
71
  if (currentContributions !== 0)
29
72
  throw new Error("only a single communication round");
30
- if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
73
+ if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
74
+ if (!this.nodes.has(nodeId))
75
+ console.warn("Contribution rejected because node id is not registered");
76
+ if (!this.isWithinRoundCutoff(round))
77
+ console.warn(`Contribution rejected because round ${round} is not within round cutoff`);
31
78
  return false;
79
+ }
32
80
  this.log(this.contributions.hasIn([0, nodeId])
33
81
  ? AggregationStep.UPDATE
34
82
  : AggregationStep.ADD, nodeId);
@@ -1,4 +1,3 @@
1
- import type { Set } from 'immutable';
2
1
  import type { Model, Task, WeightsContainer } from '../index.js';
3
2
  import type { NodeID } from './types.js';
4
3
  import type { EventConnection } from './event_connection.js';
@@ -70,7 +69,7 @@ export declare abstract class Base {
70
69
  * @param _round The current training round
71
70
  */
72
71
  onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<void>;
73
- get nodes(): Set<NodeID>;
72
+ get nbOfParticipants(): number;
74
73
  get ownId(): NodeID;
75
74
  get server(): EventConnection;
76
75
  }
@@ -70,8 +70,12 @@ export class Base {
70
70
  * @param _round The current training round
71
71
  */
72
72
  async onRoundEndCommunication(_weights, _round) { }
73
- get nodes() {
74
- return this.aggregator.nodes;
73
+ // Number of contributors to a collaborative session
74
+ // If decentralized, it should be the number of peers
75
+ // If federated, it should the number of participants excluding the server
76
+ // If local it should be 1
77
+ get nbOfParticipants() {
78
+ return this.aggregator.nodes.size; // overriden by the federated client
75
79
  }
76
80
  get ownId() {
77
81
  if (this._ownId === undefined) {
@@ -1,7 +1,5 @@
1
1
  import { type WeightsContainer } from '../../index.js';
2
2
  import { Client } from '../index.js';
3
- import { type PeerConnection } from '../event_connection.js';
4
- import * as messages from './messages.js';
5
3
  /**
6
4
  * Represents a decentralized client in a network of peers. Peers coordinate each other with the
7
5
  * help of the network's server, yet only exchange payloads between each other. Communication
@@ -14,19 +12,38 @@ export declare class Base extends Client {
14
12
  */
15
13
  private pool?;
16
14
  private connections?;
15
+ private get isDisconnected();
17
16
  /**
18
- * Send message to server that this client is ready for the next training round.
17
+ * Public method called by disco.ts when starting training. This method sends
18
+ * a message to the server asking to join the task and be assigned a client ID.
19
+ *
20
+ * The peer also establishes a WebSocket connection with the server to then
21
+ * create peer-to-peer WebRTC connections with peers. The server is used to exchange
22
+ * peers network information.
19
23
  */
20
- private waitForPeers;
21
- protected sendMessagetoPeer(peer: PeerConnection, msg: messages.PeerMessage): void;
24
+ connect(): Promise<void>;
22
25
  /**
23
- * Creation of the WebSocket for the server, connection of client to that WebSocket,
24
- * deals with message reception from the decentralized client's perspective (messages received by client).
26
+ * Create a WebSocket connection with the server
27
+ * The client then waits for the server to forward it other client's network information.
28
+ * Upon receiving other peer's information, the clients establish a peer-to-peer WebRTC connection.
25
29
  */
26
30
  private connectServer;
27
- connect(): Promise<void>;
28
31
  disconnect(): Promise<void>;
32
+ /**
33
+ * At the beginning of a round, each peer tells the server it is ready to proceed
34
+ * The server answers with the list of all peers connected for the round
35
+ * Given the list, the peers then create peer-to-peer connections with each other.
36
+ * When connected, one peer creates a promise for every other peer's weight update
37
+ * and waits for it to resolve.
38
+ *
39
+ */
29
40
  onRoundBeginCommunication(_: WeightsContainer, round: number): Promise<void>;
30
- onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
41
+ /**
42
+ * At each communication rounds, awaits peers contributions and add them to the client's aggregator.
43
+ * This method is used as callback by getPeers when connecting to the rounds' peers
44
+ * @param connections
45
+ * @param round
46
+ */
31
47
  private receivePayloads;
48
+ onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
32
49
  }