@epfml/discojs 1.0.0 → 2.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. package/README.md +28 -8
  2. package/dist/{async_buffer.d.ts → core/async_buffer.d.ts} +3 -3
  3. package/dist/{async_buffer.js → core/async_buffer.js} +5 -6
  4. package/dist/{async_informant.d.ts → core/async_informant.d.ts} +0 -0
  5. package/dist/{async_informant.js → core/async_informant.js} +0 -0
  6. package/dist/{client → core/client}/base.d.ts +4 -7
  7. package/dist/{client → core/client}/base.js +3 -2
  8. package/dist/core/client/decentralized/base.d.ts +32 -0
  9. package/dist/core/client/decentralized/base.js +212 -0
  10. package/dist/core/client/decentralized/clear_text.d.ts +14 -0
  11. package/dist/core/client/decentralized/clear_text.js +96 -0
  12. package/dist/{client → core/client}/decentralized/index.d.ts +0 -0
  13. package/dist/{client → core/client}/decentralized/index.js +0 -0
  14. package/dist/core/client/decentralized/messages.d.ts +41 -0
  15. package/dist/core/client/decentralized/messages.js +54 -0
  16. package/dist/core/client/decentralized/peer.d.ts +26 -0
  17. package/dist/core/client/decentralized/peer.js +210 -0
  18. package/dist/core/client/decentralized/peer_pool.d.ts +14 -0
  19. package/dist/core/client/decentralized/peer_pool.js +92 -0
  20. package/dist/core/client/decentralized/sec_agg.d.ts +22 -0
  21. package/dist/core/client/decentralized/sec_agg.js +190 -0
  22. package/dist/core/client/decentralized/secret_shares.d.ts +3 -0
  23. package/dist/core/client/decentralized/secret_shares.js +39 -0
  24. package/dist/core/client/decentralized/types.d.ts +2 -0
  25. package/dist/core/client/decentralized/types.js +7 -0
  26. package/dist/core/client/event_connection.d.ts +37 -0
  27. package/dist/core/client/event_connection.js +158 -0
  28. package/dist/core/client/federated/client.d.ts +37 -0
  29. package/dist/core/client/federated/client.js +273 -0
  30. package/dist/core/client/federated/index.d.ts +2 -0
  31. package/dist/core/client/federated/index.js +7 -0
  32. package/dist/core/client/federated/messages.d.ts +38 -0
  33. package/dist/core/client/federated/messages.js +25 -0
  34. package/dist/{client → core/client}/index.d.ts +2 -1
  35. package/dist/{client → core/client}/index.js +3 -3
  36. package/dist/{client → core/client}/local.d.ts +2 -2
  37. package/dist/{client → core/client}/local.js +0 -0
  38. package/dist/core/client/messages.d.ts +28 -0
  39. package/dist/core/client/messages.js +33 -0
  40. package/dist/core/client/utils.d.ts +2 -0
  41. package/dist/core/client/utils.js +19 -0
  42. package/dist/core/dataset/data/data.d.ts +11 -0
  43. package/dist/core/dataset/data/data.js +20 -0
  44. package/dist/core/dataset/data/data_split.d.ts +5 -0
  45. package/dist/{client/decentralized/types.js → core/dataset/data/data_split.js} +0 -0
  46. package/dist/core/dataset/data/image_data.d.ts +8 -0
  47. package/dist/core/dataset/data/image_data.js +64 -0
  48. package/dist/core/dataset/data/index.d.ts +5 -0
  49. package/dist/core/dataset/data/index.js +11 -0
  50. package/dist/core/dataset/data/preprocessing.d.ts +13 -0
  51. package/dist/core/dataset/data/preprocessing.js +33 -0
  52. package/dist/core/dataset/data/tabular_data.d.ts +8 -0
  53. package/dist/core/dataset/data/tabular_data.js +40 -0
  54. package/dist/{dataset → core/dataset}/data_loader/data_loader.d.ts +4 -11
  55. package/dist/{dataset → core/dataset}/data_loader/data_loader.js +0 -0
  56. package/dist/core/dataset/data_loader/image_loader.d.ts +17 -0
  57. package/dist/core/dataset/data_loader/image_loader.js +141 -0
  58. package/dist/core/dataset/data_loader/index.d.ts +3 -0
  59. package/dist/core/dataset/data_loader/index.js +9 -0
  60. package/dist/core/dataset/data_loader/tabular_loader.d.ts +29 -0
  61. package/dist/core/dataset/data_loader/tabular_loader.js +101 -0
  62. package/dist/core/dataset/dataset.d.ts +2 -0
  63. package/dist/{task/training_information.js → core/dataset/dataset.js} +0 -0
  64. package/dist/{dataset → core/dataset}/dataset_builder.d.ts +5 -5
  65. package/dist/{dataset → core/dataset}/dataset_builder.js +14 -10
  66. package/dist/core/dataset/index.d.ts +4 -0
  67. package/dist/core/dataset/index.js +14 -0
  68. package/dist/core/index.d.ts +18 -0
  69. package/dist/core/index.js +41 -0
  70. package/dist/{informant → core/informant}/graph_informant.d.ts +0 -0
  71. package/dist/{informant → core/informant}/graph_informant.js +0 -0
  72. package/dist/{informant → core/informant}/index.d.ts +0 -0
  73. package/dist/{informant → core/informant}/index.js +0 -0
  74. package/dist/{informant → core/informant}/training_informant/base.d.ts +3 -3
  75. package/dist/{informant → core/informant}/training_informant/base.js +3 -2
  76. package/dist/{informant → core/informant}/training_informant/decentralized.d.ts +0 -0
  77. package/dist/{informant → core/informant}/training_informant/decentralized.js +0 -0
  78. package/dist/{informant → core/informant}/training_informant/federated.d.ts +0 -0
  79. package/dist/{informant → core/informant}/training_informant/federated.js +0 -0
  80. package/dist/{informant → core/informant}/training_informant/index.d.ts +0 -0
  81. package/dist/{informant → core/informant}/training_informant/index.js +0 -0
  82. package/dist/{informant → core/informant}/training_informant/local.d.ts +2 -2
  83. package/dist/{informant → core/informant}/training_informant/local.js +2 -2
  84. package/dist/{logging → core/logging}/console_logger.d.ts +0 -0
  85. package/dist/{logging → core/logging}/console_logger.js +0 -0
  86. package/dist/{logging → core/logging}/index.d.ts +0 -0
  87. package/dist/{logging → core/logging}/index.js +0 -0
  88. package/dist/{logging → core/logging}/logger.d.ts +0 -0
  89. package/dist/{logging → core/logging}/logger.js +0 -0
  90. package/dist/{logging → core/logging}/trainer_logger.d.ts +0 -0
  91. package/dist/{logging → core/logging}/trainer_logger.js +0 -0
  92. package/dist/{memory → core/memory}/base.d.ts +2 -2
  93. package/dist/{memory → core/memory}/base.js +0 -0
  94. package/dist/{memory → core/memory}/empty.d.ts +0 -0
  95. package/dist/{memory → core/memory}/empty.js +0 -0
  96. package/dist/core/memory/index.d.ts +3 -0
  97. package/dist/core/memory/index.js +9 -0
  98. package/dist/{memory → core/memory}/model_type.d.ts +0 -0
  99. package/dist/{memory → core/memory}/model_type.js +0 -0
  100. package/dist/{privacy.d.ts → core/privacy.d.ts} +2 -3
  101. package/dist/{privacy.js → core/privacy.js} +3 -16
  102. package/dist/{serialization → core/serialization}/index.d.ts +0 -0
  103. package/dist/{serialization → core/serialization}/index.js +0 -0
  104. package/dist/{serialization → core/serialization}/model.d.ts +0 -0
  105. package/dist/{serialization → core/serialization}/model.js +0 -0
  106. package/dist/core/serialization/weights.d.ts +5 -0
  107. package/dist/{serialization → core/serialization}/weights.js +11 -9
  108. package/dist/{task → core/task}/data_example.d.ts +0 -0
  109. package/dist/{task → core/task}/data_example.js +0 -0
  110. package/dist/{task → core/task}/display_information.d.ts +5 -5
  111. package/dist/{task → core/task}/display_information.js +5 -10
  112. package/dist/{task → core/task}/index.d.ts +0 -0
  113. package/dist/{task → core/task}/index.js +0 -0
  114. package/dist/core/task/model_compile_data.d.ts +6 -0
  115. package/dist/core/task/model_compile_data.js +22 -0
  116. package/dist/{task → core/task}/summary.d.ts +0 -0
  117. package/dist/{task → core/task}/summary.js +0 -4
  118. package/dist/{task → core/task}/task.d.ts +2 -2
  119. package/dist/{task → core/task}/task.js +6 -7
  120. package/dist/{task → core/task}/training_information.d.ts +10 -14
  121. package/dist/core/task/training_information.js +66 -0
  122. package/dist/{tasks → core/tasks}/cifar10.d.ts +1 -2
  123. package/dist/{tasks → core/tasks}/cifar10.js +12 -23
  124. package/dist/core/tasks/geotags.d.ts +3 -0
  125. package/dist/core/tasks/geotags.js +67 -0
  126. package/dist/{tasks → core/tasks}/index.d.ts +2 -1
  127. package/dist/{tasks → core/tasks}/index.js +3 -2
  128. package/dist/core/tasks/lus_covid.d.ts +3 -0
  129. package/dist/{tasks → core/tasks}/lus_covid.js +26 -24
  130. package/dist/{tasks → core/tasks}/mnist.d.ts +1 -2
  131. package/dist/{tasks → core/tasks}/mnist.js +18 -16
  132. package/dist/core/tasks/simple_face.d.ts +2 -0
  133. package/dist/core/tasks/simple_face.js +41 -0
  134. package/dist/{tasks → core/tasks}/titanic.d.ts +1 -2
  135. package/dist/{tasks → core/tasks}/titanic.js +11 -11
  136. package/dist/core/training/disco.d.ts +23 -0
  137. package/dist/core/training/disco.js +130 -0
  138. package/dist/{training → core/training}/index.d.ts +0 -0
  139. package/dist/{training → core/training}/index.js +0 -0
  140. package/dist/{training → core/training}/trainer/distributed_trainer.d.ts +1 -2
  141. package/dist/{training → core/training}/trainer/distributed_trainer.js +6 -5
  142. package/dist/{training → core/training}/trainer/local_trainer.d.ts +2 -2
  143. package/dist/{training → core/training}/trainer/local_trainer.js +0 -0
  144. package/dist/{training → core/training}/trainer/round_tracker.d.ts +0 -0
  145. package/dist/{training → core/training}/trainer/round_tracker.js +0 -0
  146. package/dist/{training → core/training}/trainer/trainer.d.ts +1 -2
  147. package/dist/{training → core/training}/trainer/trainer.js +2 -2
  148. package/dist/{training → core/training}/trainer/trainer_builder.d.ts +0 -0
  149. package/dist/{training → core/training}/trainer/trainer_builder.js +0 -0
  150. package/dist/core/training/training_schemes.d.ts +5 -0
  151. package/dist/{training → core/training}/training_schemes.js +2 -2
  152. package/dist/{types.d.ts → core/types.d.ts} +0 -0
  153. package/dist/{types.js → core/types.js} +0 -0
  154. package/dist/{validation → core/validation}/index.d.ts +0 -0
  155. package/dist/{validation → core/validation}/index.js +0 -0
  156. package/dist/{validation → core/validation}/validator.d.ts +5 -8
  157. package/dist/{validation → core/validation}/validator.js +9 -11
  158. package/dist/core/weights/aggregation.d.ts +8 -0
  159. package/dist/core/weights/aggregation.js +96 -0
  160. package/dist/core/weights/index.d.ts +2 -0
  161. package/dist/core/weights/index.js +7 -0
  162. package/dist/core/weights/weights_container.d.ts +19 -0
  163. package/dist/core/weights/weights_container.js +64 -0
  164. package/dist/dataset/data_loader/image_loader.d.ts +3 -15
  165. package/dist/dataset/data_loader/image_loader.js +12 -125
  166. package/dist/dataset/data_loader/index.d.ts +2 -3
  167. package/dist/dataset/data_loader/index.js +3 -5
  168. package/dist/dataset/data_loader/tabular_loader.d.ts +3 -28
  169. package/dist/dataset/data_loader/tabular_loader.js +11 -92
  170. package/dist/imports.d.ts +2 -0
  171. package/dist/imports.js +7 -0
  172. package/dist/index.d.ts +2 -19
  173. package/dist/index.js +3 -39
  174. package/dist/memory/index.d.ts +1 -3
  175. package/dist/memory/index.js +3 -7
  176. package/dist/memory/memory.d.ts +26 -0
  177. package/dist/memory/memory.js +160 -0
  178. package/package.json +14 -27
  179. package/dist/aggregation.d.ts +0 -5
  180. package/dist/aggregation.js +0 -33
  181. package/dist/client/decentralized/base.d.ts +0 -43
  182. package/dist/client/decentralized/base.js +0 -243
  183. package/dist/client/decentralized/clear_text.d.ts +0 -13
  184. package/dist/client/decentralized/clear_text.js +0 -78
  185. package/dist/client/decentralized/messages.d.ts +0 -37
  186. package/dist/client/decentralized/messages.js +0 -15
  187. package/dist/client/decentralized/sec_agg.d.ts +0 -18
  188. package/dist/client/decentralized/sec_agg.js +0 -169
  189. package/dist/client/decentralized/secret_shares.d.ts +0 -5
  190. package/dist/client/decentralized/secret_shares.js +0 -58
  191. package/dist/client/decentralized/types.d.ts +0 -1
  192. package/dist/client/federated.d.ts +0 -30
  193. package/dist/client/federated.js +0 -218
  194. package/dist/dataset/index.d.ts +0 -2
  195. package/dist/dataset/index.js +0 -7
  196. package/dist/model_actor.d.ts +0 -16
  197. package/dist/model_actor.js +0 -20
  198. package/dist/serialization/weights.d.ts +0 -5
  199. package/dist/task/model_compile_data.d.ts +0 -6
  200. package/dist/task/model_compile_data.js +0 -12
  201. package/dist/tasks/lus_covid.d.ts +0 -4
  202. package/dist/tasks/simple_face.d.ts +0 -4
  203. package/dist/tasks/simple_face.js +0 -84
  204. package/dist/tfjs.d.ts +0 -2
  205. package/dist/tfjs.js +0 -6
  206. package/dist/training/disco.d.ts +0 -14
  207. package/dist/training/disco.js +0 -70
  208. package/dist/training/training_schemes.d.ts +0 -5
@@ -0,0 +1,158 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.WebSocketServer = exports.PeerConnection = exports.waitMessageWithTimeout = exports.waitMessage = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var isomorphic_ws_1 = (0, tslib_1.__importDefault)(require("isomorphic-ws"));
6
+ var events_1 = require("events");
7
+ var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
8
+ var decentralizedMessages = (0, tslib_1.__importStar)(require("./decentralized/messages"));
9
+ var messages_1 = require("./messages");
10
+ var utils_1 = require("./utils");
11
+ function waitMessage(connection, type) {
12
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
13
+ return (0, tslib_1.__generator)(this, function (_a) {
14
+ switch (_a.label) {
15
+ case 0: return [4 /*yield*/, new Promise(function (resolve) {
16
+ // "once" is important because we can't resolve the same promise multiple time
17
+ connection.once(type, function (event) {
18
+ resolve(event);
19
+ });
20
+ })];
21
+ case 1: return [2 /*return*/, _a.sent()];
22
+ }
23
+ });
24
+ });
25
+ }
26
+ exports.waitMessage = waitMessage;
27
+ function waitMessageWithTimeout(connection, type, timeoutMs) {
28
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
29
+ return (0, tslib_1.__generator)(this, function (_a) {
30
+ switch (_a.label) {
31
+ case 0: return [4 /*yield*/, Promise.race([waitMessage(connection, type), (0, utils_1.timeout)(timeoutMs)])];
32
+ case 1: return [2 /*return*/, _a.sent()];
33
+ }
34
+ });
35
+ });
36
+ }
37
+ exports.waitMessageWithTimeout = waitMessageWithTimeout;
38
+ var PeerConnection = /** @class */ (function () {
39
+ function PeerConnection(selfId, peer, signallingServer) {
40
+ this.eventEmitter = new events_1.EventEmitter();
41
+ this.selfId = selfId;
42
+ this.peer = peer;
43
+ this.signallingServer = signallingServer;
44
+ }
45
+ PeerConnection.prototype.connect = function () {
46
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
47
+ var _this = this;
48
+ return (0, tslib_1.__generator)(this, function (_a) {
49
+ switch (_a.label) {
50
+ case 0:
51
+ this.peer.on('signal', function (signal) {
52
+ console.debug(_this.selfId, 'generates signal for', _this.peer.id);
53
+ var msg = {
54
+ type: messages_1.type.SignalForPeer,
55
+ peer: _this.peer.id,
56
+ signal: signal
57
+ };
58
+ _this.signallingServer.send(msg);
59
+ });
60
+ this.peer.on('data', function (data) {
61
+ var msg = msgpack_lite_1.default.decode(data);
62
+ if (!decentralizedMessages.isPeerMessage(msg)) {
63
+ throw new Error("invalid message received: " + JSON.stringify(msg));
64
+ }
65
+ _this.eventEmitter.emit(msg.type.toString(), msg);
66
+ });
67
+ this.peer.on('close', function () { return console.warn('peer', _this.peer.id, 'closed connection'); });
68
+ return [4 /*yield*/, new Promise(function (resolve) {
69
+ _this.peer.on('connect', function () {
70
+ console.debug('connected new peer', _this.peer.id);
71
+ resolve();
72
+ });
73
+ })];
74
+ case 1: return [2 /*return*/, _a.sent()];
75
+ }
76
+ });
77
+ });
78
+ };
79
+ PeerConnection.prototype.signal = function (signal) {
80
+ this.peer.signal(signal);
81
+ };
82
+ PeerConnection.prototype.on = function (type, handler) {
83
+ this.eventEmitter.on(type.toString(), handler);
84
+ };
85
+ PeerConnection.prototype.once = function (type, handler) {
86
+ this.eventEmitter.once(type.toString(), handler);
87
+ };
88
+ PeerConnection.prototype.send = function (msg) {
89
+ if (!decentralizedMessages.isPeerMessage(msg)) {
90
+ throw new Error("can't send this type of message: " + JSON.stringify(msg));
91
+ }
92
+ this.peer.send(msgpack_lite_1.default.encode(msg));
93
+ };
94
+ PeerConnection.prototype.disconnect = function () {
95
+ this.peer.destroy();
96
+ };
97
+ return PeerConnection;
98
+ }());
99
+ exports.PeerConnection = PeerConnection;
100
+ var WebSocketServer = /** @class */ (function () {
101
+ function WebSocketServer(socket, eventEmitter, validateReceived, validateSent) {
102
+ this.socket = socket;
103
+ this.eventEmitter = eventEmitter;
104
+ this.validateReceived = validateReceived;
105
+ this.validateSent = validateSent;
106
+ }
107
+ WebSocketServer.connect = function (url, validateReceived, validateSent) {
108
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
109
+ var WS, ws, emitter, server;
110
+ return (0, tslib_1.__generator)(this, function (_a) {
111
+ switch (_a.label) {
112
+ case 0:
113
+ WS = typeof window !== 'undefined' ? window.WebSocket : isomorphic_ws_1.default.WebSocket;
114
+ ws = new WS(url);
115
+ ws.binaryType = 'arraybuffer';
116
+ emitter = new events_1.EventEmitter();
117
+ server = new WebSocketServer(ws, emitter, validateReceived, validateSent);
118
+ ws.onmessage = function (event) {
119
+ if (!(event.data instanceof ArrayBuffer)) {
120
+ throw new Error('server did not send an ArrayBuffer');
121
+ }
122
+ var msg = msgpack_lite_1.default.decode(new Uint8Array(event.data));
123
+ // Validate message format
124
+ if (validateReceived && !validateReceived(msg)) {
125
+ throw new Error("invalid message received: " + JSON.stringify(msg));
126
+ }
127
+ emitter.emit(msg.type.toString(), msg);
128
+ };
129
+ return [4 /*yield*/, new Promise(function (resolve, reject) {
130
+ ws.onerror = function (err) {
131
+ return reject(new Error("connecting server: " + err.message));
132
+ }; // eslint-disable-line @typescript-eslint/restrict-template-expressions
133
+ ws.onopen = function () { return resolve(server); };
134
+ })];
135
+ case 1: return [2 /*return*/, _a.sent()];
136
+ }
137
+ });
138
+ });
139
+ };
140
+ WebSocketServer.prototype.disconnect = function () {
141
+ this.socket.close();
142
+ };
143
+ // Not straigtforward way of making sure the handler take the correct message type as a parameter, for typesafety
144
+ WebSocketServer.prototype.on = function (type, handler) {
145
+ this.eventEmitter.on(type.toString(), handler);
146
+ };
147
+ WebSocketServer.prototype.once = function (type, handler) {
148
+ this.eventEmitter.once(type.toString(), handler);
149
+ };
150
+ WebSocketServer.prototype.send = function (msg) {
151
+ if (this.validateSent && !this.validateSent(msg)) {
152
+ throw new Error("can't send this type of message: " + JSON.stringify(msg));
153
+ }
154
+ this.socket.send(msgpack_lite_1.default.encode(msg));
155
+ };
156
+ return WebSocketServer;
157
+ }());
158
+ exports.WebSocketServer = WebSocketServer;
@@ -0,0 +1,37 @@
1
+ import { informant, MetadataID, WeightsContainer } from '../..';
2
+ import { Base } from '../base';
3
+ import { EventConnection } from '../event_connection';
4
+ /**
5
+ * Class that deals with communication with the centralized server when training
6
+ * a specific task in the federated setting.
7
+ */
8
+ export declare class Client extends Base {
9
+ private readonly clientID;
10
+ private readonly peer;
11
+ private round;
12
+ protected _server?: EventConnection;
13
+ private serverRound?;
14
+ private serverWeights?;
15
+ private receivedStatistics?;
16
+ private metadataMap?;
17
+ get server(): EventConnection;
18
+ private connectServer;
19
+ /**
20
+ * Initialize the connection to the server. TODO: In the case of FeAI,
21
+ * should return the current server-side round for the task.
22
+ */
23
+ connect(): Promise<void>;
24
+ /**
25
+ * Disconnection process when user quits the task.
26
+ */
27
+ disconnect(): Promise<void>;
28
+ private sendMessage;
29
+ postWeightsToServer(weights: WeightsContainer): Promise<void>;
30
+ getLatestServerRound(): Promise<number | undefined>;
31
+ pullRoundAndFetchWeights(): Promise<WeightsContainer | undefined>;
32
+ pullServerStatistics(trainingInformant: informant.FederatedInformant): Promise<void>;
33
+ postMetadata(metadataID: MetadataID, metadata: string): Promise<void>;
34
+ getMetadataMap(metadataId: MetadataID): Promise<Map<string, unknown> | undefined>;
35
+ onRoundEndCommunication(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, _: number, trainingInformant: informant.FederatedInformant): Promise<WeightsContainer>;
36
+ onTrainEndCommunication(): Promise<void>;
37
+ }
@@ -0,0 +1,273 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Client = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var uuid_1 = require("uuid");
6
+ var __1 = require("../..");
7
+ var base_1 = require("../base");
8
+ var messages = (0, tslib_1.__importStar)(require("./messages"));
9
+ var messages_1 = require("../messages");
10
+ var nodeUrl = (0, tslib_1.__importStar)(require("url"));
11
+ var event_connection_1 = require("../event_connection");
12
+ var utils_1 = require("../utils");
13
+ /**
14
+ * Class that deals with communication with the centralized server when training
15
+ * a specific task in the federated setting.
16
+ */
17
+ var Client = /** @class */ (function (_super) {
18
+ (0, tslib_1.__extends)(Client, _super);
19
+ function Client() {
20
+ var _this = _super !== null && _super.apply(this, arguments) || this;
21
+ _this.clientID = (0, uuid_1.v4)();
22
+ _this.round = 0;
23
+ return _this;
24
+ }
25
+ Object.defineProperty(Client.prototype, "server", {
26
+ get: function () {
27
+ if (this._server === undefined) {
28
+ throw new Error('server undefined, not connected');
29
+ }
30
+ return this._server;
31
+ },
32
+ enumerable: false,
33
+ configurable: true
34
+ });
35
+ // It opens a new WebSocket connection and listens to new messages over the channel
36
+ Client.prototype.connectServer = function (url) {
37
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
38
+ var server;
39
+ return (0, tslib_1.__generator)(this, function (_a) {
40
+ switch (_a.label) {
41
+ case 0: return [4 /*yield*/, event_connection_1.WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated)];
42
+ case 1:
43
+ server = _a.sent();
44
+ return [2 /*return*/, server];
45
+ }
46
+ });
47
+ });
48
+ };
49
+ /**
50
+ * Initialize the connection to the server. TODO: In the case of FeAI,
51
+ * should return the current server-side round for the task.
52
+ */
53
+ Client.prototype.connect = function () {
54
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
55
+ var URL, serverURL, _a, msg;
56
+ return (0, tslib_1.__generator)(this, function (_b) {
57
+ switch (_b.label) {
58
+ case 0:
59
+ URL = typeof window !== 'undefined' ? window.URL : nodeUrl.URL;
60
+ serverURL = new URL('', this.url.href);
61
+ switch (this.url.protocol) {
62
+ case 'http:':
63
+ serverURL.protocol = 'ws:';
64
+ break;
65
+ case 'https:':
66
+ serverURL.protocol = 'wss:';
67
+ break;
68
+ default:
69
+ throw new Error("unknown protocol: " + this.url.protocol);
70
+ }
71
+ serverURL.pathname += "feai/" + this.task.taskID + "/" + this.clientID;
72
+ _a = this;
73
+ return [4 /*yield*/, this.connectServer(serverURL)];
74
+ case 1:
75
+ _a._server = _b.sent();
76
+ msg = {
77
+ type: messages_1.type.clientConnected
78
+ };
79
+ this.server.send(msg);
80
+ return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.clientConnected, utils_1.MAX_WAIT_PER_ROUND)];
81
+ case 2:
82
+ _b.sent();
83
+ this.connected = true;
84
+ return [2 /*return*/];
85
+ }
86
+ });
87
+ });
88
+ };
89
+ /**
90
+ * Disconnection process when user quits the task.
91
+ */
92
+ Client.prototype.disconnect = function () {
93
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
94
+ return (0, tslib_1.__generator)(this, function (_a) {
95
+ this.server.disconnect();
96
+ this._server = undefined;
97
+ this.connected = false;
98
+ return [2 /*return*/];
99
+ });
100
+ });
101
+ };
102
+ // It sends a message to the server
103
+ Client.prototype.sendMessage = function (msg) {
104
+ var _a;
105
+ (_a = this.server) === null || _a === void 0 ? void 0 : _a.send(msg);
106
+ };
107
+ // It sends weights to the server
108
+ Client.prototype.postWeightsToServer = function (weights) {
109
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
110
+ var msg;
111
+ var _a;
112
+ return (0, tslib_1.__generator)(this, function (_b) {
113
+ switch (_b.label) {
114
+ case 0:
115
+ _a = {
116
+ type: messages_1.type.postWeightsToServer
117
+ };
118
+ return [4 /*yield*/, __1.serialization.weights.encode(weights)];
119
+ case 1:
120
+ msg = (_a.weights = _b.sent(),
121
+ _a.round = this.round,
122
+ _a);
123
+ this.sendMessage(msg);
124
+ return [2 /*return*/];
125
+ }
126
+ });
127
+ });
128
+ };
129
+ // It retrieves the last server round and weights, but return only the server round
130
+ Client.prototype.getLatestServerRound = function () {
131
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
132
+ var msg, received;
133
+ return (0, tslib_1.__generator)(this, function (_a) {
134
+ switch (_a.label) {
135
+ case 0:
136
+ this.serverRound = undefined;
137
+ this.serverWeights = undefined;
138
+ msg = {
139
+ type: messages_1.type.latestServerRound
140
+ };
141
+ this.sendMessage(msg);
142
+ return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.latestServerRound, utils_1.MAX_WAIT_PER_ROUND)];
143
+ case 1:
144
+ received = _a.sent();
145
+ this.serverRound = received.round;
146
+ this.serverWeights = __1.serialization.weights.decode(received.weights);
147
+ return [2 /*return*/, this.serverRound];
148
+ }
149
+ });
150
+ });
151
+ };
152
+ // It retrieves the last server round and weights, but return only the server weights
153
+ Client.prototype.pullRoundAndFetchWeights = function () {
154
+ var _a;
155
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
156
+ return (0, tslib_1.__generator)(this, function (_b) {
157
+ switch (_b.label) {
158
+ case 0:
159
+ // get server round of latest model
160
+ return [4 /*yield*/, this.getLatestServerRound()];
161
+ case 1:
162
+ // get server round of latest model
163
+ _b.sent();
164
+ if (this.round < ((_a = this.serverRound) !== null && _a !== void 0 ? _a : 0)) {
165
+ // Update the local round to match the server's
166
+ this.round = this.serverRound;
167
+ return [2 /*return*/, this.serverWeights];
168
+ }
169
+ else {
170
+ return [2 /*return*/, undefined];
171
+ }
172
+ return [2 /*return*/];
173
+ }
174
+ });
175
+ });
176
+ };
177
+ // It pulls statistics from the server
178
+ Client.prototype.pullServerStatistics = function (trainingInformant) {
179
+ var _a;
180
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
181
+ var msg, received;
182
+ return (0, tslib_1.__generator)(this, function (_b) {
183
+ switch (_b.label) {
184
+ case 0:
185
+ this.receivedStatistics = undefined;
186
+ msg = {
187
+ type: messages_1.type.pullServerStatistics
188
+ };
189
+ this.sendMessage(msg);
190
+ return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.pullServerStatistics, utils_1.MAX_WAIT_PER_ROUND)];
191
+ case 1:
192
+ received = _b.sent();
193
+ this.receivedStatistics = received.statistics;
194
+ trainingInformant.update((_a = this.receivedStatistics) !== null && _a !== void 0 ? _a : {});
195
+ return [2 /*return*/];
196
+ }
197
+ });
198
+ });
199
+ };
200
+ // It posts a new metadata value to the server
201
+ Client.prototype.postMetadata = function (metadataID, metadata) {
202
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
203
+ var msg;
204
+ return (0, tslib_1.__generator)(this, function (_a) {
205
+ msg = {
206
+ type: messages_1.type.postMetadata,
207
+ taskId: this.task.taskID,
208
+ clientId: this.clientID,
209
+ round: this.round,
210
+ metadataId: metadataID,
211
+ metadata: metadata
212
+ };
213
+ this.sendMessage(msg);
214
+ return [2 /*return*/];
215
+ });
216
+ });
217
+ };
218
+ // It gets a metadata map from the server
219
+ Client.prototype.getMetadataMap = function (metadataId) {
220
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
221
+ var msg, received;
222
+ return (0, tslib_1.__generator)(this, function (_a) {
223
+ switch (_a.label) {
224
+ case 0:
225
+ this.metadataMap = undefined;
226
+ msg = {
227
+ type: messages_1.type.getMetadataMap,
228
+ taskId: this.task.taskID,
229
+ clientId: this.clientID,
230
+ round: this.round,
231
+ metadataId: metadataId
232
+ };
233
+ this.sendMessage(msg);
234
+ return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.getMetadataMap, utils_1.MAX_WAIT_PER_ROUND)];
235
+ case 1:
236
+ received = _a.sent();
237
+ if (received.metadataMap !== undefined) {
238
+ this.metadataMap = new Map(received.metadataMap);
239
+ }
240
+ return [2 /*return*/, this.metadataMap];
241
+ }
242
+ });
243
+ });
244
+ };
245
+ Client.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, _, trainingInformant) {
246
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
247
+ var noisyWeights, serverWeights;
248
+ return (0, tslib_1.__generator)(this, function (_a) {
249
+ switch (_a.label) {
250
+ case 0:
251
+ noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
252
+ return [4 /*yield*/, this.postWeightsToServer(noisyWeights)];
253
+ case 1:
254
+ _a.sent();
255
+ return [4 /*yield*/, this.pullServerStatistics(trainingInformant)];
256
+ case 2:
257
+ _a.sent();
258
+ return [4 /*yield*/, this.pullRoundAndFetchWeights()];
259
+ case 3:
260
+ serverWeights = _a.sent();
261
+ return [2 /*return*/, serverWeights !== null && serverWeights !== void 0 ? serverWeights : staleWeights];
262
+ }
263
+ });
264
+ });
265
+ };
266
+ Client.prototype.onTrainEndCommunication = function () {
267
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
268
+ return [2 /*return*/];
269
+ }); });
270
+ };
271
+ return Client;
272
+ }(base_1.Base));
273
+ exports.Client = Client;
@@ -0,0 +1,2 @@
1
+ export { Client } from './client';
2
+ export * as messages from './messages';
@@ -0,0 +1,7 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.messages = exports.Client = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var client_1 = require("./client");
6
+ Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Client; } });
7
+ exports.messages = (0, tslib_1.__importStar)(require("./messages"));
@@ -0,0 +1,38 @@
1
+ import { MetadataID } from '../..';
2
+ import { weights } from '../../serialization';
3
+ import { type } from '../messages';
4
+ export declare type MessageFederated = postWeightsToServer | latestServerRound | pullServerStatistics | postMetadata | getMetadataMap | messageGeneral;
5
+ export interface messageGeneral {
6
+ type: type;
7
+ }
8
+ export interface postWeightsToServer {
9
+ type: type.postWeightsToServer;
10
+ weights: weights.Encoded;
11
+ round: number;
12
+ }
13
+ export interface latestServerRound {
14
+ type: type.latestServerRound;
15
+ weights: weights.Encoded;
16
+ round: number;
17
+ }
18
+ export interface pullServerStatistics {
19
+ type: type.pullServerStatistics;
20
+ statistics: Record<string, number>;
21
+ }
22
+ export interface postMetadata {
23
+ type: type.postMetadata;
24
+ clientId: string;
25
+ taskId: string;
26
+ round: number;
27
+ metadataId: string;
28
+ metadata: string;
29
+ }
30
+ export interface getMetadataMap {
31
+ type: type.getMetadataMap;
32
+ clientId: string;
33
+ taskId: string;
34
+ round: number;
35
+ metadataId: MetadataID;
36
+ metadataMap?: Array<[string, string | undefined]>;
37
+ }
38
+ export declare function isMessageFederated(o: unknown): o is MessageFederated;
@@ -0,0 +1,25 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.isMessageFederated = void 0;
4
+ var messages_1 = require("../messages");
5
+ function isMessageFederated(o) {
6
+ if (!(0, messages_1.hasMessageType)(o)) {
7
+ return false;
8
+ }
9
+ switch (o.type) {
10
+ case messages_1.type.clientConnected:
11
+ return true;
12
+ case messages_1.type.postWeightsToServer:
13
+ return true;
14
+ case messages_1.type.latestServerRound:
15
+ return true;
16
+ case messages_1.type.pullServerStatistics:
17
+ return true;
18
+ case messages_1.type.postMetadata:
19
+ return true;
20
+ case messages_1.type.getMetadataMap:
21
+ return true;
22
+ }
23
+ return false;
24
+ }
25
+ exports.isMessageFederated = isMessageFederated;
@@ -1,4 +1,5 @@
1
1
  export { Base } from './base';
2
2
  export * as decentralized from './decentralized';
3
- export { Federated } from './federated';
3
+ export * as federated from './federated';
4
+ export * as messages from './messages';
4
5
  export { Local } from './local';
@@ -1,11 +1,11 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Local = exports.Federated = exports.decentralized = exports.Base = void 0;
3
+ exports.Local = exports.messages = exports.federated = exports.decentralized = exports.Base = void 0;
4
4
  var tslib_1 = require("tslib");
5
5
  var base_1 = require("./base");
6
6
  Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
7
7
  exports.decentralized = (0, tslib_1.__importStar)(require("./decentralized"));
8
- var federated_1 = require("./federated");
9
- Object.defineProperty(exports, "Federated", { enumerable: true, get: function () { return federated_1.Federated; } });
8
+ exports.federated = (0, tslib_1.__importStar)(require("./federated"));
9
+ exports.messages = (0, tslib_1.__importStar)(require("./messages"));
10
10
  var local_1 = require("./local");
11
11
  Object.defineProperty(exports, "Local", { enumerable: true, get: function () { return local_1.Local; } });
@@ -1,8 +1,8 @@
1
- import { Weights } from '../types';
1
+ import { WeightsContainer } from '..';
2
2
  import { Base } from './base';
3
3
  export declare class Local extends Base {
4
4
  connect(): Promise<void>;
5
5
  disconnect(): Promise<void>;
6
- onRoundEndCommunication(_: Weights): Promise<Weights>;
6
+ onRoundEndCommunication(_: WeightsContainer): Promise<WeightsContainer>;
7
7
  onTrainEndCommunication(): Promise<void>;
8
8
  }
File without changes
@@ -0,0 +1,28 @@
1
+ import * as decentralized from './decentralized/messages';
2
+ import * as federated from './federated/messages';
3
+ export declare enum type {
4
+ clientConnected = 0,
5
+ PeerID = 1,
6
+ SignalForPeer = 2,
7
+ PeerIsReady = 3,
8
+ PeersForRound = 4,
9
+ Weights = 5,
10
+ Shares = 6,
11
+ PartialSums = 7,
12
+ postWeightsToServer = 8,
13
+ postMetadata = 9,
14
+ getMetadataMap = 10,
15
+ latestServerRound = 11,
16
+ pullRoundAndFetchWeights = 12,
17
+ pullServerStatistics = 13
18
+ }
19
+ export interface clientConnected {
20
+ type: type.clientConnected;
21
+ }
22
+ export declare type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
23
+ export declare type NarrowMessage<D> = Extract<Message, {
24
+ type: D;
25
+ }>;
26
+ export declare function hasMessageType(raw: unknown): raw is {
27
+ type: type;
28
+ } & Record<string, unknown>;
@@ -0,0 +1,33 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.hasMessageType = exports.type = void 0;
4
+ var type;
5
+ (function (type) {
6
+ type[type["clientConnected"] = 0] = "clientConnected";
7
+ // decentralized
8
+ type[type["PeerID"] = 1] = "PeerID";
9
+ type[type["SignalForPeer"] = 2] = "SignalForPeer";
10
+ type[type["PeerIsReady"] = 3] = "PeerIsReady";
11
+ type[type["PeersForRound"] = 4] = "PeersForRound";
12
+ type[type["Weights"] = 5] = "Weights";
13
+ type[type["Shares"] = 6] = "Shares";
14
+ type[type["PartialSums"] = 7] = "PartialSums";
15
+ // federated
16
+ type[type["postWeightsToServer"] = 8] = "postWeightsToServer";
17
+ type[type["postMetadata"] = 9] = "postMetadata";
18
+ type[type["getMetadataMap"] = 10] = "getMetadataMap";
19
+ type[type["latestServerRound"] = 11] = "latestServerRound";
20
+ type[type["pullRoundAndFetchWeights"] = 12] = "pullRoundAndFetchWeights";
21
+ type[type["pullServerStatistics"] = 13] = "pullServerStatistics";
22
+ })(type = exports.type || (exports.type = {}));
23
+ function hasMessageType(raw) {
24
+ if (typeof raw !== 'object' || raw === null) {
25
+ return false;
26
+ }
27
+ var o = raw;
28
+ if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
29
+ return false;
30
+ }
31
+ return true;
32
+ }
33
+ exports.hasMessageType = hasMessageType;
@@ -0,0 +1,2 @@
1
+ export declare const MAX_WAIT_PER_ROUND = 10000;
2
+ export declare function timeout(ms: number): Promise<never>;
@@ -0,0 +1,19 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.timeout = exports.MAX_WAIT_PER_ROUND = void 0;
4
+ var tslib_1 = require("tslib");
5
+ // Time to wait for the others in milliseconds.
6
+ exports.MAX_WAIT_PER_ROUND = 10000;
7
+ function timeout(ms) {
8
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
9
+ return (0, tslib_1.__generator)(this, function (_a) {
10
+ switch (_a.label) {
11
+ case 0: return [4 /*yield*/, new Promise(function (resolve, reject) {
12
+ setTimeout(function () { return reject(new Error('timeout')); }, ms);
13
+ })];
14
+ case 1: return [2 /*return*/, _a.sent()];
15
+ }
16
+ });
17
+ });
18
+ }
19
+ exports.timeout = timeout;