@epfml/discojs 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/README.md +32 -0
  2. package/dist/aggregation.d.ts +3 -0
  3. package/dist/aggregation.js +19 -0
  4. package/dist/async_buffer.d.ts +41 -0
  5. package/dist/async_buffer.js +98 -0
  6. package/dist/async_informant.d.ts +20 -0
  7. package/dist/async_informant.js +69 -0
  8. package/dist/client/base.d.ts +36 -0
  9. package/dist/client/base.js +34 -0
  10. package/dist/client/decentralized.d.ts +23 -0
  11. package/dist/client/decentralized.js +275 -0
  12. package/dist/client/federated.d.ts +30 -0
  13. package/dist/client/federated.js +221 -0
  14. package/dist/client/index.d.ts +4 -0
  15. package/dist/client/index.js +11 -0
  16. package/dist/client/local.d.ts +8 -0
  17. package/dist/client/local.js +34 -0
  18. package/dist/dataset/data_loader/data_loader.d.ts +16 -0
  19. package/dist/dataset/data_loader/data_loader.js +10 -0
  20. package/dist/dataset/data_loader/image_loader.d.ts +14 -0
  21. package/dist/dataset/data_loader/image_loader.js +93 -0
  22. package/dist/dataset/data_loader/index.d.ts +3 -0
  23. package/dist/dataset/data_loader/index.js +9 -0
  24. package/dist/dataset/data_loader/tabular_loader.d.ts +29 -0
  25. package/dist/dataset/data_loader/tabular_loader.js +88 -0
  26. package/dist/dataset/dataset_builder.d.ts +17 -0
  27. package/dist/dataset/dataset_builder.js +80 -0
  28. package/dist/dataset/index.d.ts +2 -0
  29. package/dist/dataset/index.js +7 -0
  30. package/dist/index.d.ts +17 -0
  31. package/dist/index.js +34 -0
  32. package/dist/logging/console_logger.d.ts +18 -0
  33. package/dist/logging/console_logger.js +33 -0
  34. package/dist/logging/index.d.ts +2 -0
  35. package/dist/logging/index.js +7 -0
  36. package/dist/logging/logger.d.ts +12 -0
  37. package/dist/logging/logger.js +9 -0
  38. package/dist/logging/trainer_logger.d.ts +24 -0
  39. package/dist/logging/trainer_logger.js +59 -0
  40. package/dist/memory/base.d.ts +53 -0
  41. package/dist/memory/base.js +9 -0
  42. package/dist/memory/empty.d.ts +12 -0
  43. package/dist/memory/empty.js +69 -0
  44. package/dist/memory/index.d.ts +3 -0
  45. package/dist/memory/index.js +9 -0
  46. package/dist/memory/model_type.d.ts +4 -0
  47. package/dist/memory/model_type.js +9 -0
  48. package/dist/model_actor.d.ts +16 -0
  49. package/dist/model_actor.js +20 -0
  50. package/dist/privacy.d.ts +12 -0
  51. package/dist/privacy.js +60 -0
  52. package/dist/serialization/index.d.ts +2 -0
  53. package/dist/serialization/index.js +6 -0
  54. package/dist/serialization/model.d.ts +5 -0
  55. package/dist/serialization/model.js +55 -0
  56. package/dist/serialization/weights.d.ts +5 -0
  57. package/dist/serialization/weights.js +62 -0
  58. package/dist/task/data_example.d.ts +5 -0
  59. package/dist/task/data_example.js +24 -0
  60. package/dist/task/display_information.d.ts +15 -0
  61. package/dist/task/display_information.js +53 -0
  62. package/dist/task/index.d.ts +3 -0
  63. package/dist/task/index.js +8 -0
  64. package/dist/task/model_compile_data.d.ts +6 -0
  65. package/dist/task/model_compile_data.js +12 -0
  66. package/dist/task/task.d.ts +10 -0
  67. package/dist/task/task.js +32 -0
  68. package/dist/task/training_information.d.ts +29 -0
  69. package/dist/task/training_information.js +2 -0
  70. package/dist/tasks/cifar10.d.ts +4 -0
  71. package/dist/tasks/cifar10.js +74 -0
  72. package/dist/tasks/index.d.ts +5 -0
  73. package/dist/tasks/index.js +9 -0
  74. package/dist/tasks/lus_covid.d.ts +4 -0
  75. package/dist/tasks/lus_covid.js +48 -0
  76. package/dist/tasks/mnist.d.ts +4 -0
  77. package/dist/tasks/mnist.js +56 -0
  78. package/dist/tasks/simple_face.d.ts +4 -0
  79. package/dist/tasks/simple_face.js +88 -0
  80. package/dist/tasks/titanic.d.ts +4 -0
  81. package/dist/tasks/titanic.js +86 -0
  82. package/dist/testing/tester.d.ts +5 -0
  83. package/dist/testing/tester.js +21 -0
  84. package/dist/training/disco.d.ts +12 -0
  85. package/dist/training/disco.js +62 -0
  86. package/dist/training/index.d.ts +2 -0
  87. package/dist/training/index.js +7 -0
  88. package/dist/training/trainer/distributed_trainer.d.ts +21 -0
  89. package/dist/training/trainer/distributed_trainer.js +60 -0
  90. package/dist/training/trainer/local_trainer.d.ts +10 -0
  91. package/dist/training/trainer/local_trainer.js +37 -0
  92. package/dist/training/trainer/round_tracker.d.ts +30 -0
  93. package/dist/training/trainer/round_tracker.js +44 -0
  94. package/dist/training/trainer/trainer.d.ts +66 -0
  95. package/dist/training/trainer/trainer.js +146 -0
  96. package/dist/training/trainer/trainer_builder.d.ts +25 -0
  97. package/dist/training/trainer/trainer_builder.js +102 -0
  98. package/dist/training/training_schemes.d.ts +5 -0
  99. package/dist/training/training_schemes.js +10 -0
  100. package/dist/training_informant.d.ts +88 -0
  101. package/dist/training_informant.js +135 -0
  102. package/dist/types.d.ts +4 -0
  103. package/dist/types.js +2 -0
  104. package/package.json +48 -0
package/README.md ADDED
@@ -0,0 +1,32 @@
1
+ # discojs
2
+
3
+ discojs contains the core code of disco.
4
+
5
+ ## Node Installation and NPM installation
6
+
7
+ The app is running under Node 15.12.0. It can be downloaded from [here](https://nodejs.org/en/download/releases/).
8
+
9
+ NPM is a package manager for the JavaScript runtime environment Node.js.
10
+ To start the application (running locally) run the following command.
11
+ Note: the application is currently developed using [NPM 7.6.3](https://www.npmjs.com/package/npm/v/7.6.3).
12
+
13
+ ```
14
+ npm install
15
+ ```
16
+
17
+ This command will install the necessary libraries required to run the application (defined in the `package.json` and `package-lock.json`). The latter command is only required when one is using the app for the first time.
18
+
19
+ > **⚠ WARNING: Apple Silicon.**
20
+ > `TensorFlow.js` in version `3.13.0` currently supports for M1 mac laptops. However, make sure you have an `arm` node executable installed (not `x86`). It can be checked using:
21
+
22
+ ```
23
+ node -p "process.arch"
24
+ ```
25
+
26
+ ## Build
27
+
28
+ In order to enable the Browser to use the `discojs` package, we must build discojs:
29
+
30
+ ```
31
+ npm run build
32
+ ```
@@ -0,0 +1,3 @@
1
+ import { Set } from 'immutable';
2
+ import { Weights } from './types';
3
+ export declare function averageWeights(peersWeights: Set<Weights>): Weights;
@@ -0,0 +1,19 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.averageWeights = void 0;
4
+ function averageWeights(peersWeights) {
5
+ var _a;
6
+ var firstWeightSize = (_a = peersWeights.first()) === null || _a === void 0 ? void 0 : _a.length;
7
+ if (firstWeightSize === undefined) {
8
+ throw new Error('no weights to average');
9
+ }
10
+ if (!peersWeights.rest().every(function (ws) { return ws.length === firstWeightSize; })) {
11
+ throw new Error('variable weights size');
12
+ }
13
+ var numberOfPeers = peersWeights.size;
14
+ var peersAverageWeights = peersWeights.reduce(function (accum, weights) {
15
+ return accum.map(function (w, i) { return w.add(weights[i]); });
16
+ }).map(function (w) { return w.div(numberOfPeers); });
17
+ return peersAverageWeights;
18
+ }
19
+ exports.averageWeights = averageWeights;
@@ -0,0 +1,41 @@
1
+ import { AsyncInformant } from './async_informant';
2
+ import { TaskID } from './task';
3
+ /**
4
+ * The AsyncWeightsBuffer class holds and manipulates information about the
5
+ * async weights buffer. It works as follows:
6
+ *
7
+ * Setup: Init round to zero and create empty buffer (a map from user id to weights)
8
+ *
9
+ * - When a user adds weights only do so when they are recent weights: i.e. this.round - round <= roundCutoff.
10
+ * - If a user already added weights, update them. (-> there can be at most one entry of weights per id in a buffer).
11
+ * - When the buffer is full, call aggregateAndStoreWeights with the weights in the buffer and then increment round by one and reset the buffer.
12
+ *
13
+ * @remarks
14
+ * taskID: corresponds to the task that weights correspond to.
15
+ * bufferCapacity: size of the buffer.
16
+ * buffer: holds a map of users to their added weights.
17
+ * round: the latest round of the weight buffer.
18
+ * roundCutoff: cutoff for accepted rounds.
19
+ */
20
+ export declare class AsyncBuffer<T> {
21
+ readonly taskID: TaskID;
22
+ private readonly bufferCapacity;
23
+ private readonly aggregateAndStoreWeights;
24
+ private readonly roundCutoff;
25
+ buffer: Map<string, T>;
26
+ round: number;
27
+ private observer;
28
+ constructor(taskID: TaskID, bufferCapacity: number, aggregateAndStoreWeights: (weights: T[]) => Promise<void>, roundCutoff?: number);
29
+ registerObserver(observer: AsyncInformant<T>): void;
30
+ bufferIsFull(): boolean;
31
+ private updateWeightsIfBufferIsFull;
32
+ isNotWithinRoundCutoff(round: number): boolean;
33
+ /**
34
+ * Add weights originating from weights of a given round.
35
+ * Only add to buffer if the given round is not old.
36
+ * @param weights
37
+ * @param round
38
+ * @returns true if weights were added, and false otherwise
39
+ */
40
+ add(id: string, weights: T, round: number): Promise<boolean>;
41
+ }
@@ -0,0 +1,98 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.AsyncBuffer = void 0;
4
+ var tslib_1 = require("tslib");
5
+ /**
6
+ * The AsyncWeightsBuffer class holds and manipulates information about the
7
+ * async weights buffer. It works as follows:
8
+ *
9
+ * Setup: Init round to zero and create empty buffer (a map from user id to weights)
10
+ *
11
+ * - When a user adds weights only do so when they are recent weights: i.e. this.round - round <= roundCutoff.
12
+ * - If a user already added weights, update them. (-> there can be at most one entry of weights per id in a buffer).
13
+ * - When the buffer is full, call aggregateAndStoreWeights with the weights in the buffer and then increment round by one and reset the buffer.
14
+ *
15
+ * @remarks
16
+ * taskID: corresponds to the task that weights correspond to.
17
+ * bufferCapacity: size of the buffer.
18
+ * buffer: holds a map of users to their added weights.
19
+ * round: the latest round of the weight buffer.
20
+ * roundCutoff: cutoff for accepted rounds.
21
+ */
22
+ var AsyncBuffer = /** @class */ (function () {
23
+ function AsyncBuffer(taskID, bufferCapacity, aggregateAndStoreWeights, roundCutoff) {
24
+ if (roundCutoff === void 0) { roundCutoff = 0; }
25
+ this.taskID = taskID;
26
+ this.bufferCapacity = bufferCapacity;
27
+ this.aggregateAndStoreWeights = aggregateAndStoreWeights;
28
+ this.roundCutoff = roundCutoff;
29
+ this.buffer = new Map();
30
+ this.round = 0;
31
+ }
32
+ AsyncBuffer.prototype.registerObserver = function (observer) {
33
+ this.observer = observer;
34
+ };
35
+ // TODO do not test private
36
+ AsyncBuffer.prototype.bufferIsFull = function () {
37
+ return this.buffer.size >= this.bufferCapacity;
38
+ };
39
+ AsyncBuffer.prototype.updateWeightsIfBufferIsFull = function () {
40
+ var _a;
41
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
42
+ var allWeights;
43
+ return (0, tslib_1.__generator)(this, function (_b) {
44
+ switch (_b.label) {
45
+ case 0:
46
+ if (!this.bufferIsFull()) return [3 /*break*/, 2];
47
+ allWeights = Array.from(this.buffer.values());
48
+ return [4 /*yield*/, this.aggregateAndStoreWeights(allWeights)];
49
+ case 1:
50
+ _b.sent();
51
+ this.round += 1;
52
+ (_a = this.observer) === null || _a === void 0 ? void 0 : _a.update();
53
+ this.buffer.clear();
54
+ console.log('\n************************************************************');
55
+ console.log("Buffer is full; Aggregating weights and starting round: " + this.round + "\n");
56
+ _b.label = 2;
57
+ case 2: return [2 /*return*/];
58
+ }
59
+ });
60
+ });
61
+ };
62
+ // TODO do not test private
63
+ AsyncBuffer.prototype.isNotWithinRoundCutoff = function (round) {
64
+ // Note that always this.round >= round
65
+ return this.round - round > this.roundCutoff;
66
+ };
67
+ /**
68
+ * Add weights originating from weights of a given round.
69
+ * Only add to buffer if the given round is not old.
70
+ * @param weights
71
+ * @param round
72
+ * @returns true if weights were added, and false otherwise
73
+ */
74
+ AsyncBuffer.prototype.add = function (id, weights, round) {
75
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
76
+ var weightsUpdatedByUser, msg;
77
+ return (0, tslib_1.__generator)(this, function (_a) {
78
+ switch (_a.label) {
79
+ case 0:
80
+ if (this.isNotWithinRoundCutoff(round)) {
81
+ console.log("Did not add weights of " + id + " to buffer. Due to old round update: " + round + ", current round is " + this.round);
82
+ return [2 /*return*/, false];
83
+ }
84
+ weightsUpdatedByUser = this.buffer.has(id);
85
+ msg = weightsUpdatedByUser ? '\tUpdating' : '-> Adding new';
86
+ console.log(msg + " weights of " + id + " to buffer.");
87
+ this.buffer.set(id, weights);
88
+ return [4 /*yield*/, this.updateWeightsIfBufferIsFull()];
89
+ case 1:
90
+ _a.sent();
91
+ return [2 /*return*/, true];
92
+ }
93
+ });
94
+ });
95
+ };
96
+ return AsyncBuffer;
97
+ }());
98
+ exports.AsyncBuffer = AsyncBuffer;
@@ -0,0 +1,20 @@
1
+ import { AsyncBuffer } from './async_buffer';
2
+ export declare class AsyncInformant<T> {
3
+ private readonly asyncBuffer;
4
+ private round;
5
+ private currentNumberOfParticipants;
6
+ private totalNumberOfParticipants;
7
+ private averageNumberOfParticipants;
8
+ constructor(asyncBuffer: AsyncBuffer<T>);
9
+ update(): void;
10
+ private updateRound;
11
+ private updateNumberOfParticipants;
12
+ private updateAverageNumberOfParticipants;
13
+ private updateTotalNumberOfParticipants;
14
+ getCurrentRound(): number;
15
+ getNumberOfParticipants(): number;
16
+ getTotalNumberOfParticipants(): number;
17
+ getAverageNumberOfParticipants(): number;
18
+ getAllStatistics(): Record<'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number>;
19
+ printAllInfos(): void;
20
+ }
@@ -0,0 +1,69 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.AsyncInformant = void 0;
4
+ var AsyncInformant = /** @class */ (function () {
5
+ function AsyncInformant(asyncBuffer) {
6
+ this.asyncBuffer = asyncBuffer;
7
+ this.round = 0;
8
+ this.currentNumberOfParticipants = 0;
9
+ this.totalNumberOfParticipants = 0;
10
+ this.averageNumberOfParticipants = 0;
11
+ this.asyncBuffer.registerObserver(this);
12
+ }
13
+ // Update functions
14
+ AsyncInformant.prototype.update = function () {
15
+ // DEBUG
16
+ console.log('Before update');
17
+ this.printAllInfos();
18
+ this.updateRound();
19
+ this.updateNumberOfParticipants();
20
+ // DEBUG
21
+ console.log('After update');
22
+ this.printAllInfos();
23
+ };
24
+ AsyncInformant.prototype.updateRound = function () {
25
+ this.round = this.asyncBuffer.round;
26
+ };
27
+ AsyncInformant.prototype.updateNumberOfParticipants = function () {
28
+ this.currentNumberOfParticipants = this.asyncBuffer.buffer.size;
29
+ this.updateTotalNumberOfParticipants(this.currentNumberOfParticipants);
30
+ this.updateAverageNumberOfParticipants();
31
+ };
32
+ AsyncInformant.prototype.updateAverageNumberOfParticipants = function () {
33
+ this.averageNumberOfParticipants = this.totalNumberOfParticipants / this.round;
34
+ };
35
+ AsyncInformant.prototype.updateTotalNumberOfParticipants = function (currentNumberOfParticipants) {
36
+ this.totalNumberOfParticipants += currentNumberOfParticipants;
37
+ };
38
+ // Getter functions
39
+ AsyncInformant.prototype.getCurrentRound = function () {
40
+ return this.round;
41
+ };
42
+ AsyncInformant.prototype.getNumberOfParticipants = function () {
43
+ return this.currentNumberOfParticipants;
44
+ };
45
+ AsyncInformant.prototype.getTotalNumberOfParticipants = function () {
46
+ return this.totalNumberOfParticipants;
47
+ };
48
+ AsyncInformant.prototype.getAverageNumberOfParticipants = function () {
49
+ return this.averageNumberOfParticipants;
50
+ };
51
+ AsyncInformant.prototype.getAllStatistics = function () {
52
+ return {
53
+ round: this.getCurrentRound(),
54
+ currentNumberOfParticipants: this.getNumberOfParticipants(),
55
+ totalNumberOfParticipants: this.getTotalNumberOfParticipants(),
56
+ averageNumberOfParticipants: this.getAverageNumberOfParticipants()
57
+ };
58
+ };
59
+ // Debug
60
+ AsyncInformant.prototype.printAllInfos = function () {
61
+ console.log('task : ', this.asyncBuffer.taskID);
62
+ console.log('round : ', this.getCurrentRound());
63
+ console.log('participants : ', this.getNumberOfParticipants());
64
+ console.log('total : ', this.getTotalNumberOfParticipants());
65
+ console.log('average : ', this.getAverageNumberOfParticipants());
66
+ };
67
+ return AsyncInformant;
68
+ }());
69
+ exports.AsyncInformant = AsyncInformant;
@@ -0,0 +1,36 @@
1
+ /// <reference types="node" />
2
+ import * as tf from '@tensorflow/tfjs';
3
+ import { Task } from '@/task';
4
+ import { TrainingInformant } from '@/training_informant';
5
+ import { Weights } from '@/types';
6
+ export declare abstract class Base {
7
+ readonly url: URL;
8
+ readonly task: Task;
9
+ constructor(url: URL, task: Task);
10
+ /**
11
+ * Handles the connection process from the client to any sort of
12
+ * centralized server.
13
+ */
14
+ abstract connect(): Promise<void>;
15
+ /**
16
+ * Handles the disconnection process of the client from any sort
17
+ * of centralized server.
18
+ */
19
+ abstract disconnect(): Promise<void>;
20
+ getLatestModel(): Promise<tf.LayersModel>;
21
+ /**
22
+ * The training manager matches this function with the training loop's
23
+ * onTrainEnd callback when training a TFJS model object. See the
24
+ * training manager for more details.
25
+ */
26
+ abstract onTrainEndCommunication(weights: Weights, trainingInformant: TrainingInformant): Promise<void>;
27
+ /**
28
+ * This function will be called whenever a local round has ended.
29
+ *
30
+ * @param updatedWeights
31
+ * @param staleWeights
32
+ * @param round
33
+ * @param trainingInformant
34
+ */
35
+ abstract onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<Weights>;
36
+ }
@@ -0,0 +1,34 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Base = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var axios_1 = (0, tslib_1.__importDefault)(require("axios"));
6
+ var serialization = (0, tslib_1.__importStar)(require("../serialization"));
7
+ var Base = /** @class */ (function () {
8
+ function Base(url, task) {
9
+ this.url = url;
10
+ this.task = task;
11
+ }
12
+ Base.prototype.getLatestModel = function () {
13
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
14
+ var url, response;
15
+ return (0, tslib_1.__generator)(this, function (_a) {
16
+ switch (_a.label) {
17
+ case 0:
18
+ url = new URL('', this.url.href);
19
+ if (!url.pathname.endsWith('/')) {
20
+ url.pathname += '/';
21
+ }
22
+ url.pathname += "tasks/" + this.task.taskID + "/model.json";
23
+ return [4 /*yield*/, axios_1.default.get(url.href)];
24
+ case 1:
25
+ response = _a.sent();
26
+ return [4 /*yield*/, serialization.model.decode(response.data)];
27
+ case 2: return [2 /*return*/, _a.sent()];
28
+ }
29
+ });
30
+ });
31
+ };
32
+ return Base;
33
+ }());
34
+ exports.Base = Base;
@@ -0,0 +1,23 @@
1
+ import { TrainingInformant, Weights } from '..';
2
+ import { Base } from './base';
3
+ /**
4
+ * Class that deals with communication with the PeerJS server.
5
+ * Collects the list of receivers currently connected to the PeerJS server.
6
+ */
7
+ export declare class Decentralized extends Base {
8
+ private server?;
9
+ private peers;
10
+ private readonly weights;
11
+ private connectServer;
12
+ private connectNewPeer;
13
+ /**
14
+ * Initialize the connection to the peers and to the other nodes.
15
+ */
16
+ connect(): Promise<void>;
17
+ /**
18
+ * Disconnection process when user quits the task.
19
+ */
20
+ disconnect(): Promise<void>;
21
+ onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, epoch: number, trainingInformant: TrainingInformant): Promise<Weights>;
22
+ onTrainEndCommunication(_: Weights, trainingInformant: TrainingInformant): Promise<void>;
23
+ }
@@ -0,0 +1,275 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Decentralized = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var immutable_1 = require("immutable");
6
+ var isomorphic_ws_1 = (0, tslib_1.__importDefault)(require("isomorphic-ws"));
7
+ var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
8
+ var simple_peer_1 = (0, tslib_1.__importDefault)(require("simple-peer"));
9
+ var url_1 = require("url");
10
+ var __1 = require("..");
11
+ var base_1 = require("./base");
12
+ function isPeerMessage(data) {
13
+ if (typeof data !== 'object') {
14
+ return false;
15
+ }
16
+ if (data === null) {
17
+ return false;
18
+ }
19
+ if (!(0, immutable_1.Set)(Object.keys(data)).equals(immutable_1.Set.of('epoch', 'weights'))) {
20
+ return false;
21
+ }
22
+ var _a = data, epoch = _a.epoch, weights = _a.weights;
23
+ if (typeof epoch !== 'number' ||
24
+ !__1.serialization.weights.isEncoded(weights)) {
25
+ return false;
26
+ }
27
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
28
+ var _ = { epoch: epoch, weights: weights };
29
+ return true;
30
+ }
31
+ function isServerOpeningMessage(msg) {
32
+ if (!(msg instanceof Array)) {
33
+ return false;
34
+ }
35
+ if (!msg.every(function (elem) { return typeof elem === 'number'; })) {
36
+ return false;
37
+ }
38
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
39
+ var _ = msg;
40
+ return true;
41
+ }
42
+ function isServerPeerMessage(msg) {
43
+ if (!(msg instanceof Array)) {
44
+ return false;
45
+ }
46
+ if (msg.length !== 2) {
47
+ return false;
48
+ }
49
+ var _a = (0, tslib_1.__read)(msg, 2), id = _a[0], signal = _a[1];
50
+ if (typeof id !== 'number') {
51
+ return false;
52
+ }
53
+ if (!(signal instanceof Uint8Array)) {
54
+ return false;
55
+ }
56
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
57
+ var _ = [id, signal];
58
+ return true;
59
+ }
60
+ // Time to wait between network checks in milliseconds.
61
+ var TICK = 100;
62
+ // Time to wait for the others in milliseconds.
63
+ var MAX_WAIT_PER_ROUND = 10000;
64
+ /**
65
+ * Class that deals with communication with the PeerJS server.
66
+ * Collects the list of receivers currently connected to the PeerJS server.
67
+ */
68
+ var Decentralized = /** @class */ (function (_super) {
69
+ (0, tslib_1.__extends)(Decentralized, _super);
70
+ function Decentralized() {
71
+ var _this = _super !== null && _super.apply(this, arguments) || this;
72
+ _this.peers = (0, immutable_1.Map)();
73
+ _this.weights = (0, immutable_1.Map)();
74
+ return _this;
75
+ }
76
+ Decentralized.prototype.connectServer = function (url) {
77
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
78
+ var ws;
79
+ var _this = this;
80
+ return (0, tslib_1.__generator)(this, function (_a) {
81
+ switch (_a.label) {
82
+ case 0:
83
+ ws = new isomorphic_ws_1.default.WebSocket(url);
84
+ ws.binaryType = 'arraybuffer';
85
+ ws.onmessage = function (event) {
86
+ if (!(event.data instanceof ArrayBuffer)) {
87
+ throw new Error('server did not send an ArrayBuffer');
88
+ }
89
+ var msg = msgpack_lite_1.default.decode(new Uint8Array(event.data));
90
+ if (isServerOpeningMessage(msg)) {
91
+ console.debug('server sent us the list of peer to connect to:', msg);
92
+ if (_this.peers.size !== 0) {
93
+ throw new Error('server already gave us a list of peers');
94
+ }
95
+ _this.peers = (0, immutable_1.Map)((0, immutable_1.List)(msg)
96
+ .map(function (id) { return [id, _this.connectNewPeer(id, true)]; }));
97
+ }
98
+ else if (isServerPeerMessage(msg)) {
99
+ var _a = (0, tslib_1.__read)(msg, 2), peerID = _a[0], encodedSignal = _a[1];
100
+ var signal = msgpack_lite_1.default.decode(encodedSignal);
101
+ console.debug('server on behalf of', peerID, 'sent', signal);
102
+ var peer = _this.peers.get(peerID);
103
+ if (peer === undefined) {
104
+ peer = _this.connectNewPeer(peerID, false);
105
+ _this.peers = _this.peers.set(peerID, peer);
106
+ }
107
+ peer.signal(signal);
108
+ }
109
+ else {
110
+ throw new Error('send sent an invalid msg');
111
+ }
112
+ };
113
+ return [4 /*yield*/, new Promise(function (resolve, reject) {
114
+ ws.onerror = function (err) { return reject(new Error("connecting server: " + err)); };
115
+ ws.onopen = function () { return resolve(ws); };
116
+ })];
117
+ case 1: return [2 /*return*/, _a.sent()];
118
+ }
119
+ });
120
+ });
121
+ };
122
+ // connect a new peer
123
+ //
124
+ // if initiator is true, we start the connection on our side
125
+ // see SimplePeer.Options.initiator for more info
126
+ Decentralized.prototype.connectNewPeer = function (peerID, initiator) {
127
+ var _this = this;
128
+ console.debug('connect new peer with initiator: ', initiator);
129
+ var peer = new simple_peer_1.default({
130
+ initiator: initiator,
131
+ config: {
132
+ iceServers: (0, immutable_1.List)(simple_peer_1.default.config.iceServers)
133
+ /* .push({
134
+ urls: 'turn:34.77.172.69:3478',
135
+ credential: 'deai',
136
+ username: 'deai'
137
+ }) */
138
+ .toArray()
139
+ }
140
+ });
141
+ peer.on('signal', function (signal) {
142
+ console.debug('local', peerID, 'is signaling', signal);
143
+ if (_this.server === undefined) {
144
+ throw new Error('server closed but received a signal');
145
+ }
146
+ var msg = [peerID, msgpack_lite_1.default.encode(signal)];
147
+ _this.server.send(msgpack_lite_1.default.encode(msg));
148
+ });
149
+ peer.on('data', function (data) {
150
+ var _a;
151
+ var message = msgpack_lite_1.default.decode(data);
152
+ if (!isPeerMessage(message)) {
153
+ throw new Error("invalid message received from " + peerID);
154
+ }
155
+ var weights = __1.serialization.weights.decode(message.weights);
156
+ console.debug('peer', peerID, 'sent weights', weights);
157
+ if (((_a = _this.weights.get(peer)) === null || _a === void 0 ? void 0 : _a.get(message.epoch)) !== undefined) {
158
+ throw new Error("weights from " + peerID + " already received");
159
+ }
160
+ _this.weights.set(peer, _this.weights.get(peer, (0, immutable_1.List)())
161
+ .set(message.epoch, weights));
162
+ });
163
+ peer.on('connect', function () { return console.info('connected to peer', peerID); });
164
+ // TODO better error handling
165
+ peer.on('error', function (err) { throw err; });
166
+ return peer;
167
+ };
168
+ /**
169
+ * Initialize the connection to the peers and to the other nodes.
170
+ */
171
+ Decentralized.prototype.connect = function () {
172
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
173
+ var serverURL, _a;
174
+ return (0, tslib_1.__generator)(this, function (_b) {
175
+ switch (_b.label) {
176
+ case 0:
177
+ serverURL = new url_1.URL('', this.url.href);
178
+ serverURL.pathname += "/deai/tasks/" + this.task.taskID;
179
+ _a = this;
180
+ return [4 /*yield*/, this.connectServer(serverURL)];
181
+ case 1:
182
+ _a.server = _b.sent();
183
+ return [2 /*return*/];
184
+ }
185
+ });
186
+ });
187
+ };
188
+ /**
189
+ * Disconnection process when user quits the task.
190
+ */
191
+ Decentralized.prototype.disconnect = function () {
192
+ var _a;
193
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
194
+ return (0, tslib_1.__generator)(this, function (_b) {
195
+ this.peers.forEach(function (peer) { return peer.destroy(); });
196
+ this.peers = (0, immutable_1.Map)();
197
+ (_a = this.server) === null || _a === void 0 ? void 0 : _a.close();
198
+ this.server = undefined;
199
+ return [2 /*return*/];
200
+ });
201
+ });
202
+ };
203
+ Decentralized.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, epoch, trainingInformant) {
204
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
205
+ var noisyWeights, msg, encodedMsg, getWeights, timeoutError, receivedWeights;
206
+ var _a;
207
+ var _this = this;
208
+ return (0, tslib_1.__generator)(this, function (_b) {
209
+ switch (_b.label) {
210
+ case 0:
211
+ noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
212
+ _a = {
213
+ epoch: epoch
214
+ };
215
+ return [4 /*yield*/, __1.serialization.weights.encode(noisyWeights)];
216
+ case 1:
217
+ msg = (_a.weights = _b.sent(),
218
+ _a);
219
+ encodedMsg = msgpack_lite_1.default.encode(msg);
220
+ this.peers
221
+ .filter(function (peer) { return peer.connected; })
222
+ .forEach(function (peer, peerID) {
223
+ trainingInformant.addMessage("Sending weights to peer " + peerID);
224
+ trainingInformant.updateWhoReceivedMyModel("peer " + peerID);
225
+ peer.send(encodedMsg);
226
+ });
227
+ getWeights = function () {
228
+ return _this.weights
229
+ .valueSeq()
230
+ .map(function (epochesWeights) { return epochesWeights.get(epoch); });
231
+ };
232
+ timeoutError = new Error('timeout');
233
+ return [4 /*yield*/, new Promise(function (resolve, reject) {
234
+ var interval = setInterval(function () {
235
+ var gotAllWeights = getWeights().every(function (weights) { return weights !== undefined; });
236
+ if (gotAllWeights) {
237
+ clearInterval(interval);
238
+ resolve();
239
+ }
240
+ }, TICK);
241
+ setTimeout(function () {
242
+ clearInterval(interval);
243
+ reject(timeoutError);
244
+ }, MAX_WAIT_PER_ROUND);
245
+ }).catch(function (err) {
246
+ if (err !== timeoutError) {
247
+ throw err;
248
+ }
249
+ })];
250
+ case 2:
251
+ _b.sent();
252
+ receivedWeights = getWeights()
253
+ .filter(function (weights) { return weights !== undefined; })
254
+ .toSet();
255
+ // Average weights
256
+ trainingInformant.addMessage('Averaging weights');
257
+ trainingInformant.updateNbrUpdatesWithOthers(1);
258
+ // Return the new "received" weights
259
+ return [2 /*return*/, __1.aggregation.averageWeights(receivedWeights)];
260
+ }
261
+ });
262
+ });
263
+ };
264
+ Decentralized.prototype.onTrainEndCommunication = function (_, trainingInformant) {
265
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
266
+ return (0, tslib_1.__generator)(this, function (_a) {
267
+ // TODO: enter seeding mode?
268
+ trainingInformant.addMessage('Training finished.');
269
+ return [2 /*return*/];
270
+ });
271
+ });
272
+ };
273
+ return Decentralized;
274
+ }(base_1.Base));
275
+ exports.Decentralized = Decentralized;