@epfml/discojs 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +32 -0
- package/dist/aggregation.d.ts +3 -0
- package/dist/aggregation.js +19 -0
- package/dist/async_buffer.d.ts +41 -0
- package/dist/async_buffer.js +98 -0
- package/dist/async_informant.d.ts +20 -0
- package/dist/async_informant.js +69 -0
- package/dist/client/base.d.ts +36 -0
- package/dist/client/base.js +34 -0
- package/dist/client/decentralized.d.ts +23 -0
- package/dist/client/decentralized.js +275 -0
- package/dist/client/federated.d.ts +30 -0
- package/dist/client/federated.js +221 -0
- package/dist/client/index.d.ts +4 -0
- package/dist/client/index.js +11 -0
- package/dist/client/local.d.ts +8 -0
- package/dist/client/local.js +34 -0
- package/dist/dataset/data_loader/data_loader.d.ts +16 -0
- package/dist/dataset/data_loader/data_loader.js +10 -0
- package/dist/dataset/data_loader/image_loader.d.ts +14 -0
- package/dist/dataset/data_loader/image_loader.js +93 -0
- package/dist/dataset/data_loader/index.d.ts +3 -0
- package/dist/dataset/data_loader/index.js +9 -0
- package/dist/dataset/data_loader/tabular_loader.d.ts +29 -0
- package/dist/dataset/data_loader/tabular_loader.js +88 -0
- package/dist/dataset/dataset_builder.d.ts +17 -0
- package/dist/dataset/dataset_builder.js +80 -0
- package/dist/dataset/index.d.ts +2 -0
- package/dist/dataset/index.js +7 -0
- package/dist/index.d.ts +17 -0
- package/dist/index.js +34 -0
- package/dist/logging/console_logger.d.ts +18 -0
- package/dist/logging/console_logger.js +33 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +7 -0
- package/dist/logging/logger.d.ts +12 -0
- package/dist/logging/logger.js +9 -0
- package/dist/logging/trainer_logger.d.ts +24 -0
- package/dist/logging/trainer_logger.js +59 -0
- package/dist/memory/base.d.ts +53 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +12 -0
- package/dist/memory/empty.js +69 -0
- package/dist/memory/index.d.ts +3 -0
- package/dist/memory/index.js +9 -0
- package/dist/memory/model_type.d.ts +4 -0
- package/dist/memory/model_type.js +9 -0
- package/dist/model_actor.d.ts +16 -0
- package/dist/model_actor.js +20 -0
- package/dist/privacy.d.ts +12 -0
- package/dist/privacy.js +60 -0
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +6 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +55 -0
- package/dist/serialization/weights.d.ts +5 -0
- package/dist/serialization/weights.js +62 -0
- package/dist/task/data_example.d.ts +5 -0
- package/dist/task/data_example.js +24 -0
- package/dist/task/display_information.d.ts +15 -0
- package/dist/task/display_information.js +53 -0
- package/dist/task/index.d.ts +3 -0
- package/dist/task/index.js +8 -0
- package/dist/task/model_compile_data.d.ts +6 -0
- package/dist/task/model_compile_data.js +12 -0
- package/dist/task/task.d.ts +10 -0
- package/dist/task/task.js +32 -0
- package/dist/task/training_information.d.ts +29 -0
- package/dist/task/training_information.js +2 -0
- package/dist/tasks/cifar10.d.ts +4 -0
- package/dist/tasks/cifar10.js +74 -0
- package/dist/tasks/index.d.ts +5 -0
- package/dist/tasks/index.js +9 -0
- package/dist/tasks/lus_covid.d.ts +4 -0
- package/dist/tasks/lus_covid.js +48 -0
- package/dist/tasks/mnist.d.ts +4 -0
- package/dist/tasks/mnist.js +56 -0
- package/dist/tasks/simple_face.d.ts +4 -0
- package/dist/tasks/simple_face.js +88 -0
- package/dist/tasks/titanic.d.ts +4 -0
- package/dist/tasks/titanic.js +86 -0
- package/dist/testing/tester.d.ts +5 -0
- package/dist/testing/tester.js +21 -0
- package/dist/training/disco.d.ts +12 -0
- package/dist/training/disco.js +62 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +7 -0
- package/dist/training/trainer/distributed_trainer.d.ts +21 -0
- package/dist/training/trainer/distributed_trainer.js +60 -0
- package/dist/training/trainer/local_trainer.d.ts +10 -0
- package/dist/training/trainer/local_trainer.js +37 -0
- package/dist/training/trainer/round_tracker.d.ts +30 -0
- package/dist/training/trainer/round_tracker.js +44 -0
- package/dist/training/trainer/trainer.d.ts +66 -0
- package/dist/training/trainer/trainer.js +146 -0
- package/dist/training/trainer/trainer_builder.d.ts +25 -0
- package/dist/training/trainer/trainer_builder.js +102 -0
- package/dist/training/training_schemes.d.ts +5 -0
- package/dist/training/training_schemes.js +10 -0
- package/dist/training_informant.d.ts +88 -0
- package/dist/training_informant.js +135 -0
- package/dist/types.d.ts +4 -0
- package/dist/types.js +2 -0
- package/package.json +48 -0
package/README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# discojs
|
|
2
|
+
|
|
3
|
+
discojs contains the core code of disco.
|
|
4
|
+
|
|
5
|
+
## Node Installation and NPM installation
|
|
6
|
+
|
|
7
|
+
The app is running under Node 15.12.0. It can be downloaded from [here](https://nodejs.org/en/download/releases/).
|
|
8
|
+
|
|
9
|
+
NPM is a package manager for the JavaScript runtime environment Node.js.
|
|
10
|
+
To start the application (running locally) run the following command.
|
|
11
|
+
Note: the application is currently developed using [NPM 7.6.3](https://www.npmjs.com/package/npm/v/7.6.3).
|
|
12
|
+
|
|
13
|
+
```
|
|
14
|
+
npm install
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
This command will install the necessary libraries required to run the application (defined in the `package.json` and `package-lock.json`). The latter command is only required when one is using the app for the first time.
|
|
18
|
+
|
|
19
|
+
> **⚠ WARNING: Apple Silicon.**
|
|
20
|
+
> `TensorFlow.js` in version `3.13.0` currently supports for M1 mac laptops. However, make sure you have an `arm` node executable installed (not `x86`). It can be checked using:
|
|
21
|
+
|
|
22
|
+
```
|
|
23
|
+
node -p "process.arch"
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
## Build
|
|
27
|
+
|
|
28
|
+
In order to enable the Browser to use the `discojs` package, we must build discojs:
|
|
29
|
+
|
|
30
|
+
```
|
|
31
|
+
npm run build
|
|
32
|
+
```
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.averageWeights = void 0;
|
|
4
|
+
function averageWeights(peersWeights) {
|
|
5
|
+
var _a;
|
|
6
|
+
var firstWeightSize = (_a = peersWeights.first()) === null || _a === void 0 ? void 0 : _a.length;
|
|
7
|
+
if (firstWeightSize === undefined) {
|
|
8
|
+
throw new Error('no weights to average');
|
|
9
|
+
}
|
|
10
|
+
if (!peersWeights.rest().every(function (ws) { return ws.length === firstWeightSize; })) {
|
|
11
|
+
throw new Error('variable weights size');
|
|
12
|
+
}
|
|
13
|
+
var numberOfPeers = peersWeights.size;
|
|
14
|
+
var peersAverageWeights = peersWeights.reduce(function (accum, weights) {
|
|
15
|
+
return accum.map(function (w, i) { return w.add(weights[i]); });
|
|
16
|
+
}).map(function (w) { return w.div(numberOfPeers); });
|
|
17
|
+
return peersAverageWeights;
|
|
18
|
+
}
|
|
19
|
+
exports.averageWeights = averageWeights;
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import { AsyncInformant } from './async_informant';
|
|
2
|
+
import { TaskID } from './task';
|
|
3
|
+
/**
|
|
4
|
+
* The AsyncWeightsBuffer class holds and manipulates information about the
|
|
5
|
+
* async weights buffer. It works as follows:
|
|
6
|
+
*
|
|
7
|
+
* Setup: Init round to zero and create empty buffer (a map from user id to weights)
|
|
8
|
+
*
|
|
9
|
+
* - When a user adds weights only do so when they are recent weights: i.e. this.round - round <= roundCutoff.
|
|
10
|
+
* - If a user already added weights, update them. (-> there can be at most one entry of weights per id in a buffer).
|
|
11
|
+
* - When the buffer is full, call aggregateAndStoreWeights with the weights in the buffer and then increment round by one and reset the buffer.
|
|
12
|
+
*
|
|
13
|
+
* @remarks
|
|
14
|
+
* taskID: corresponds to the task that weights correspond to.
|
|
15
|
+
* bufferCapacity: size of the buffer.
|
|
16
|
+
* buffer: holds a map of users to their added weights.
|
|
17
|
+
* round: the latest round of the weight buffer.
|
|
18
|
+
* roundCutoff: cutoff for accepted rounds.
|
|
19
|
+
*/
|
|
20
|
+
export declare class AsyncBuffer<T> {
|
|
21
|
+
readonly taskID: TaskID;
|
|
22
|
+
private readonly bufferCapacity;
|
|
23
|
+
private readonly aggregateAndStoreWeights;
|
|
24
|
+
private readonly roundCutoff;
|
|
25
|
+
buffer: Map<string, T>;
|
|
26
|
+
round: number;
|
|
27
|
+
private observer;
|
|
28
|
+
constructor(taskID: TaskID, bufferCapacity: number, aggregateAndStoreWeights: (weights: T[]) => Promise<void>, roundCutoff?: number);
|
|
29
|
+
registerObserver(observer: AsyncInformant<T>): void;
|
|
30
|
+
bufferIsFull(): boolean;
|
|
31
|
+
private updateWeightsIfBufferIsFull;
|
|
32
|
+
isNotWithinRoundCutoff(round: number): boolean;
|
|
33
|
+
/**
|
|
34
|
+
* Add weights originating from weights of a given round.
|
|
35
|
+
* Only add to buffer if the given round is not old.
|
|
36
|
+
* @param weights
|
|
37
|
+
* @param round
|
|
38
|
+
* @returns true if weights were added, and false otherwise
|
|
39
|
+
*/
|
|
40
|
+
add(id: string, weights: T, round: number): Promise<boolean>;
|
|
41
|
+
}
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.AsyncBuffer = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
/**
|
|
6
|
+
* The AsyncWeightsBuffer class holds and manipulates information about the
|
|
7
|
+
* async weights buffer. It works as follows:
|
|
8
|
+
*
|
|
9
|
+
* Setup: Init round to zero and create empty buffer (a map from user id to weights)
|
|
10
|
+
*
|
|
11
|
+
* - When a user adds weights only do so when they are recent weights: i.e. this.round - round <= roundCutoff.
|
|
12
|
+
* - If a user already added weights, update them. (-> there can be at most one entry of weights per id in a buffer).
|
|
13
|
+
* - When the buffer is full, call aggregateAndStoreWeights with the weights in the buffer and then increment round by one and reset the buffer.
|
|
14
|
+
*
|
|
15
|
+
* @remarks
|
|
16
|
+
* taskID: corresponds to the task that weights correspond to.
|
|
17
|
+
* bufferCapacity: size of the buffer.
|
|
18
|
+
* buffer: holds a map of users to their added weights.
|
|
19
|
+
* round: the latest round of the weight buffer.
|
|
20
|
+
* roundCutoff: cutoff for accepted rounds.
|
|
21
|
+
*/
|
|
22
|
+
var AsyncBuffer = /** @class */ (function () {
|
|
23
|
+
function AsyncBuffer(taskID, bufferCapacity, aggregateAndStoreWeights, roundCutoff) {
|
|
24
|
+
if (roundCutoff === void 0) { roundCutoff = 0; }
|
|
25
|
+
this.taskID = taskID;
|
|
26
|
+
this.bufferCapacity = bufferCapacity;
|
|
27
|
+
this.aggregateAndStoreWeights = aggregateAndStoreWeights;
|
|
28
|
+
this.roundCutoff = roundCutoff;
|
|
29
|
+
this.buffer = new Map();
|
|
30
|
+
this.round = 0;
|
|
31
|
+
}
|
|
32
|
+
AsyncBuffer.prototype.registerObserver = function (observer) {
|
|
33
|
+
this.observer = observer;
|
|
34
|
+
};
|
|
35
|
+
// TODO do not test private
|
|
36
|
+
AsyncBuffer.prototype.bufferIsFull = function () {
|
|
37
|
+
return this.buffer.size >= this.bufferCapacity;
|
|
38
|
+
};
|
|
39
|
+
AsyncBuffer.prototype.updateWeightsIfBufferIsFull = function () {
|
|
40
|
+
var _a;
|
|
41
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
42
|
+
var allWeights;
|
|
43
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
44
|
+
switch (_b.label) {
|
|
45
|
+
case 0:
|
|
46
|
+
if (!this.bufferIsFull()) return [3 /*break*/, 2];
|
|
47
|
+
allWeights = Array.from(this.buffer.values());
|
|
48
|
+
return [4 /*yield*/, this.aggregateAndStoreWeights(allWeights)];
|
|
49
|
+
case 1:
|
|
50
|
+
_b.sent();
|
|
51
|
+
this.round += 1;
|
|
52
|
+
(_a = this.observer) === null || _a === void 0 ? void 0 : _a.update();
|
|
53
|
+
this.buffer.clear();
|
|
54
|
+
console.log('\n************************************************************');
|
|
55
|
+
console.log("Buffer is full; Aggregating weights and starting round: " + this.round + "\n");
|
|
56
|
+
_b.label = 2;
|
|
57
|
+
case 2: return [2 /*return*/];
|
|
58
|
+
}
|
|
59
|
+
});
|
|
60
|
+
});
|
|
61
|
+
};
|
|
62
|
+
// TODO do not test private
|
|
63
|
+
AsyncBuffer.prototype.isNotWithinRoundCutoff = function (round) {
|
|
64
|
+
// Note that always this.round >= round
|
|
65
|
+
return this.round - round > this.roundCutoff;
|
|
66
|
+
};
|
|
67
|
+
/**
|
|
68
|
+
* Add weights originating from weights of a given round.
|
|
69
|
+
* Only add to buffer if the given round is not old.
|
|
70
|
+
* @param weights
|
|
71
|
+
* @param round
|
|
72
|
+
* @returns true if weights were added, and false otherwise
|
|
73
|
+
*/
|
|
74
|
+
AsyncBuffer.prototype.add = function (id, weights, round) {
|
|
75
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
76
|
+
var weightsUpdatedByUser, msg;
|
|
77
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
78
|
+
switch (_a.label) {
|
|
79
|
+
case 0:
|
|
80
|
+
if (this.isNotWithinRoundCutoff(round)) {
|
|
81
|
+
console.log("Did not add weights of " + id + " to buffer. Due to old round update: " + round + ", current round is " + this.round);
|
|
82
|
+
return [2 /*return*/, false];
|
|
83
|
+
}
|
|
84
|
+
weightsUpdatedByUser = this.buffer.has(id);
|
|
85
|
+
msg = weightsUpdatedByUser ? '\tUpdating' : '-> Adding new';
|
|
86
|
+
console.log(msg + " weights of " + id + " to buffer.");
|
|
87
|
+
this.buffer.set(id, weights);
|
|
88
|
+
return [4 /*yield*/, this.updateWeightsIfBufferIsFull()];
|
|
89
|
+
case 1:
|
|
90
|
+
_a.sent();
|
|
91
|
+
return [2 /*return*/, true];
|
|
92
|
+
}
|
|
93
|
+
});
|
|
94
|
+
});
|
|
95
|
+
};
|
|
96
|
+
return AsyncBuffer;
|
|
97
|
+
}());
|
|
98
|
+
exports.AsyncBuffer = AsyncBuffer;
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import { AsyncBuffer } from './async_buffer';
|
|
2
|
+
export declare class AsyncInformant<T> {
|
|
3
|
+
private readonly asyncBuffer;
|
|
4
|
+
private round;
|
|
5
|
+
private currentNumberOfParticipants;
|
|
6
|
+
private totalNumberOfParticipants;
|
|
7
|
+
private averageNumberOfParticipants;
|
|
8
|
+
constructor(asyncBuffer: AsyncBuffer<T>);
|
|
9
|
+
update(): void;
|
|
10
|
+
private updateRound;
|
|
11
|
+
private updateNumberOfParticipants;
|
|
12
|
+
private updateAverageNumberOfParticipants;
|
|
13
|
+
private updateTotalNumberOfParticipants;
|
|
14
|
+
getCurrentRound(): number;
|
|
15
|
+
getNumberOfParticipants(): number;
|
|
16
|
+
getTotalNumberOfParticipants(): number;
|
|
17
|
+
getAverageNumberOfParticipants(): number;
|
|
18
|
+
getAllStatistics(): Record<'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number>;
|
|
19
|
+
printAllInfos(): void;
|
|
20
|
+
}
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.AsyncInformant = void 0;
|
|
4
|
+
var AsyncInformant = /** @class */ (function () {
|
|
5
|
+
function AsyncInformant(asyncBuffer) {
|
|
6
|
+
this.asyncBuffer = asyncBuffer;
|
|
7
|
+
this.round = 0;
|
|
8
|
+
this.currentNumberOfParticipants = 0;
|
|
9
|
+
this.totalNumberOfParticipants = 0;
|
|
10
|
+
this.averageNumberOfParticipants = 0;
|
|
11
|
+
this.asyncBuffer.registerObserver(this);
|
|
12
|
+
}
|
|
13
|
+
// Update functions
|
|
14
|
+
AsyncInformant.prototype.update = function () {
|
|
15
|
+
// DEBUG
|
|
16
|
+
console.log('Before update');
|
|
17
|
+
this.printAllInfos();
|
|
18
|
+
this.updateRound();
|
|
19
|
+
this.updateNumberOfParticipants();
|
|
20
|
+
// DEBUG
|
|
21
|
+
console.log('After update');
|
|
22
|
+
this.printAllInfos();
|
|
23
|
+
};
|
|
24
|
+
AsyncInformant.prototype.updateRound = function () {
|
|
25
|
+
this.round = this.asyncBuffer.round;
|
|
26
|
+
};
|
|
27
|
+
AsyncInformant.prototype.updateNumberOfParticipants = function () {
|
|
28
|
+
this.currentNumberOfParticipants = this.asyncBuffer.buffer.size;
|
|
29
|
+
this.updateTotalNumberOfParticipants(this.currentNumberOfParticipants);
|
|
30
|
+
this.updateAverageNumberOfParticipants();
|
|
31
|
+
};
|
|
32
|
+
AsyncInformant.prototype.updateAverageNumberOfParticipants = function () {
|
|
33
|
+
this.averageNumberOfParticipants = this.totalNumberOfParticipants / this.round;
|
|
34
|
+
};
|
|
35
|
+
AsyncInformant.prototype.updateTotalNumberOfParticipants = function (currentNumberOfParticipants) {
|
|
36
|
+
this.totalNumberOfParticipants += currentNumberOfParticipants;
|
|
37
|
+
};
|
|
38
|
+
// Getter functions
|
|
39
|
+
AsyncInformant.prototype.getCurrentRound = function () {
|
|
40
|
+
return this.round;
|
|
41
|
+
};
|
|
42
|
+
AsyncInformant.prototype.getNumberOfParticipants = function () {
|
|
43
|
+
return this.currentNumberOfParticipants;
|
|
44
|
+
};
|
|
45
|
+
AsyncInformant.prototype.getTotalNumberOfParticipants = function () {
|
|
46
|
+
return this.totalNumberOfParticipants;
|
|
47
|
+
};
|
|
48
|
+
AsyncInformant.prototype.getAverageNumberOfParticipants = function () {
|
|
49
|
+
return this.averageNumberOfParticipants;
|
|
50
|
+
};
|
|
51
|
+
AsyncInformant.prototype.getAllStatistics = function () {
|
|
52
|
+
return {
|
|
53
|
+
round: this.getCurrentRound(),
|
|
54
|
+
currentNumberOfParticipants: this.getNumberOfParticipants(),
|
|
55
|
+
totalNumberOfParticipants: this.getTotalNumberOfParticipants(),
|
|
56
|
+
averageNumberOfParticipants: this.getAverageNumberOfParticipants()
|
|
57
|
+
};
|
|
58
|
+
};
|
|
59
|
+
// Debug
|
|
60
|
+
AsyncInformant.prototype.printAllInfos = function () {
|
|
61
|
+
console.log('task : ', this.asyncBuffer.taskID);
|
|
62
|
+
console.log('round : ', this.getCurrentRound());
|
|
63
|
+
console.log('participants : ', this.getNumberOfParticipants());
|
|
64
|
+
console.log('total : ', this.getTotalNumberOfParticipants());
|
|
65
|
+
console.log('average : ', this.getAverageNumberOfParticipants());
|
|
66
|
+
};
|
|
67
|
+
return AsyncInformant;
|
|
68
|
+
}());
|
|
69
|
+
exports.AsyncInformant = AsyncInformant;
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
/// <reference types="node" />
|
|
2
|
+
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
import { Task } from '@/task';
|
|
4
|
+
import { TrainingInformant } from '@/training_informant';
|
|
5
|
+
import { Weights } from '@/types';
|
|
6
|
+
export declare abstract class Base {
|
|
7
|
+
readonly url: URL;
|
|
8
|
+
readonly task: Task;
|
|
9
|
+
constructor(url: URL, task: Task);
|
|
10
|
+
/**
|
|
11
|
+
* Handles the connection process from the client to any sort of
|
|
12
|
+
* centralized server.
|
|
13
|
+
*/
|
|
14
|
+
abstract connect(): Promise<void>;
|
|
15
|
+
/**
|
|
16
|
+
* Handles the disconnection process of the client from any sort
|
|
17
|
+
* of centralized server.
|
|
18
|
+
*/
|
|
19
|
+
abstract disconnect(): Promise<void>;
|
|
20
|
+
getLatestModel(): Promise<tf.LayersModel>;
|
|
21
|
+
/**
|
|
22
|
+
* The training manager matches this function with the training loop's
|
|
23
|
+
* onTrainEnd callback when training a TFJS model object. See the
|
|
24
|
+
* training manager for more details.
|
|
25
|
+
*/
|
|
26
|
+
abstract onTrainEndCommunication(weights: Weights, trainingInformant: TrainingInformant): Promise<void>;
|
|
27
|
+
/**
|
|
28
|
+
* This function will be called whenever a local round has ended.
|
|
29
|
+
*
|
|
30
|
+
* @param updatedWeights
|
|
31
|
+
* @param staleWeights
|
|
32
|
+
* @param round
|
|
33
|
+
* @param trainingInformant
|
|
34
|
+
*/
|
|
35
|
+
abstract onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<Weights>;
|
|
36
|
+
}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Base = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var axios_1 = (0, tslib_1.__importDefault)(require("axios"));
|
|
6
|
+
var serialization = (0, tslib_1.__importStar)(require("../serialization"));
|
|
7
|
+
var Base = /** @class */ (function () {
|
|
8
|
+
function Base(url, task) {
|
|
9
|
+
this.url = url;
|
|
10
|
+
this.task = task;
|
|
11
|
+
}
|
|
12
|
+
Base.prototype.getLatestModel = function () {
|
|
13
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
14
|
+
var url, response;
|
|
15
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
16
|
+
switch (_a.label) {
|
|
17
|
+
case 0:
|
|
18
|
+
url = new URL('', this.url.href);
|
|
19
|
+
if (!url.pathname.endsWith('/')) {
|
|
20
|
+
url.pathname += '/';
|
|
21
|
+
}
|
|
22
|
+
url.pathname += "tasks/" + this.task.taskID + "/model.json";
|
|
23
|
+
return [4 /*yield*/, axios_1.default.get(url.href)];
|
|
24
|
+
case 1:
|
|
25
|
+
response = _a.sent();
|
|
26
|
+
return [4 /*yield*/, serialization.model.decode(response.data)];
|
|
27
|
+
case 2: return [2 /*return*/, _a.sent()];
|
|
28
|
+
}
|
|
29
|
+
});
|
|
30
|
+
});
|
|
31
|
+
};
|
|
32
|
+
return Base;
|
|
33
|
+
}());
|
|
34
|
+
exports.Base = Base;
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import { TrainingInformant, Weights } from '..';
|
|
2
|
+
import { Base } from './base';
|
|
3
|
+
/**
|
|
4
|
+
* Class that deals with communication with the PeerJS server.
|
|
5
|
+
* Collects the list of receivers currently connected to the PeerJS server.
|
|
6
|
+
*/
|
|
7
|
+
export declare class Decentralized extends Base {
|
|
8
|
+
private server?;
|
|
9
|
+
private peers;
|
|
10
|
+
private readonly weights;
|
|
11
|
+
private connectServer;
|
|
12
|
+
private connectNewPeer;
|
|
13
|
+
/**
|
|
14
|
+
* Initialize the connection to the peers and to the other nodes.
|
|
15
|
+
*/
|
|
16
|
+
connect(): Promise<void>;
|
|
17
|
+
/**
|
|
18
|
+
* Disconnection process when user quits the task.
|
|
19
|
+
*/
|
|
20
|
+
disconnect(): Promise<void>;
|
|
21
|
+
onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, epoch: number, trainingInformant: TrainingInformant): Promise<Weights>;
|
|
22
|
+
onTrainEndCommunication(_: Weights, trainingInformant: TrainingInformant): Promise<void>;
|
|
23
|
+
}
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Decentralized = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var isomorphic_ws_1 = (0, tslib_1.__importDefault)(require("isomorphic-ws"));
|
|
7
|
+
var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
|
|
8
|
+
var simple_peer_1 = (0, tslib_1.__importDefault)(require("simple-peer"));
|
|
9
|
+
var url_1 = require("url");
|
|
10
|
+
var __1 = require("..");
|
|
11
|
+
var base_1 = require("./base");
|
|
12
|
+
function isPeerMessage(data) {
|
|
13
|
+
if (typeof data !== 'object') {
|
|
14
|
+
return false;
|
|
15
|
+
}
|
|
16
|
+
if (data === null) {
|
|
17
|
+
return false;
|
|
18
|
+
}
|
|
19
|
+
if (!(0, immutable_1.Set)(Object.keys(data)).equals(immutable_1.Set.of('epoch', 'weights'))) {
|
|
20
|
+
return false;
|
|
21
|
+
}
|
|
22
|
+
var _a = data, epoch = _a.epoch, weights = _a.weights;
|
|
23
|
+
if (typeof epoch !== 'number' ||
|
|
24
|
+
!__1.serialization.weights.isEncoded(weights)) {
|
|
25
|
+
return false;
|
|
26
|
+
}
|
|
27
|
+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
|
28
|
+
var _ = { epoch: epoch, weights: weights };
|
|
29
|
+
return true;
|
|
30
|
+
}
|
|
31
|
+
function isServerOpeningMessage(msg) {
|
|
32
|
+
if (!(msg instanceof Array)) {
|
|
33
|
+
return false;
|
|
34
|
+
}
|
|
35
|
+
if (!msg.every(function (elem) { return typeof elem === 'number'; })) {
|
|
36
|
+
return false;
|
|
37
|
+
}
|
|
38
|
+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
|
39
|
+
var _ = msg;
|
|
40
|
+
return true;
|
|
41
|
+
}
|
|
42
|
+
function isServerPeerMessage(msg) {
|
|
43
|
+
if (!(msg instanceof Array)) {
|
|
44
|
+
return false;
|
|
45
|
+
}
|
|
46
|
+
if (msg.length !== 2) {
|
|
47
|
+
return false;
|
|
48
|
+
}
|
|
49
|
+
var _a = (0, tslib_1.__read)(msg, 2), id = _a[0], signal = _a[1];
|
|
50
|
+
if (typeof id !== 'number') {
|
|
51
|
+
return false;
|
|
52
|
+
}
|
|
53
|
+
if (!(signal instanceof Uint8Array)) {
|
|
54
|
+
return false;
|
|
55
|
+
}
|
|
56
|
+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
|
57
|
+
var _ = [id, signal];
|
|
58
|
+
return true;
|
|
59
|
+
}
|
|
60
|
+
// Time to wait between network checks in milliseconds.
|
|
61
|
+
var TICK = 100;
|
|
62
|
+
// Time to wait for the others in milliseconds.
|
|
63
|
+
var MAX_WAIT_PER_ROUND = 10000;
|
|
64
|
+
/**
|
|
65
|
+
* Class that deals with communication with the PeerJS server.
|
|
66
|
+
* Collects the list of receivers currently connected to the PeerJS server.
|
|
67
|
+
*/
|
|
68
|
+
var Decentralized = /** @class */ (function (_super) {
|
|
69
|
+
(0, tslib_1.__extends)(Decentralized, _super);
|
|
70
|
+
function Decentralized() {
|
|
71
|
+
var _this = _super !== null && _super.apply(this, arguments) || this;
|
|
72
|
+
_this.peers = (0, immutable_1.Map)();
|
|
73
|
+
_this.weights = (0, immutable_1.Map)();
|
|
74
|
+
return _this;
|
|
75
|
+
}
|
|
76
|
+
Decentralized.prototype.connectServer = function (url) {
|
|
77
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
78
|
+
var ws;
|
|
79
|
+
var _this = this;
|
|
80
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
81
|
+
switch (_a.label) {
|
|
82
|
+
case 0:
|
|
83
|
+
ws = new isomorphic_ws_1.default.WebSocket(url);
|
|
84
|
+
ws.binaryType = 'arraybuffer';
|
|
85
|
+
ws.onmessage = function (event) {
|
|
86
|
+
if (!(event.data instanceof ArrayBuffer)) {
|
|
87
|
+
throw new Error('server did not send an ArrayBuffer');
|
|
88
|
+
}
|
|
89
|
+
var msg = msgpack_lite_1.default.decode(new Uint8Array(event.data));
|
|
90
|
+
if (isServerOpeningMessage(msg)) {
|
|
91
|
+
console.debug('server sent us the list of peer to connect to:', msg);
|
|
92
|
+
if (_this.peers.size !== 0) {
|
|
93
|
+
throw new Error('server already gave us a list of peers');
|
|
94
|
+
}
|
|
95
|
+
_this.peers = (0, immutable_1.Map)((0, immutable_1.List)(msg)
|
|
96
|
+
.map(function (id) { return [id, _this.connectNewPeer(id, true)]; }));
|
|
97
|
+
}
|
|
98
|
+
else if (isServerPeerMessage(msg)) {
|
|
99
|
+
var _a = (0, tslib_1.__read)(msg, 2), peerID = _a[0], encodedSignal = _a[1];
|
|
100
|
+
var signal = msgpack_lite_1.default.decode(encodedSignal);
|
|
101
|
+
console.debug('server on behalf of', peerID, 'sent', signal);
|
|
102
|
+
var peer = _this.peers.get(peerID);
|
|
103
|
+
if (peer === undefined) {
|
|
104
|
+
peer = _this.connectNewPeer(peerID, false);
|
|
105
|
+
_this.peers = _this.peers.set(peerID, peer);
|
|
106
|
+
}
|
|
107
|
+
peer.signal(signal);
|
|
108
|
+
}
|
|
109
|
+
else {
|
|
110
|
+
throw new Error('send sent an invalid msg');
|
|
111
|
+
}
|
|
112
|
+
};
|
|
113
|
+
return [4 /*yield*/, new Promise(function (resolve, reject) {
|
|
114
|
+
ws.onerror = function (err) { return reject(new Error("connecting server: " + err)); };
|
|
115
|
+
ws.onopen = function () { return resolve(ws); };
|
|
116
|
+
})];
|
|
117
|
+
case 1: return [2 /*return*/, _a.sent()];
|
|
118
|
+
}
|
|
119
|
+
});
|
|
120
|
+
});
|
|
121
|
+
};
|
|
122
|
+
// connect a new peer
|
|
123
|
+
//
|
|
124
|
+
// if initiator is true, we start the connection on our side
|
|
125
|
+
// see SimplePeer.Options.initiator for more info
|
|
126
|
+
Decentralized.prototype.connectNewPeer = function (peerID, initiator) {
|
|
127
|
+
var _this = this;
|
|
128
|
+
console.debug('connect new peer with initiator: ', initiator);
|
|
129
|
+
var peer = new simple_peer_1.default({
|
|
130
|
+
initiator: initiator,
|
|
131
|
+
config: {
|
|
132
|
+
iceServers: (0, immutable_1.List)(simple_peer_1.default.config.iceServers)
|
|
133
|
+
/* .push({
|
|
134
|
+
urls: 'turn:34.77.172.69:3478',
|
|
135
|
+
credential: 'deai',
|
|
136
|
+
username: 'deai'
|
|
137
|
+
}) */
|
|
138
|
+
.toArray()
|
|
139
|
+
}
|
|
140
|
+
});
|
|
141
|
+
peer.on('signal', function (signal) {
|
|
142
|
+
console.debug('local', peerID, 'is signaling', signal);
|
|
143
|
+
if (_this.server === undefined) {
|
|
144
|
+
throw new Error('server closed but received a signal');
|
|
145
|
+
}
|
|
146
|
+
var msg = [peerID, msgpack_lite_1.default.encode(signal)];
|
|
147
|
+
_this.server.send(msgpack_lite_1.default.encode(msg));
|
|
148
|
+
});
|
|
149
|
+
peer.on('data', function (data) {
|
|
150
|
+
var _a;
|
|
151
|
+
var message = msgpack_lite_1.default.decode(data);
|
|
152
|
+
if (!isPeerMessage(message)) {
|
|
153
|
+
throw new Error("invalid message received from " + peerID);
|
|
154
|
+
}
|
|
155
|
+
var weights = __1.serialization.weights.decode(message.weights);
|
|
156
|
+
console.debug('peer', peerID, 'sent weights', weights);
|
|
157
|
+
if (((_a = _this.weights.get(peer)) === null || _a === void 0 ? void 0 : _a.get(message.epoch)) !== undefined) {
|
|
158
|
+
throw new Error("weights from " + peerID + " already received");
|
|
159
|
+
}
|
|
160
|
+
_this.weights.set(peer, _this.weights.get(peer, (0, immutable_1.List)())
|
|
161
|
+
.set(message.epoch, weights));
|
|
162
|
+
});
|
|
163
|
+
peer.on('connect', function () { return console.info('connected to peer', peerID); });
|
|
164
|
+
// TODO better error handling
|
|
165
|
+
peer.on('error', function (err) { throw err; });
|
|
166
|
+
return peer;
|
|
167
|
+
};
|
|
168
|
+
/**
|
|
169
|
+
* Initialize the connection to the peers and to the other nodes.
|
|
170
|
+
*/
|
|
171
|
+
Decentralized.prototype.connect = function () {
|
|
172
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
173
|
+
var serverURL, _a;
|
|
174
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
175
|
+
switch (_b.label) {
|
|
176
|
+
case 0:
|
|
177
|
+
serverURL = new url_1.URL('', this.url.href);
|
|
178
|
+
serverURL.pathname += "/deai/tasks/" + this.task.taskID;
|
|
179
|
+
_a = this;
|
|
180
|
+
return [4 /*yield*/, this.connectServer(serverURL)];
|
|
181
|
+
case 1:
|
|
182
|
+
_a.server = _b.sent();
|
|
183
|
+
return [2 /*return*/];
|
|
184
|
+
}
|
|
185
|
+
});
|
|
186
|
+
});
|
|
187
|
+
};
|
|
188
|
+
/**
|
|
189
|
+
* Disconnection process when user quits the task.
|
|
190
|
+
*/
|
|
191
|
+
Decentralized.prototype.disconnect = function () {
|
|
192
|
+
var _a;
|
|
193
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
194
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
195
|
+
this.peers.forEach(function (peer) { return peer.destroy(); });
|
|
196
|
+
this.peers = (0, immutable_1.Map)();
|
|
197
|
+
(_a = this.server) === null || _a === void 0 ? void 0 : _a.close();
|
|
198
|
+
this.server = undefined;
|
|
199
|
+
return [2 /*return*/];
|
|
200
|
+
});
|
|
201
|
+
});
|
|
202
|
+
};
|
|
203
|
+
Decentralized.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, epoch, trainingInformant) {
|
|
204
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
205
|
+
var noisyWeights, msg, encodedMsg, getWeights, timeoutError, receivedWeights;
|
|
206
|
+
var _a;
|
|
207
|
+
var _this = this;
|
|
208
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
209
|
+
switch (_b.label) {
|
|
210
|
+
case 0:
|
|
211
|
+
noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
|
|
212
|
+
_a = {
|
|
213
|
+
epoch: epoch
|
|
214
|
+
};
|
|
215
|
+
return [4 /*yield*/, __1.serialization.weights.encode(noisyWeights)];
|
|
216
|
+
case 1:
|
|
217
|
+
msg = (_a.weights = _b.sent(),
|
|
218
|
+
_a);
|
|
219
|
+
encodedMsg = msgpack_lite_1.default.encode(msg);
|
|
220
|
+
this.peers
|
|
221
|
+
.filter(function (peer) { return peer.connected; })
|
|
222
|
+
.forEach(function (peer, peerID) {
|
|
223
|
+
trainingInformant.addMessage("Sending weights to peer " + peerID);
|
|
224
|
+
trainingInformant.updateWhoReceivedMyModel("peer " + peerID);
|
|
225
|
+
peer.send(encodedMsg);
|
|
226
|
+
});
|
|
227
|
+
getWeights = function () {
|
|
228
|
+
return _this.weights
|
|
229
|
+
.valueSeq()
|
|
230
|
+
.map(function (epochesWeights) { return epochesWeights.get(epoch); });
|
|
231
|
+
};
|
|
232
|
+
timeoutError = new Error('timeout');
|
|
233
|
+
return [4 /*yield*/, new Promise(function (resolve, reject) {
|
|
234
|
+
var interval = setInterval(function () {
|
|
235
|
+
var gotAllWeights = getWeights().every(function (weights) { return weights !== undefined; });
|
|
236
|
+
if (gotAllWeights) {
|
|
237
|
+
clearInterval(interval);
|
|
238
|
+
resolve();
|
|
239
|
+
}
|
|
240
|
+
}, TICK);
|
|
241
|
+
setTimeout(function () {
|
|
242
|
+
clearInterval(interval);
|
|
243
|
+
reject(timeoutError);
|
|
244
|
+
}, MAX_WAIT_PER_ROUND);
|
|
245
|
+
}).catch(function (err) {
|
|
246
|
+
if (err !== timeoutError) {
|
|
247
|
+
throw err;
|
|
248
|
+
}
|
|
249
|
+
})];
|
|
250
|
+
case 2:
|
|
251
|
+
_b.sent();
|
|
252
|
+
receivedWeights = getWeights()
|
|
253
|
+
.filter(function (weights) { return weights !== undefined; })
|
|
254
|
+
.toSet();
|
|
255
|
+
// Average weights
|
|
256
|
+
trainingInformant.addMessage('Averaging weights');
|
|
257
|
+
trainingInformant.updateNbrUpdatesWithOthers(1);
|
|
258
|
+
// Return the new "received" weights
|
|
259
|
+
return [2 /*return*/, __1.aggregation.averageWeights(receivedWeights)];
|
|
260
|
+
}
|
|
261
|
+
});
|
|
262
|
+
});
|
|
263
|
+
};
|
|
264
|
+
Decentralized.prototype.onTrainEndCommunication = function (_, trainingInformant) {
|
|
265
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
266
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
267
|
+
// TODO: enter seeding mode?
|
|
268
|
+
trainingInformant.addMessage('Training finished.');
|
|
269
|
+
return [2 /*return*/];
|
|
270
|
+
});
|
|
271
|
+
});
|
|
272
|
+
};
|
|
273
|
+
return Decentralized;
|
|
274
|
+
}(base_1.Base));
|
|
275
|
+
exports.Decentralized = Decentralized;
|