@epfml/discojs 0.1.0 → 2.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. package/README.md +28 -8
  2. package/dist/{async_buffer.d.ts → core/async_buffer.d.ts} +3 -3
  3. package/dist/{async_buffer.js → core/async_buffer.js} +5 -6
  4. package/dist/{async_informant.d.ts → core/async_informant.d.ts} +0 -0
  5. package/dist/{async_informant.js → core/async_informant.js} +0 -0
  6. package/dist/{client → core/client}/base.d.ts +4 -7
  7. package/dist/{client → core/client}/base.js +3 -2
  8. package/dist/core/client/decentralized/base.d.ts +32 -0
  9. package/dist/core/client/decentralized/base.js +212 -0
  10. package/dist/core/client/decentralized/clear_text.d.ts +14 -0
  11. package/dist/core/client/decentralized/clear_text.js +96 -0
  12. package/dist/{client → core/client}/decentralized/index.d.ts +0 -0
  13. package/dist/{client → core/client}/decentralized/index.js +0 -0
  14. package/dist/core/client/decentralized/messages.d.ts +41 -0
  15. package/dist/core/client/decentralized/messages.js +54 -0
  16. package/dist/core/client/decentralized/peer.d.ts +26 -0
  17. package/dist/core/client/decentralized/peer.js +210 -0
  18. package/dist/core/client/decentralized/peer_pool.d.ts +14 -0
  19. package/dist/core/client/decentralized/peer_pool.js +92 -0
  20. package/dist/core/client/decentralized/sec_agg.d.ts +22 -0
  21. package/dist/core/client/decentralized/sec_agg.js +190 -0
  22. package/dist/core/client/decentralized/secret_shares.d.ts +3 -0
  23. package/dist/core/client/decentralized/secret_shares.js +39 -0
  24. package/dist/core/client/decentralized/types.d.ts +2 -0
  25. package/dist/core/client/decentralized/types.js +7 -0
  26. package/dist/core/client/event_connection.d.ts +37 -0
  27. package/dist/core/client/event_connection.js +158 -0
  28. package/dist/core/client/federated/client.d.ts +37 -0
  29. package/dist/core/client/federated/client.js +273 -0
  30. package/dist/core/client/federated/index.d.ts +2 -0
  31. package/dist/core/client/federated/index.js +7 -0
  32. package/dist/core/client/federated/messages.d.ts +38 -0
  33. package/dist/core/client/federated/messages.js +25 -0
  34. package/dist/{client → core/client}/index.d.ts +2 -1
  35. package/dist/{client → core/client}/index.js +3 -3
  36. package/dist/{client → core/client}/local.d.ts +2 -2
  37. package/dist/{client → core/client}/local.js +0 -0
  38. package/dist/core/client/messages.d.ts +28 -0
  39. package/dist/core/client/messages.js +33 -0
  40. package/dist/core/client/utils.d.ts +2 -0
  41. package/dist/core/client/utils.js +19 -0
  42. package/dist/core/dataset/data/data.d.ts +11 -0
  43. package/dist/core/dataset/data/data.js +20 -0
  44. package/dist/core/dataset/data/data_split.d.ts +5 -0
  45. package/dist/{client/decentralized/types.js → core/dataset/data/data_split.js} +0 -0
  46. package/dist/core/dataset/data/image_data.d.ts +8 -0
  47. package/dist/core/dataset/data/image_data.js +64 -0
  48. package/dist/core/dataset/data/index.d.ts +5 -0
  49. package/dist/core/dataset/data/index.js +11 -0
  50. package/dist/core/dataset/data/preprocessing.d.ts +13 -0
  51. package/dist/core/dataset/data/preprocessing.js +33 -0
  52. package/dist/core/dataset/data/tabular_data.d.ts +8 -0
  53. package/dist/core/dataset/data/tabular_data.js +40 -0
  54. package/dist/{dataset → core/dataset}/data_loader/data_loader.d.ts +4 -11
  55. package/dist/{dataset → core/dataset}/data_loader/data_loader.js +0 -0
  56. package/dist/core/dataset/data_loader/image_loader.d.ts +17 -0
  57. package/dist/core/dataset/data_loader/image_loader.js +141 -0
  58. package/dist/core/dataset/data_loader/index.d.ts +3 -0
  59. package/dist/core/dataset/data_loader/index.js +9 -0
  60. package/dist/core/dataset/data_loader/tabular_loader.d.ts +29 -0
  61. package/dist/core/dataset/data_loader/tabular_loader.js +101 -0
  62. package/dist/core/dataset/dataset.d.ts +2 -0
  63. package/dist/{task/training_information.js → core/dataset/dataset.js} +0 -0
  64. package/dist/{dataset → core/dataset}/dataset_builder.d.ts +5 -5
  65. package/dist/{dataset → core/dataset}/dataset_builder.js +14 -10
  66. package/dist/core/dataset/index.d.ts +4 -0
  67. package/dist/core/dataset/index.js +14 -0
  68. package/dist/core/index.d.ts +18 -0
  69. package/dist/core/index.js +41 -0
  70. package/dist/{informant → core/informant}/graph_informant.d.ts +0 -0
  71. package/dist/{informant → core/informant}/graph_informant.js +0 -0
  72. package/dist/{informant → core/informant}/index.d.ts +0 -0
  73. package/dist/{informant → core/informant}/index.js +0 -0
  74. package/dist/{informant → core/informant}/training_informant/base.d.ts +3 -3
  75. package/dist/{informant → core/informant}/training_informant/base.js +3 -2
  76. package/dist/{informant → core/informant}/training_informant/decentralized.d.ts +0 -0
  77. package/dist/{informant → core/informant}/training_informant/decentralized.js +0 -0
  78. package/dist/{informant → core/informant}/training_informant/federated.d.ts +0 -0
  79. package/dist/{informant → core/informant}/training_informant/federated.js +0 -0
  80. package/dist/{informant → core/informant}/training_informant/index.d.ts +0 -0
  81. package/dist/{informant → core/informant}/training_informant/index.js +0 -0
  82. package/dist/{informant → core/informant}/training_informant/local.d.ts +2 -2
  83. package/dist/{informant → core/informant}/training_informant/local.js +2 -2
  84. package/dist/{logging → core/logging}/console_logger.d.ts +0 -0
  85. package/dist/{logging → core/logging}/console_logger.js +0 -0
  86. package/dist/{logging → core/logging}/index.d.ts +0 -0
  87. package/dist/{logging → core/logging}/index.js +0 -0
  88. package/dist/{logging → core/logging}/logger.d.ts +0 -0
  89. package/dist/{logging → core/logging}/logger.js +0 -0
  90. package/dist/{logging → core/logging}/trainer_logger.d.ts +0 -0
  91. package/dist/{logging → core/logging}/trainer_logger.js +0 -0
  92. package/dist/{memory → core/memory}/base.d.ts +2 -2
  93. package/dist/{memory → core/memory}/base.js +0 -0
  94. package/dist/{memory → core/memory}/empty.d.ts +0 -0
  95. package/dist/{memory → core/memory}/empty.js +0 -0
  96. package/dist/core/memory/index.d.ts +3 -0
  97. package/dist/core/memory/index.js +9 -0
  98. package/dist/{memory → core/memory}/model_type.d.ts +0 -0
  99. package/dist/{memory → core/memory}/model_type.js +0 -0
  100. package/dist/{privacy.d.ts → core/privacy.d.ts} +2 -3
  101. package/dist/{privacy.js → core/privacy.js} +3 -16
  102. package/dist/{serialization → core/serialization}/index.d.ts +0 -0
  103. package/dist/{serialization → core/serialization}/index.js +0 -0
  104. package/dist/{serialization → core/serialization}/model.d.ts +0 -0
  105. package/dist/{serialization → core/serialization}/model.js +0 -0
  106. package/dist/core/serialization/weights.d.ts +5 -0
  107. package/dist/{serialization → core/serialization}/weights.js +11 -9
  108. package/dist/{task → core/task}/data_example.d.ts +0 -0
  109. package/dist/{task → core/task}/data_example.js +0 -0
  110. package/dist/{task → core/task}/display_information.d.ts +5 -5
  111. package/dist/{task → core/task}/display_information.js +5 -10
  112. package/dist/{task → core/task}/index.d.ts +0 -0
  113. package/dist/{task → core/task}/index.js +0 -0
  114. package/dist/core/task/model_compile_data.d.ts +6 -0
  115. package/dist/core/task/model_compile_data.js +22 -0
  116. package/dist/{task → core/task}/summary.d.ts +0 -0
  117. package/dist/{task → core/task}/summary.js +0 -4
  118. package/dist/{task → core/task}/task.d.ts +2 -2
  119. package/dist/{task → core/task}/task.js +6 -7
  120. package/dist/{task → core/task}/training_information.d.ts +10 -14
  121. package/dist/core/task/training_information.js +66 -0
  122. package/dist/{tasks → core/tasks}/cifar10.d.ts +1 -2
  123. package/dist/{tasks → core/tasks}/cifar10.js +12 -23
  124. package/dist/core/tasks/geotags.d.ts +3 -0
  125. package/dist/core/tasks/geotags.js +67 -0
  126. package/dist/{tasks → core/tasks}/index.d.ts +2 -1
  127. package/dist/{tasks → core/tasks}/index.js +3 -2
  128. package/dist/core/tasks/lus_covid.d.ts +3 -0
  129. package/dist/{tasks → core/tasks}/lus_covid.js +26 -24
  130. package/dist/{tasks → core/tasks}/mnist.d.ts +1 -2
  131. package/dist/{tasks → core/tasks}/mnist.js +18 -16
  132. package/dist/core/tasks/simple_face.d.ts +2 -0
  133. package/dist/core/tasks/simple_face.js +41 -0
  134. package/dist/{tasks → core/tasks}/titanic.d.ts +1 -2
  135. package/dist/{tasks → core/tasks}/titanic.js +11 -11
  136. package/dist/core/training/disco.d.ts +23 -0
  137. package/dist/core/training/disco.js +130 -0
  138. package/dist/{training → core/training}/index.d.ts +0 -0
  139. package/dist/{training → core/training}/index.js +0 -0
  140. package/dist/{training → core/training}/trainer/distributed_trainer.d.ts +1 -2
  141. package/dist/{training → core/training}/trainer/distributed_trainer.js +6 -5
  142. package/dist/{training → core/training}/trainer/local_trainer.d.ts +2 -2
  143. package/dist/{training → core/training}/trainer/local_trainer.js +0 -0
  144. package/dist/{training → core/training}/trainer/round_tracker.d.ts +0 -0
  145. package/dist/{training → core/training}/trainer/round_tracker.js +0 -0
  146. package/dist/{training → core/training}/trainer/trainer.d.ts +1 -2
  147. package/dist/{training → core/training}/trainer/trainer.js +2 -2
  148. package/dist/{training → core/training}/trainer/trainer_builder.d.ts +0 -0
  149. package/dist/{training → core/training}/trainer/trainer_builder.js +0 -0
  150. package/dist/core/training/training_schemes.d.ts +5 -0
  151. package/dist/{training → core/training}/training_schemes.js +2 -2
  152. package/dist/{types.d.ts → core/types.d.ts} +0 -0
  153. package/dist/{types.js → core/types.js} +0 -0
  154. package/dist/{validation → core/validation}/index.d.ts +0 -0
  155. package/dist/{validation → core/validation}/index.js +0 -0
  156. package/dist/{validation → core/validation}/validator.d.ts +5 -8
  157. package/dist/{validation → core/validation}/validator.js +9 -11
  158. package/dist/core/weights/aggregation.d.ts +8 -0
  159. package/dist/core/weights/aggregation.js +96 -0
  160. package/dist/core/weights/index.d.ts +2 -0
  161. package/dist/core/weights/index.js +7 -0
  162. package/dist/core/weights/weights_container.d.ts +19 -0
  163. package/dist/core/weights/weights_container.js +64 -0
  164. package/dist/dataset/data_loader/image_loader.d.ts +3 -15
  165. package/dist/dataset/data_loader/image_loader.js +12 -125
  166. package/dist/dataset/data_loader/index.d.ts +2 -3
  167. package/dist/dataset/data_loader/index.js +3 -5
  168. package/dist/dataset/data_loader/tabular_loader.d.ts +3 -28
  169. package/dist/dataset/data_loader/tabular_loader.js +11 -92
  170. package/dist/imports.d.ts +2 -0
  171. package/dist/imports.js +7 -0
  172. package/dist/index.d.ts +2 -19
  173. package/dist/index.js +3 -39
  174. package/dist/memory/index.d.ts +1 -3
  175. package/dist/memory/index.js +3 -7
  176. package/dist/memory/memory.d.ts +26 -0
  177. package/dist/memory/memory.js +160 -0
  178. package/package.json +13 -26
  179. package/dist/aggregation.d.ts +0 -5
  180. package/dist/aggregation.js +0 -33
  181. package/dist/client/decentralized/base.d.ts +0 -43
  182. package/dist/client/decentralized/base.js +0 -243
  183. package/dist/client/decentralized/clear_text.d.ts +0 -13
  184. package/dist/client/decentralized/clear_text.js +0 -78
  185. package/dist/client/decentralized/messages.d.ts +0 -37
  186. package/dist/client/decentralized/messages.js +0 -15
  187. package/dist/client/decentralized/sec_agg.d.ts +0 -18
  188. package/dist/client/decentralized/sec_agg.js +0 -169
  189. package/dist/client/decentralized/secret_shares.d.ts +0 -5
  190. package/dist/client/decentralized/secret_shares.js +0 -58
  191. package/dist/client/decentralized/types.d.ts +0 -1
  192. package/dist/client/federated.d.ts +0 -30
  193. package/dist/client/federated.js +0 -218
  194. package/dist/dataset/index.d.ts +0 -2
  195. package/dist/dataset/index.js +0 -7
  196. package/dist/model_actor.d.ts +0 -16
  197. package/dist/model_actor.js +0 -20
  198. package/dist/serialization/weights.d.ts +0 -5
  199. package/dist/task/model_compile_data.d.ts +0 -6
  200. package/dist/task/model_compile_data.js +0 -12
  201. package/dist/tasks/lus_covid.d.ts +0 -4
  202. package/dist/tasks/simple_face.d.ts +0 -4
  203. package/dist/tasks/simple_face.js +0 -84
  204. package/dist/tfjs.d.ts +0 -2
  205. package/dist/tfjs.js +0 -6
  206. package/dist/training/disco.d.ts +0 -14
  207. package/dist/training/disco.js +0 -70
  208. package/dist/training/training_schemes.d.ts +0 -5
package/package.json CHANGED
@@ -1,11 +1,13 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "0.1.0",
3
+ "version": "2.0.0",
4
4
  "main": "dist/index.js",
5
5
  "types": "dist/index.d.ts",
6
6
  "scripts": {
7
- "build": "tsc",
8
- "test": "mocha"
7
+ "build": "cp ./src/core/index.browser.ts ./src/core/index.ts && tsc",
8
+ "build-win": "copy ..\\discojs-core\\src\\index.node.ts ..\\discojs-core\\src\\index.ts && tsc",
9
+ "test": "cp ./src/core/index.node.ts ./src/core/index.ts && mocha",
10
+ "lint": "cp ./src/core/index.browser.ts ./src/core/index.ts && npx eslint --max-warnings 0 --ignore-pattern '*.spec.ts' --ignore-pattern 'src/core' ."
9
11
  },
10
12
  "repository": {
11
13
  "type": "git",
@@ -16,32 +18,17 @@
16
18
  },
17
19
  "homepage": "https://github.com/epfml/disco#readme",
18
20
  "dependencies": {
19
- "axios": "0.27",
20
- "immutable": "4",
21
- "isomorphic-ws": "4",
22
- "lodash": "4",
23
21
  "msgpack-lite": "0.1",
24
- "simple-peer": "9",
22
+ "immutable": "4",
23
+ "@tensorflow/tfjs": "4",
25
24
  "tslib": "2",
25
+ "isomorphic-ws": "4",
26
26
  "url": "0.11",
27
- "uuid": "8",
28
- "ws": "8"
29
- },
30
- "devDependencies": {
31
- "@tensorflow/tfjs-node": "3",
32
- "@types/chai": "4",
33
- "@types/lodash": "4",
34
- "@types/mocha": "9",
27
+ "@koush/wrtc": "0.5",
28
+ "axios": "0.27",
35
29
  "@types/msgpack-lite": "0.1",
36
- "@types/simple-peer": "9",
37
- "@types/uuid": "8",
38
- "@typescript-eslint/eslint-plugin": "4",
39
- "@typescript-eslint/parser": "4",
40
- "chai": "4",
41
- "eslint": "7",
42
- "eslint-config-standard-with-typescript": "21",
43
- "mocha": "9",
44
- "ts-node": "10",
45
- "typescript": "<4.5.0"
30
+ "uuid": "8",
31
+ "ws": "8",
32
+ "simple-peer": "9"
46
33
  }
47
34
  }
@@ -1,5 +0,0 @@
1
- import { List } from 'immutable';
2
- import { Weights } from './types';
3
- export declare function sumWeights(peersWeights: List<Weights>): Weights;
4
- export declare function subtractWeights(peersWeights: List<Weights>): Weights;
5
- export declare function averageWeights(peersWeights: List<Weights>): Weights;
@@ -1,33 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.averageWeights = exports.subtractWeights = exports.sumWeights = void 0;
4
- var tslib_1 = require("tslib");
5
- var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
- function applyArithmeticToWeights(peersWeights, tfjsArithmeticFunction) {
7
- var _a;
8
- console.log('Aggregating a list of', peersWeights.size, 'weight vectors.');
9
- var firstWeightSize = (_a = peersWeights.first()) === null || _a === void 0 ? void 0 : _a.length;
10
- if (firstWeightSize === undefined) {
11
- throw new Error('no weights to average');
12
- }
13
- if (!peersWeights.rest().every(function (ws) { return ws.length === firstWeightSize; })) {
14
- throw new Error('weights dimensions are different for some of the summands');
15
- }
16
- var peersAverageWeights = peersWeights.reduce(function (accum, weights) {
17
- return accum.map(function (w, i) { return tfjsArithmeticFunction(w, weights[i]); });
18
- });
19
- return peersAverageWeights;
20
- }
21
- function sumWeights(peersWeights) {
22
- return applyArithmeticToWeights(peersWeights, tf.add);
23
- }
24
- exports.sumWeights = sumWeights;
25
- function subtractWeights(peersWeights) {
26
- return applyArithmeticToWeights(peersWeights, tf.sub);
27
- }
28
- exports.subtractWeights = subtractWeights;
29
- function averageWeights(peersWeights) {
30
- var numberOfPeers = peersWeights.size;
31
- return sumWeights(peersWeights).map(function (w) { return w.div(numberOfPeers); });
32
- }
33
- exports.averageWeights = averageWeights;
@@ -1,43 +0,0 @@
1
- /// <reference types="node" />
2
- import { List } from 'immutable';
3
- import isomorphic from 'isomorphic-ws';
4
- import { URL } from 'url';
5
- import { Task } from '@/task';
6
- import { TrainingInformant, Weights } from '../..';
7
- import { Base as ClientBase } from '../base';
8
- import * as messages from './messages';
9
- import { PeerID } from './types';
10
- /**
11
- * Abstract class for decentralized clients, executes onRoundEndCommunication as well as connecting
12
- * to the signaling server
13
- */
14
- export declare abstract class Base extends ClientBase {
15
- readonly url: URL;
16
- readonly task: Task;
17
- protected minimumReadyPeers: number;
18
- protected maxShareValue: number;
19
- constructor(url: URL, task: Task);
20
- protected server?: isomorphic.WebSocket;
21
- protected peers: PeerID[];
22
- protected peersLocked: boolean;
23
- protected ID: number;
24
- protected pauseUntil(condition: () => boolean): Promise<void>;
25
- protected sendMessagetoPeer(message: unknown): void;
26
- protected sendReadyMessage(round: number): void;
27
- private instanceOfMessageGeneral;
28
- private instanceOfServerClientIDMessage;
29
- private instanceOfServerReadyClients;
30
- protected connectServer(url: URL): Promise<isomorphic.WebSocket>;
31
- /**
32
- * Initialize the connection to the peers and to the other nodes.
33
- */
34
- connect(): Promise<void>;
35
- /**
36
- * Disconnection process when user quits the task.
37
- */
38
- disconnect(): Promise<void>;
39
- onTrainEndCommunication(_: Weights, trainingInformant: TrainingInformant): Promise<void>;
40
- onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<Weights>;
41
- abstract sendAndReceiveWeights(noisyWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<List<Weights>>;
42
- abstract clientHandle(msg: messages.messageGeneral): void;
43
- }
@@ -1,243 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Base = void 0;
4
- var tslib_1 = require("tslib");
5
- var isomorphic_ws_1 = (0, tslib_1.__importDefault)(require("isomorphic-ws"));
6
- var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
7
- var url_1 = require("url");
8
- var __1 = require("../..");
9
- var base_1 = require("../base");
10
- var messages = (0, tslib_1.__importStar)(require("./messages"));
11
- // Time to wait between network checks in milliseconds.
12
- var TICK = 100;
13
- // Time to wait for the others in milliseconds.
14
- var MAX_WAIT_PER_ROUND = 10000;
15
- /**
16
- * Abstract class for decentralized clients, executes onRoundEndCommunication as well as connecting
17
- * to the signaling server
18
- */
19
- var Base = /** @class */ (function (_super) {
20
- (0, tslib_1.__extends)(Base, _super);
21
- function Base(url, task) {
22
- var _a, _b, _c, _d;
23
- var _this = _super.call(this, url, task) || this;
24
- _this.url = url;
25
- _this.task = task;
26
- // list of peerIDs who the client will send messages to
27
- _this.peers = [];
28
- _this.peersLocked = false;
29
- // the ID of the client, set arbitrarily to 0 but gets set an actual value once it cues the signaling server
30
- // that it is ready to connect
31
- _this.ID = 0;
32
- _this.minimumReadyPeers = (_b = (_a = _this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.minimumReadyPeers) !== null && _b !== void 0 ? _b : 3;
33
- _this.maxShareValue = (_d = (_c = _this.task.trainingInformation) === null || _c === void 0 ? void 0 : _c.maxShareValue) !== null && _d !== void 0 ? _d : 100;
34
- return _this;
35
- }
36
- /*
37
- function to check if a given boolean condition is true, checks continuously until maxWait time is reached
38
- */
39
- Base.prototype.pauseUntil = function (condition) {
40
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
41
- return (0, tslib_1.__generator)(this, function (_a) {
42
- switch (_a.label) {
43
- case 0: return [4 /*yield*/, new Promise(function (resolve, reject) {
44
- var timeWas = new Date().getTime();
45
- var wait = setInterval(function () {
46
- if (condition()) {
47
- console.log('resolved after', new Date().getTime() - timeWas, 'ms');
48
- clearInterval(wait);
49
- resolve();
50
- }
51
- else if (new Date().getTime() - timeWas > MAX_WAIT_PER_ROUND) { // Timeout
52
- console.log('rejected after', new Date().getTime() - timeWas, 'ms');
53
- clearInterval(wait);
54
- reject(new Error('timeout'));
55
- }
56
- }, TICK);
57
- })];
58
- case 1: return [2 /*return*/, _a.sent()];
59
- }
60
- });
61
- });
62
- };
63
- /*
64
- function behavior will change with peer 2 peer, sends message to its destination, currently through server
65
- */
66
- Base.prototype.sendMessagetoPeer = function (message) {
67
- if (this.server === undefined) {
68
- throw new Error("Undefined Server, can't send message");
69
- }
70
- this.server.send(message);
71
- };
72
- /*
73
- send message to server that client is ready
74
- */
75
- Base.prototype.sendReadyMessage = function (round) {
76
- // Broadcast our readiness
77
- var msg = { type: messages.messageType.clientReadyMessage, round: round };
78
- var encodedMsg = msgpack_lite_1.default.encode(msg);
79
- if (this.server === undefined) {
80
- throw new Error('server undefined, could not connect peers');
81
- }
82
- this.server.send(encodedMsg);
83
- };
84
- /*
85
- checks if message is of type messageGeneral. If it is, the specific message.type can be identified.
86
- */
87
- Base.prototype.instanceOfMessageGeneral = function (msg) {
88
- return typeof msg === 'object' && msg !== null && 'type' in msg;
89
- };
90
- /*
91
- checks if message contains the client's ID number
92
- */
93
- Base.prototype.instanceOfServerClientIDMessage = function (msg) {
94
- return msg.type === messages.messageType.serverClientIDMessage;
95
- };
96
- /*
97
- checks if message contains the list of peerIDs that are ready to share updates
98
- */
99
- Base.prototype.instanceOfServerReadyClients = function (msg) {
100
- return msg.type === messages.messageType.serverReadyClients;
101
- };
102
- /*
103
- creation of the websocket for the server, connection of client to that webSocket,
104
- deals with message reception from decentralized client perspective (messages received by client)
105
- */
106
- Base.prototype.connectServer = function (url) {
107
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
108
- var ws;
109
- var _this = this;
110
- return (0, tslib_1.__generator)(this, function (_a) {
111
- switch (_a.label) {
112
- case 0:
113
- ws = new isomorphic_ws_1.default.WebSocket(url);
114
- ws.binaryType = 'arraybuffer';
115
- ws.onmessage = function (event) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
116
- var msg;
117
- return (0, tslib_1.__generator)(this, function (_a) {
118
- if (!(event.data instanceof ArrayBuffer)) {
119
- throw new Error('server did not send an ArrayBuffer');
120
- }
121
- msg = msgpack_lite_1.default.decode(new Uint8Array(event.data));
122
- // check message type to choose correct action
123
- if (this.instanceOfMessageGeneral(msg)) {
124
- if (this.instanceOfServerClientIDMessage(msg)) {
125
- // updated ID
126
- this.ID = msg.peerID;
127
- }
128
- else if (this.instanceOfServerReadyClients(msg)) {
129
- // updated connected peers
130
- if (!this.peersLocked) {
131
- this.peers = msg.peerList;
132
- this.peersLocked = true;
133
- }
134
- }
135
- else {
136
- this.clientHandle(msg);
137
- }
138
- }
139
- return [2 /*return*/];
140
- });
141
- }); };
142
- return [4 /*yield*/, new Promise(function (resolve, reject) {
143
- ws.onerror = function (err) {
144
- return reject(new Error("connecting server: " + err.message));
145
- }; // eslint-disable-line @typescript-eslint/restrict-template-expressions
146
- ws.onopen = function () { return resolve(ws); };
147
- })];
148
- case 1: return [2 /*return*/, _a.sent()];
149
- }
150
- });
151
- });
152
- };
153
- /**
154
- * Initialize the connection to the peers and to the other nodes.
155
- */
156
- Base.prototype.connect = function () {
157
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
158
- var serverURL, _a;
159
- return (0, tslib_1.__generator)(this, function (_b) {
160
- switch (_b.label) {
161
- case 0:
162
- serverURL = new url_1.URL('', this.url.href);
163
- switch (this.url.protocol) {
164
- case 'http:':
165
- serverURL.protocol = 'ws:';
166
- break;
167
- case 'https:':
168
- serverURL.protocol = 'wss:';
169
- break;
170
- default:
171
- throw new Error("unknown protocol: " + this.url.protocol);
172
- }
173
- serverURL.pathname += "deai/" + this.task.taskID;
174
- _a = this;
175
- return [4 /*yield*/, this.connectServer(serverURL)];
176
- case 1:
177
- _a.server = _b.sent();
178
- return [2 /*return*/];
179
- }
180
- });
181
- });
182
- };
183
- /**
184
- * Disconnection process when user quits the task.
185
- */
186
- Base.prototype.disconnect = function () {
187
- var _a;
188
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
189
- return (0, tslib_1.__generator)(this, function (_b) {
190
- // this.peers.forEach((peer) => peer.destroy())
191
- // this.peers = Map()
192
- (_a = this.server) === null || _a === void 0 ? void 0 : _a.close();
193
- this.server = undefined;
194
- return [2 /*return*/];
195
- });
196
- });
197
- };
198
- Base.prototype.onTrainEndCommunication = function (_, trainingInformant) {
199
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
200
- return (0, tslib_1.__generator)(this, function (_a) {
201
- // TODO: enter seeding mode?
202
- trainingInformant.addMessage('Training finished.');
203
- return [2 /*return*/];
204
- });
205
- });
206
- };
207
- Base.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, round, trainingInformant) {
208
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
209
- var noisyWeights, finalWeights, Error_1;
210
- var _this = this;
211
- return (0, tslib_1.__generator)(this, function (_a) {
212
- switch (_a.label) {
213
- case 0:
214
- _a.trys.push([0, 3, , 4]);
215
- // reset peer list at each round of training to make sure client waits for updated peerList from server
216
- this.peers = [];
217
- this.peersLocked = false;
218
- // centralized phase of communication --> client tells server that they have finished a local round and are ready to aggregate
219
- this.sendReadyMessage(round);
220
- // wait for peers to be connected before sending any update information
221
- return [4 /*yield*/, this.pauseUntil(function () { return _this.peers.length >= _this.minimumReadyPeers; })
222
- // Apply DP to updates that will be sent
223
- ];
224
- case 1:
225
- // wait for peers to be connected before sending any update information
226
- _a.sent();
227
- noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
228
- return [4 /*yield*/, this.sendAndReceiveWeights(noisyWeights, round, trainingInformant)];
229
- case 2:
230
- finalWeights = _a.sent();
231
- return [2 /*return*/, __1.aggregation.averageWeights(finalWeights)];
232
- case 3:
233
- Error_1 = _a.sent();
234
- console.log('Timeout Error Reported, training will continue');
235
- return [2 /*return*/, updatedWeights];
236
- case 4: return [2 /*return*/];
237
- }
238
- });
239
- });
240
- };
241
- return Base;
242
- }(base_1.Base));
243
- exports.Base = Base;
@@ -1,13 +0,0 @@
1
- import { List } from 'immutable';
2
- import { TrainingInformant, Weights } from '../..';
3
- import { Base } from './base';
4
- import * as messages from './messages';
5
- /**
6
- * Decentralized client that does not utilize secure aggregation, but sends model updates in clear text
7
- */
8
- export declare class ClearText extends Base {
9
- protected receivedWeights: List<Weights>;
10
- sendAndReceiveWeights(noisyWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<List<Weights>>;
11
- private instanceOfClientWeightsMessageServer;
12
- clientHandle(msg: messages.messageGeneral): void;
13
- }
@@ -1,78 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ClearText = void 0;
4
- var tslib_1 = require("tslib");
5
- var immutable_1 = require("immutable");
6
- var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
7
- var __1 = require("../..");
8
- var base_1 = require("./base");
9
- var messages = (0, tslib_1.__importStar)(require("./messages"));
10
- /**
11
- * Decentralized client that does not utilize secure aggregation, but sends model updates in clear text
12
- */
13
- var ClearText = /** @class */ (function (_super) {
14
- (0, tslib_1.__extends)(ClearText, _super);
15
- function ClearText() {
16
- var _this = _super !== null && _super.apply(this, arguments) || this;
17
- // list of weights received from other clients
18
- _this.receivedWeights = (0, immutable_1.List)();
19
- return _this;
20
- }
21
- ClearText.prototype.sendAndReceiveWeights = function (noisyWeights, round, trainingInformant) {
22
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
23
- var weightsToSend, i, msg, encodedMsg;
24
- var _this = this;
25
- return (0, tslib_1.__generator)(this, function (_a) {
26
- switch (_a.label) {
27
- case 0:
28
- // reset received fields at beginning of each round
29
- this.receivedWeights = this.receivedWeights.clear();
30
- return [4 /*yield*/, __1.serialization.weights.encode(noisyWeights)];
31
- case 1:
32
- weightsToSend = _a.sent();
33
- if (this.server === undefined) {
34
- throw new Error('server undefined so we cannot send weights through it');
35
- }
36
- // PHASE 1 COMMUNICATION --> create weights message and send to all peers (only one phase of communication for clear_text)
37
- for (i = 0; i < this.peers.length; i++) {
38
- msg = {
39
- type: messages.messageType.clientWeightsMessageServer,
40
- peerID: this.ID,
41
- weights: weightsToSend,
42
- destination: this.peers[i]
43
- };
44
- encodedMsg = msgpack_lite_1.default.encode(msg);
45
- this.sendMessagetoPeer(encodedMsg);
46
- }
47
- // wait to receive all weights from peers
48
- return [4 /*yield*/, this.pauseUntil(function () { return _this.receivedWeights.size >= _this.peers.length; })];
49
- case 2:
50
- // wait to receive all weights from peers
51
- _a.sent();
52
- return [2 /*return*/, this.receivedWeights];
53
- }
54
- });
55
- });
56
- };
57
- /*
58
- checks if message contains weights from a peer
59
- */
60
- ClearText.prototype.instanceOfClientWeightsMessageServer = function (msg) {
61
- return msg.type === messages.messageType.clientWeightsMessageServer;
62
- };
63
- /*
64
- handles received messages from signaling server
65
- */
66
- ClearText.prototype.clientHandle = function (msg) {
67
- if (this.instanceOfClientWeightsMessageServer(msg)) {
68
- // update received weights by one weights reception
69
- var weights = __1.serialization.weights.decode(msg.weights);
70
- this.receivedWeights = this.receivedWeights.push(weights);
71
- }
72
- else {
73
- throw new Error('Unexpected Message Type');
74
- }
75
- };
76
- return ClearText;
77
- }(base_1.Base));
78
- exports.ClearText = ClearText;
@@ -1,37 +0,0 @@
1
- import { weights } from '../../serialization';
2
- import { PeerID } from './types';
3
- export declare enum messageType {
4
- serverClientIDMessage = 0,
5
- clientReadyMessage = 1,
6
- serverReadyClients = 2,
7
- clientWeightsMessageServer = 3,
8
- clientSharesMessageServer = 4,
9
- clientPartialSumsMessageServer = 5
10
- }
11
- export interface messageGeneral {
12
- type: messageType;
13
- }
14
- export interface serverClientIDMessage extends messageGeneral {
15
- peerID: PeerID;
16
- }
17
- export interface clientReadyMessage extends messageGeneral {
18
- round: number;
19
- }
20
- export interface clientWeightsMessageServer extends messageGeneral {
21
- peerID: PeerID;
22
- weights: weights.Encoded;
23
- destination: PeerID;
24
- }
25
- export interface clientSharesMessageServer extends messageGeneral {
26
- peerID: PeerID;
27
- weights: weights.Encoded;
28
- destination: PeerID;
29
- }
30
- export interface clientPartialSumsMessageServer extends messageGeneral {
31
- peerID: PeerID;
32
- partials: weights.Encoded;
33
- destination: PeerID;
34
- }
35
- export interface serverReadyClients extends messageGeneral {
36
- peerList: PeerID[];
37
- }
@@ -1,15 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.messageType = void 0;
4
- var messageType;
5
- (function (messageType) {
6
- // Phase 0 communication (just between server and client)
7
- messageType[messageType["serverClientIDMessage"] = 0] = "serverClientIDMessage";
8
- messageType[messageType["clientReadyMessage"] = 1] = "clientReadyMessage";
9
- messageType[messageType["serverReadyClients"] = 2] = "serverReadyClients";
10
- // Phase 1 communication (between client and peers)
11
- messageType[messageType["clientWeightsMessageServer"] = 3] = "clientWeightsMessageServer";
12
- messageType[messageType["clientSharesMessageServer"] = 4] = "clientSharesMessageServer";
13
- // Phase 2 communication (between client and peers)
14
- messageType[messageType["clientPartialSumsMessageServer"] = 5] = "clientPartialSumsMessageServer";
15
- })(messageType = exports.messageType || (exports.messageType = {}));
@@ -1,18 +0,0 @@
1
- import { List } from 'immutable';
2
- import { TrainingInformant, Weights } from '../..';
3
- import { Base } from './base';
4
- import * as messages from './messages';
5
- /**
6
- * Decentralized client that utilizes secure aggregation so client updates remain private
7
- */
8
- export declare class SecAgg extends Base {
9
- private receivedShares;
10
- private receivedPartialSums;
11
- private mySum;
12
- private sendShares;
13
- private sendPartialSums;
14
- sendAndReceiveWeights(noisyWeights: Weights, round: number, trainingInformant: TrainingInformant): Promise<List<Weights>>;
15
- private instanceOfClientSharesMessageServer;
16
- private instanceOfClientPartialSumsMessageServer;
17
- clientHandle(msg: messages.messageGeneral): void;
18
- }