@epfml/discojs 0.0.1 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +12 -11
- package/dist/aggregation.d.ts +4 -2
- package/dist/aggregation.js +20 -6
- package/dist/client/base.d.ts +1 -1
- package/dist/client/decentralized/base.d.ts +43 -0
- package/dist/client/decentralized/base.js +243 -0
- package/dist/client/decentralized/clear_text.d.ts +13 -0
- package/dist/client/decentralized/clear_text.js +78 -0
- package/dist/client/decentralized/index.d.ts +4 -0
- package/dist/client/decentralized/index.js +9 -0
- package/dist/client/decentralized/messages.d.ts +37 -0
- package/dist/client/decentralized/messages.js +15 -0
- package/dist/client/decentralized/sec_agg.d.ts +18 -0
- package/dist/client/decentralized/sec_agg.js +169 -0
- package/dist/client/decentralized/secret_shares.d.ts +5 -0
- package/dist/client/decentralized/secret_shares.js +58 -0
- package/dist/client/decentralized/types.d.ts +1 -0
- package/dist/client/decentralized/types.js +2 -0
- package/dist/client/federated.d.ts +4 -4
- package/dist/client/federated.js +5 -8
- package/dist/client/index.d.ts +1 -1
- package/dist/client/index.js +3 -3
- package/dist/client/local.js +5 -3
- package/dist/dataset/data_loader/data_loader.d.ts +7 -1
- package/dist/dataset/data_loader/image_loader.d.ts +5 -3
- package/dist/dataset/data_loader/image_loader.js +64 -18
- package/dist/dataset/data_loader/index.d.ts +1 -1
- package/dist/dataset/data_loader/tabular_loader.d.ts +3 -3
- package/dist/dataset/data_loader/tabular_loader.js +27 -18
- package/dist/dataset/dataset_builder.d.ts +3 -2
- package/dist/dataset/dataset_builder.js +29 -17
- package/dist/index.d.ts +6 -4
- package/dist/index.js +13 -5
- package/dist/informant/graph_informant.d.ts +10 -0
- package/dist/informant/graph_informant.js +23 -0
- package/dist/informant/index.d.ts +3 -0
- package/dist/informant/index.js +9 -0
- package/dist/informant/training_informant/base.d.ts +31 -0
- package/dist/informant/training_informant/base.js +82 -0
- package/dist/informant/training_informant/decentralized.d.ts +5 -0
- package/dist/informant/training_informant/decentralized.js +22 -0
- package/dist/informant/training_informant/federated.d.ts +14 -0
- package/dist/informant/training_informant/federated.js +32 -0
- package/dist/informant/training_informant/index.d.ts +4 -0
- package/dist/informant/training_informant/index.js +11 -0
- package/dist/informant/training_informant/local.d.ts +6 -0
- package/dist/informant/training_informant/local.js +20 -0
- package/dist/logging/index.d.ts +1 -0
- package/dist/logging/index.js +3 -1
- package/dist/logging/trainer_logger.d.ts +1 -1
- package/dist/logging/trainer_logger.js +5 -5
- package/dist/memory/base.d.ts +17 -48
- package/dist/memory/empty.d.ts +6 -4
- package/dist/memory/empty.js +8 -2
- package/dist/memory/index.d.ts +1 -1
- package/dist/privacy.js +3 -3
- package/dist/serialization/model.d.ts +1 -1
- package/dist/serialization/model.js +2 -2
- package/dist/serialization/weights.js +2 -2
- package/dist/task/display_information.d.ts +2 -2
- package/dist/task/display_information.js +6 -5
- package/dist/task/summary.d.ts +5 -0
- package/dist/task/summary.js +23 -0
- package/dist/task/training_information.d.ts +3 -0
- package/dist/tasks/cifar10.js +5 -3
- package/dist/tasks/lus_covid.d.ts +1 -1
- package/dist/tasks/lus_covid.js +50 -13
- package/dist/tasks/mnist.js +4 -2
- package/dist/tasks/simple_face.d.ts +1 -1
- package/dist/tasks/simple_face.js +13 -17
- package/dist/tasks/titanic.js +9 -7
- package/dist/tfjs.d.ts +2 -0
- package/dist/tfjs.js +6 -0
- package/dist/training/disco.d.ts +3 -1
- package/dist/training/disco.js +14 -6
- package/dist/training/trainer/distributed_trainer.d.ts +1 -1
- package/dist/training/trainer/distributed_trainer.js +5 -1
- package/dist/training/trainer/local_trainer.d.ts +4 -3
- package/dist/training/trainer/local_trainer.js +6 -9
- package/dist/training/trainer/round_tracker.js +3 -0
- package/dist/training/trainer/trainer.d.ts +15 -15
- package/dist/training/trainer/trainer.js +57 -43
- package/dist/training/trainer/trainer_builder.js +8 -15
- package/dist/types.d.ts +1 -1
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +5 -0
- package/dist/validation/validator.d.ts +20 -0
- package/dist/validation/validator.js +106 -0
- package/package.json +2 -3
- package/dist/client/decentralized.d.ts +0 -23
- package/dist/client/decentralized.js +0 -275
- package/dist/testing/tester.d.ts +0 -5
- package/dist/testing/tester.js +0 -21
- package/dist/training_informant.d.ts +0 -88
- package/dist/training_informant.js +0 -135
package/README.md
CHANGED
|
@@ -2,31 +2,32 @@
|
|
|
2
2
|
|
|
3
3
|
discojs contains the core code of disco.
|
|
4
4
|
|
|
5
|
-
##
|
|
5
|
+
## Installation
|
|
6
6
|
|
|
7
|
-
The app is running under Node
|
|
7
|
+
The app is running under Node.js v16. NPM is a package manager for the JavaScript runtime environment Node.js.
|
|
8
|
+
We recommend using [nvm](https://github.com/nvm-sh/nvm) for installing both Node.js and NPM.
|
|
8
9
|
|
|
9
|
-
|
|
10
|
-
To start the application (running locally) run the following command.
|
|
11
|
-
Note: the application is currently developed using [NPM 7.6.3](https://www.npmjs.com/package/npm/v/7.6.3).
|
|
10
|
+
To install the application's dependencies, run the following command:
|
|
12
11
|
|
|
13
12
|
```
|
|
14
|
-
npm
|
|
13
|
+
npm ci
|
|
15
14
|
```
|
|
16
15
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
> **⚠ WARNING: Apple Silicon.**
|
|
20
|
-
> `TensorFlow.js` in version `3.13.0` currently supports for M1 mac laptops. However, make sure you have an `arm` node executable installed (not `x86`). It can be checked using:
|
|
16
|
+
> **⚠ WARNING: Apple Silicon.**
|
|
17
|
+
> `TensorFlow.js` since version `3.13.0` and newer do support M1 processors for macs. To do so, make sure you have an `arm` node executable installed (not `x86`). It can be checked using:
|
|
21
18
|
|
|
22
19
|
```
|
|
23
20
|
node -p "process.arch"
|
|
24
21
|
```
|
|
25
22
|
|
|
23
|
+
which should return `arm64`.
|
|
24
|
+
|
|
26
25
|
## Build
|
|
27
26
|
|
|
28
|
-
In order to enable the
|
|
27
|
+
In order to enable the browser and server modules to use the `discojs` package, we must build discojs:
|
|
29
28
|
|
|
30
29
|
```
|
|
31
30
|
npm run build
|
|
32
31
|
```
|
|
32
|
+
|
|
33
|
+
This invokes the TypeScript compiler (`tsc`). To recompile from stratch, simply `rm -rf dist/` before running `npm run build` again.
|
package/dist/aggregation.d.ts
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { List } from 'immutable';
|
|
2
2
|
import { Weights } from './types';
|
|
3
|
-
export declare function
|
|
3
|
+
export declare function sumWeights(peersWeights: List<Weights>): Weights;
|
|
4
|
+
export declare function subtractWeights(peersWeights: List<Weights>): Weights;
|
|
5
|
+
export declare function averageWeights(peersWeights: List<Weights>): Weights;
|
package/dist/aggregation.js
CHANGED
|
@@ -1,19 +1,33 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.averageWeights = void 0;
|
|
4
|
-
|
|
3
|
+
exports.averageWeights = exports.subtractWeights = exports.sumWeights = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
|
|
6
|
+
function applyArithmeticToWeights(peersWeights, tfjsArithmeticFunction) {
|
|
5
7
|
var _a;
|
|
8
|
+
console.log('Aggregating a list of', peersWeights.size, 'weight vectors.');
|
|
6
9
|
var firstWeightSize = (_a = peersWeights.first()) === null || _a === void 0 ? void 0 : _a.length;
|
|
7
10
|
if (firstWeightSize === undefined) {
|
|
8
11
|
throw new Error('no weights to average');
|
|
9
12
|
}
|
|
10
13
|
if (!peersWeights.rest().every(function (ws) { return ws.length === firstWeightSize; })) {
|
|
11
|
-
throw new Error('
|
|
14
|
+
throw new Error('weights dimensions are different for some of the summands');
|
|
12
15
|
}
|
|
13
|
-
var numberOfPeers = peersWeights.size;
|
|
14
16
|
var peersAverageWeights = peersWeights.reduce(function (accum, weights) {
|
|
15
|
-
return accum.map(function (w, i) { return w
|
|
16
|
-
})
|
|
17
|
+
return accum.map(function (w, i) { return tfjsArithmeticFunction(w, weights[i]); });
|
|
18
|
+
});
|
|
17
19
|
return peersAverageWeights;
|
|
18
20
|
}
|
|
21
|
+
function sumWeights(peersWeights) {
|
|
22
|
+
return applyArithmeticToWeights(peersWeights, tf.add);
|
|
23
|
+
}
|
|
24
|
+
exports.sumWeights = sumWeights;
|
|
25
|
+
function subtractWeights(peersWeights) {
|
|
26
|
+
return applyArithmeticToWeights(peersWeights, tf.sub);
|
|
27
|
+
}
|
|
28
|
+
exports.subtractWeights = subtractWeights;
|
|
29
|
+
function averageWeights(peersWeights) {
|
|
30
|
+
var numberOfPeers = peersWeights.size;
|
|
31
|
+
return sumWeights(peersWeights).map(function (w) { return w.div(numberOfPeers); });
|
|
32
|
+
}
|
|
19
33
|
exports.averageWeights = averageWeights;
|
package/dist/client/base.d.ts
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
/// <reference types="node" />
|
|
2
2
|
import * as tf from '@tensorflow/tfjs';
|
|
3
3
|
import { Task } from '@/task';
|
|
4
|
-
import { TrainingInformant } from '@/
|
|
4
|
+
import { TrainingInformant } from '@/informant';
|
|
5
5
|
import { Weights } from '@/types';
|
|
6
6
|
export declare abstract class Base {
|
|
7
7
|
readonly url: URL;
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
/// <reference types="node" />
|
|
2
|
+
import { List } from 'immutable';
|
|
3
|
+
import isomorphic from 'isomorphic-ws';
|
|
4
|
+
import { URL } from 'url';
|
|
5
|
+
import { Task } from '@/task';
|
|
6
|
+
import { TrainingInformant, Weights } from '../..';
|
|
7
|
+
import { Base as ClientBase } from '../base';
|
|
8
|
+
import * as messages from './messages';
|
|
9
|
+
import { PeerID } from './types';
|
|
10
|
+
/**
|
|
11
|
+
* Abstract class for decentralized clients, executes onRoundEndCommunication as well as connecting
|
|
12
|
+
* to the signaling server
|
|
13
|
+
*/
|
|
14
|
+
export declare abstract class Base extends ClientBase {
|
|
15
|
+
readonly url: URL;
|
|
16
|
+
readonly task: Task;
|
|
17
|
+
protected minimumReadyPeers: number;
|
|
18
|
+
protected maxShareValue: number;
|
|
19
|
+
constructor(url: URL, task: Task);
|
|
20
|
+
protected server?: isomorphic.WebSocket;
|
|
21
|
+
protected peers: PeerID[];
|
|
22
|
+
protected peersLocked: boolean;
|
|
23
|
+
protected ID: number;
|
|
24
|
+
protected pauseUntil(condition: () => boolean): Promise<void>;
|
|
25
|
+
protected sendMessagetoPeer(message: unknown): void;
|
|
26
|
+
protected sendReadyMessage(round: number): void;
|
|
27
|
+
private instanceOfMessageGeneral;
|
|
28
|
+
private instanceOfServerClientIDMessage;
|
|
29
|
+
private instanceOfServerReadyClients;
|
|
30
|
+
protected connectServer(url: URL): Promise<isomorphic.WebSocket>;
|
|
31
|
+
/**
|
|
32
|
+
* Initialize the connection to the peers and to the other nodes.
|
|
33
|
+
*/
|
|
34
|
+
connect(): Promise<void>;
|
|
35
|
+
/**
|
|
36
|
+
* Disconnection process when user quits the task.
|
|
37
|
+
*/
|
|
38
|
+
disconnect(): Promise<void>;
|
|
39
|
+
onTrainEndCommunication(_: Weights, trainingInformant: TrainingInformant): Promise<void>;
|
|
40
|
+
onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<Weights>;
|
|
41
|
+
abstract sendAndReceiveWeights(noisyWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<List<Weights>>;
|
|
42
|
+
abstract clientHandle(msg: messages.messageGeneral): void;
|
|
43
|
+
}
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Base = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var isomorphic_ws_1 = (0, tslib_1.__importDefault)(require("isomorphic-ws"));
|
|
6
|
+
var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
|
|
7
|
+
var url_1 = require("url");
|
|
8
|
+
var __1 = require("../..");
|
|
9
|
+
var base_1 = require("../base");
|
|
10
|
+
var messages = (0, tslib_1.__importStar)(require("./messages"));
|
|
11
|
+
// Time to wait between network checks in milliseconds.
|
|
12
|
+
var TICK = 100;
|
|
13
|
+
// Time to wait for the others in milliseconds.
|
|
14
|
+
var MAX_WAIT_PER_ROUND = 10000;
|
|
15
|
+
/**
|
|
16
|
+
* Abstract class for decentralized clients, executes onRoundEndCommunication as well as connecting
|
|
17
|
+
* to the signaling server
|
|
18
|
+
*/
|
|
19
|
+
var Base = /** @class */ (function (_super) {
|
|
20
|
+
(0, tslib_1.__extends)(Base, _super);
|
|
21
|
+
function Base(url, task) {
|
|
22
|
+
var _a, _b, _c, _d;
|
|
23
|
+
var _this = _super.call(this, url, task) || this;
|
|
24
|
+
_this.url = url;
|
|
25
|
+
_this.task = task;
|
|
26
|
+
// list of peerIDs who the client will send messages to
|
|
27
|
+
_this.peers = [];
|
|
28
|
+
_this.peersLocked = false;
|
|
29
|
+
// the ID of the client, set arbitrarily to 0 but gets set an actual value once it cues the signaling server
|
|
30
|
+
// that it is ready to connect
|
|
31
|
+
_this.ID = 0;
|
|
32
|
+
_this.minimumReadyPeers = (_b = (_a = _this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.minimumReadyPeers) !== null && _b !== void 0 ? _b : 3;
|
|
33
|
+
_this.maxShareValue = (_d = (_c = _this.task.trainingInformation) === null || _c === void 0 ? void 0 : _c.maxShareValue) !== null && _d !== void 0 ? _d : 100;
|
|
34
|
+
return _this;
|
|
35
|
+
}
|
|
36
|
+
/*
|
|
37
|
+
function to check if a given boolean condition is true, checks continuously until maxWait time is reached
|
|
38
|
+
*/
|
|
39
|
+
Base.prototype.pauseUntil = function (condition) {
|
|
40
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
41
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
42
|
+
switch (_a.label) {
|
|
43
|
+
case 0: return [4 /*yield*/, new Promise(function (resolve, reject) {
|
|
44
|
+
var timeWas = new Date().getTime();
|
|
45
|
+
var wait = setInterval(function () {
|
|
46
|
+
if (condition()) {
|
|
47
|
+
console.log('resolved after', new Date().getTime() - timeWas, 'ms');
|
|
48
|
+
clearInterval(wait);
|
|
49
|
+
resolve();
|
|
50
|
+
}
|
|
51
|
+
else if (new Date().getTime() - timeWas > MAX_WAIT_PER_ROUND) { // Timeout
|
|
52
|
+
console.log('rejected after', new Date().getTime() - timeWas, 'ms');
|
|
53
|
+
clearInterval(wait);
|
|
54
|
+
reject(new Error('timeout'));
|
|
55
|
+
}
|
|
56
|
+
}, TICK);
|
|
57
|
+
})];
|
|
58
|
+
case 1: return [2 /*return*/, _a.sent()];
|
|
59
|
+
}
|
|
60
|
+
});
|
|
61
|
+
});
|
|
62
|
+
};
|
|
63
|
+
/*
|
|
64
|
+
function behavior will change with peer 2 peer, sends message to its destination, currently through server
|
|
65
|
+
*/
|
|
66
|
+
Base.prototype.sendMessagetoPeer = function (message) {
|
|
67
|
+
if (this.server === undefined) {
|
|
68
|
+
throw new Error("Undefined Server, can't send message");
|
|
69
|
+
}
|
|
70
|
+
this.server.send(message);
|
|
71
|
+
};
|
|
72
|
+
/*
|
|
73
|
+
send message to server that client is ready
|
|
74
|
+
*/
|
|
75
|
+
Base.prototype.sendReadyMessage = function (round) {
|
|
76
|
+
// Broadcast our readiness
|
|
77
|
+
var msg = { type: messages.messageType.clientReadyMessage, round: round };
|
|
78
|
+
var encodedMsg = msgpack_lite_1.default.encode(msg);
|
|
79
|
+
if (this.server === undefined) {
|
|
80
|
+
throw new Error('server undefined, could not connect peers');
|
|
81
|
+
}
|
|
82
|
+
this.server.send(encodedMsg);
|
|
83
|
+
};
|
|
84
|
+
/*
|
|
85
|
+
checks if message is of type messageGeneral. If it is, the specific message.type can be identified.
|
|
86
|
+
*/
|
|
87
|
+
Base.prototype.instanceOfMessageGeneral = function (msg) {
|
|
88
|
+
return typeof msg === 'object' && msg !== null && 'type' in msg;
|
|
89
|
+
};
|
|
90
|
+
/*
|
|
91
|
+
checks if message contains the client's ID number
|
|
92
|
+
*/
|
|
93
|
+
Base.prototype.instanceOfServerClientIDMessage = function (msg) {
|
|
94
|
+
return msg.type === messages.messageType.serverClientIDMessage;
|
|
95
|
+
};
|
|
96
|
+
/*
|
|
97
|
+
checks if message contains the list of peerIDs that are ready to share updates
|
|
98
|
+
*/
|
|
99
|
+
Base.prototype.instanceOfServerReadyClients = function (msg) {
|
|
100
|
+
return msg.type === messages.messageType.serverReadyClients;
|
|
101
|
+
};
|
|
102
|
+
/*
|
|
103
|
+
creation of the websocket for the server, connection of client to that webSocket,
|
|
104
|
+
deals with message reception from decentralized client perspective (messages received by client)
|
|
105
|
+
*/
|
|
106
|
+
Base.prototype.connectServer = function (url) {
|
|
107
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
108
|
+
var ws;
|
|
109
|
+
var _this = this;
|
|
110
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
111
|
+
switch (_a.label) {
|
|
112
|
+
case 0:
|
|
113
|
+
ws = new isomorphic_ws_1.default.WebSocket(url);
|
|
114
|
+
ws.binaryType = 'arraybuffer';
|
|
115
|
+
ws.onmessage = function (event) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
|
|
116
|
+
var msg;
|
|
117
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
118
|
+
if (!(event.data instanceof ArrayBuffer)) {
|
|
119
|
+
throw new Error('server did not send an ArrayBuffer');
|
|
120
|
+
}
|
|
121
|
+
msg = msgpack_lite_1.default.decode(new Uint8Array(event.data));
|
|
122
|
+
// check message type to choose correct action
|
|
123
|
+
if (this.instanceOfMessageGeneral(msg)) {
|
|
124
|
+
if (this.instanceOfServerClientIDMessage(msg)) {
|
|
125
|
+
// updated ID
|
|
126
|
+
this.ID = msg.peerID;
|
|
127
|
+
}
|
|
128
|
+
else if (this.instanceOfServerReadyClients(msg)) {
|
|
129
|
+
// updated connected peers
|
|
130
|
+
if (!this.peersLocked) {
|
|
131
|
+
this.peers = msg.peerList;
|
|
132
|
+
this.peersLocked = true;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
else {
|
|
136
|
+
this.clientHandle(msg);
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
return [2 /*return*/];
|
|
140
|
+
});
|
|
141
|
+
}); };
|
|
142
|
+
return [4 /*yield*/, new Promise(function (resolve, reject) {
|
|
143
|
+
ws.onerror = function (err) {
|
|
144
|
+
return reject(new Error("connecting server: " + err.message));
|
|
145
|
+
}; // eslint-disable-line @typescript-eslint/restrict-template-expressions
|
|
146
|
+
ws.onopen = function () { return resolve(ws); };
|
|
147
|
+
})];
|
|
148
|
+
case 1: return [2 /*return*/, _a.sent()];
|
|
149
|
+
}
|
|
150
|
+
});
|
|
151
|
+
});
|
|
152
|
+
};
|
|
153
|
+
/**
|
|
154
|
+
* Initialize the connection to the peers and to the other nodes.
|
|
155
|
+
*/
|
|
156
|
+
Base.prototype.connect = function () {
|
|
157
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
158
|
+
var serverURL, _a;
|
|
159
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
160
|
+
switch (_b.label) {
|
|
161
|
+
case 0:
|
|
162
|
+
serverURL = new url_1.URL('', this.url.href);
|
|
163
|
+
switch (this.url.protocol) {
|
|
164
|
+
case 'http:':
|
|
165
|
+
serverURL.protocol = 'ws:';
|
|
166
|
+
break;
|
|
167
|
+
case 'https:':
|
|
168
|
+
serverURL.protocol = 'wss:';
|
|
169
|
+
break;
|
|
170
|
+
default:
|
|
171
|
+
throw new Error("unknown protocol: " + this.url.protocol);
|
|
172
|
+
}
|
|
173
|
+
serverURL.pathname += "deai/" + this.task.taskID;
|
|
174
|
+
_a = this;
|
|
175
|
+
return [4 /*yield*/, this.connectServer(serverURL)];
|
|
176
|
+
case 1:
|
|
177
|
+
_a.server = _b.sent();
|
|
178
|
+
return [2 /*return*/];
|
|
179
|
+
}
|
|
180
|
+
});
|
|
181
|
+
});
|
|
182
|
+
};
|
|
183
|
+
/**
|
|
184
|
+
* Disconnection process when user quits the task.
|
|
185
|
+
*/
|
|
186
|
+
Base.prototype.disconnect = function () {
|
|
187
|
+
var _a;
|
|
188
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
189
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
190
|
+
// this.peers.forEach((peer) => peer.destroy())
|
|
191
|
+
// this.peers = Map()
|
|
192
|
+
(_a = this.server) === null || _a === void 0 ? void 0 : _a.close();
|
|
193
|
+
this.server = undefined;
|
|
194
|
+
return [2 /*return*/];
|
|
195
|
+
});
|
|
196
|
+
});
|
|
197
|
+
};
|
|
198
|
+
Base.prototype.onTrainEndCommunication = function (_, trainingInformant) {
|
|
199
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
200
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
201
|
+
// TODO: enter seeding mode?
|
|
202
|
+
trainingInformant.addMessage('Training finished.');
|
|
203
|
+
return [2 /*return*/];
|
|
204
|
+
});
|
|
205
|
+
});
|
|
206
|
+
};
|
|
207
|
+
Base.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, round, trainingInformant) {
|
|
208
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
209
|
+
var noisyWeights, finalWeights, Error_1;
|
|
210
|
+
var _this = this;
|
|
211
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
212
|
+
switch (_a.label) {
|
|
213
|
+
case 0:
|
|
214
|
+
_a.trys.push([0, 3, , 4]);
|
|
215
|
+
// reset peer list at each round of training to make sure client waits for updated peerList from server
|
|
216
|
+
this.peers = [];
|
|
217
|
+
this.peersLocked = false;
|
|
218
|
+
// centralized phase of communication --> client tells server that they have finished a local round and are ready to aggregate
|
|
219
|
+
this.sendReadyMessage(round);
|
|
220
|
+
// wait for peers to be connected before sending any update information
|
|
221
|
+
return [4 /*yield*/, this.pauseUntil(function () { return _this.peers.length >= _this.minimumReadyPeers; })
|
|
222
|
+
// Apply DP to updates that will be sent
|
|
223
|
+
];
|
|
224
|
+
case 1:
|
|
225
|
+
// wait for peers to be connected before sending any update information
|
|
226
|
+
_a.sent();
|
|
227
|
+
noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
|
|
228
|
+
return [4 /*yield*/, this.sendAndReceiveWeights(noisyWeights, round, trainingInformant)];
|
|
229
|
+
case 2:
|
|
230
|
+
finalWeights = _a.sent();
|
|
231
|
+
return [2 /*return*/, __1.aggregation.averageWeights(finalWeights)];
|
|
232
|
+
case 3:
|
|
233
|
+
Error_1 = _a.sent();
|
|
234
|
+
console.log('Timeout Error Reported, training will continue');
|
|
235
|
+
return [2 /*return*/, updatedWeights];
|
|
236
|
+
case 4: return [2 /*return*/];
|
|
237
|
+
}
|
|
238
|
+
});
|
|
239
|
+
});
|
|
240
|
+
};
|
|
241
|
+
return Base;
|
|
242
|
+
}(base_1.Base));
|
|
243
|
+
exports.Base = Base;
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import { TrainingInformant, Weights } from '../..';
|
|
3
|
+
import { Base } from './base';
|
|
4
|
+
import * as messages from './messages';
|
|
5
|
+
/**
|
|
6
|
+
* Decentralized client that does not utilize secure aggregation, but sends model updates in clear text
|
|
7
|
+
*/
|
|
8
|
+
export declare class ClearText extends Base {
|
|
9
|
+
protected receivedWeights: List<Weights>;
|
|
10
|
+
sendAndReceiveWeights(noisyWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<List<Weights>>;
|
|
11
|
+
private instanceOfClientWeightsMessageServer;
|
|
12
|
+
clientHandle(msg: messages.messageGeneral): void;
|
|
13
|
+
}
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ClearText = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
|
|
7
|
+
var __1 = require("../..");
|
|
8
|
+
var base_1 = require("./base");
|
|
9
|
+
var messages = (0, tslib_1.__importStar)(require("./messages"));
|
|
10
|
+
/**
|
|
11
|
+
* Decentralized client that does not utilize secure aggregation, but sends model updates in clear text
|
|
12
|
+
*/
|
|
13
|
+
var ClearText = /** @class */ (function (_super) {
|
|
14
|
+
(0, tslib_1.__extends)(ClearText, _super);
|
|
15
|
+
function ClearText() {
|
|
16
|
+
var _this = _super !== null && _super.apply(this, arguments) || this;
|
|
17
|
+
// list of weights received from other clients
|
|
18
|
+
_this.receivedWeights = (0, immutable_1.List)();
|
|
19
|
+
return _this;
|
|
20
|
+
}
|
|
21
|
+
ClearText.prototype.sendAndReceiveWeights = function (noisyWeights, round, trainingInformant) {
|
|
22
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
23
|
+
var weightsToSend, i, msg, encodedMsg;
|
|
24
|
+
var _this = this;
|
|
25
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
26
|
+
switch (_a.label) {
|
|
27
|
+
case 0:
|
|
28
|
+
// reset received fields at beginning of each round
|
|
29
|
+
this.receivedWeights = this.receivedWeights.clear();
|
|
30
|
+
return [4 /*yield*/, __1.serialization.weights.encode(noisyWeights)];
|
|
31
|
+
case 1:
|
|
32
|
+
weightsToSend = _a.sent();
|
|
33
|
+
if (this.server === undefined) {
|
|
34
|
+
throw new Error('server undefined so we cannot send weights through it');
|
|
35
|
+
}
|
|
36
|
+
// PHASE 1 COMMUNICATION --> create weights message and send to all peers (only one phase of communication for clear_text)
|
|
37
|
+
for (i = 0; i < this.peers.length; i++) {
|
|
38
|
+
msg = {
|
|
39
|
+
type: messages.messageType.clientWeightsMessageServer,
|
|
40
|
+
peerID: this.ID,
|
|
41
|
+
weights: weightsToSend,
|
|
42
|
+
destination: this.peers[i]
|
|
43
|
+
};
|
|
44
|
+
encodedMsg = msgpack_lite_1.default.encode(msg);
|
|
45
|
+
this.sendMessagetoPeer(encodedMsg);
|
|
46
|
+
}
|
|
47
|
+
// wait to receive all weights from peers
|
|
48
|
+
return [4 /*yield*/, this.pauseUntil(function () { return _this.receivedWeights.size >= _this.peers.length; })];
|
|
49
|
+
case 2:
|
|
50
|
+
// wait to receive all weights from peers
|
|
51
|
+
_a.sent();
|
|
52
|
+
return [2 /*return*/, this.receivedWeights];
|
|
53
|
+
}
|
|
54
|
+
});
|
|
55
|
+
});
|
|
56
|
+
};
|
|
57
|
+
/*
|
|
58
|
+
checks if message contains weights from a peer
|
|
59
|
+
*/
|
|
60
|
+
ClearText.prototype.instanceOfClientWeightsMessageServer = function (msg) {
|
|
61
|
+
return msg.type === messages.messageType.clientWeightsMessageServer;
|
|
62
|
+
};
|
|
63
|
+
/*
|
|
64
|
+
handles received messages from signaling server
|
|
65
|
+
*/
|
|
66
|
+
ClearText.prototype.clientHandle = function (msg) {
|
|
67
|
+
if (this.instanceOfClientWeightsMessageServer(msg)) {
|
|
68
|
+
// update received weights by one weights reception
|
|
69
|
+
var weights = __1.serialization.weights.decode(msg.weights);
|
|
70
|
+
this.receivedWeights = this.receivedWeights.push(weights);
|
|
71
|
+
}
|
|
72
|
+
else {
|
|
73
|
+
throw new Error('Unexpected Message Type');
|
|
74
|
+
}
|
|
75
|
+
};
|
|
76
|
+
return ClearText;
|
|
77
|
+
}(base_1.Base));
|
|
78
|
+
exports.ClearText = ClearText;
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.messages = exports.SecAgg = exports.ClearText = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var clear_text_1 = require("./clear_text");
|
|
6
|
+
Object.defineProperty(exports, "ClearText", { enumerable: true, get: function () { return clear_text_1.ClearText; } });
|
|
7
|
+
var sec_agg_1 = require("./sec_agg");
|
|
8
|
+
Object.defineProperty(exports, "SecAgg", { enumerable: true, get: function () { return sec_agg_1.SecAgg; } });
|
|
9
|
+
exports.messages = (0, tslib_1.__importStar)(require("./messages"));
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import { weights } from '../../serialization';
|
|
2
|
+
import { PeerID } from './types';
|
|
3
|
+
export declare enum messageType {
|
|
4
|
+
serverClientIDMessage = 0,
|
|
5
|
+
clientReadyMessage = 1,
|
|
6
|
+
serverReadyClients = 2,
|
|
7
|
+
clientWeightsMessageServer = 3,
|
|
8
|
+
clientSharesMessageServer = 4,
|
|
9
|
+
clientPartialSumsMessageServer = 5
|
|
10
|
+
}
|
|
11
|
+
export interface messageGeneral {
|
|
12
|
+
type: messageType;
|
|
13
|
+
}
|
|
14
|
+
export interface serverClientIDMessage extends messageGeneral {
|
|
15
|
+
peerID: PeerID;
|
|
16
|
+
}
|
|
17
|
+
export interface clientReadyMessage extends messageGeneral {
|
|
18
|
+
round: number;
|
|
19
|
+
}
|
|
20
|
+
export interface clientWeightsMessageServer extends messageGeneral {
|
|
21
|
+
peerID: PeerID;
|
|
22
|
+
weights: weights.Encoded;
|
|
23
|
+
destination: PeerID;
|
|
24
|
+
}
|
|
25
|
+
export interface clientSharesMessageServer extends messageGeneral {
|
|
26
|
+
peerID: PeerID;
|
|
27
|
+
weights: weights.Encoded;
|
|
28
|
+
destination: PeerID;
|
|
29
|
+
}
|
|
30
|
+
export interface clientPartialSumsMessageServer extends messageGeneral {
|
|
31
|
+
peerID: PeerID;
|
|
32
|
+
partials: weights.Encoded;
|
|
33
|
+
destination: PeerID;
|
|
34
|
+
}
|
|
35
|
+
export interface serverReadyClients extends messageGeneral {
|
|
36
|
+
peerList: PeerID[];
|
|
37
|
+
}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.messageType = void 0;
|
|
4
|
+
var messageType;
|
|
5
|
+
(function (messageType) {
|
|
6
|
+
// Phase 0 communication (just between server and client)
|
|
7
|
+
messageType[messageType["serverClientIDMessage"] = 0] = "serverClientIDMessage";
|
|
8
|
+
messageType[messageType["clientReadyMessage"] = 1] = "clientReadyMessage";
|
|
9
|
+
messageType[messageType["serverReadyClients"] = 2] = "serverReadyClients";
|
|
10
|
+
// Phase 1 communication (between client and peers)
|
|
11
|
+
messageType[messageType["clientWeightsMessageServer"] = 3] = "clientWeightsMessageServer";
|
|
12
|
+
messageType[messageType["clientSharesMessageServer"] = 4] = "clientSharesMessageServer";
|
|
13
|
+
// Phase 2 communication (between client and peers)
|
|
14
|
+
messageType[messageType["clientPartialSumsMessageServer"] = 5] = "clientPartialSumsMessageServer";
|
|
15
|
+
})(messageType = exports.messageType || (exports.messageType = {}));
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import { TrainingInformant, Weights } from '../..';
|
|
3
|
+
import { Base } from './base';
|
|
4
|
+
import * as messages from './messages';
|
|
5
|
+
/**
|
|
6
|
+
* Decentralized client that utilizes secure aggregation so client updates remain private
|
|
7
|
+
*/
|
|
8
|
+
export declare class SecAgg extends Base {
|
|
9
|
+
private receivedShares;
|
|
10
|
+
private receivedPartialSums;
|
|
11
|
+
private mySum;
|
|
12
|
+
private sendShares;
|
|
13
|
+
private sendPartialSums;
|
|
14
|
+
sendAndReceiveWeights(noisyWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<List<Weights>>;
|
|
15
|
+
private instanceOfClientSharesMessageServer;
|
|
16
|
+
private instanceOfClientPartialSumsMessageServer;
|
|
17
|
+
clientHandle(msg: messages.messageGeneral): void;
|
|
18
|
+
}
|