@epfml/discojs-node 2.0.0 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. package/dist/data/image_loader.d.ts +5 -0
  2. package/dist/data/image_loader.js +11 -0
  3. package/dist/data/index.d.ts +3 -0
  4. package/dist/data/index.js +3 -0
  5. package/dist/data/tabular_loader.d.ts +4 -0
  6. package/dist/data/tabular_loader.js +11 -0
  7. package/dist/data/text_loader.d.ts +4 -0
  8. package/dist/data/text_loader.js +14 -0
  9. package/dist/index.d.ts +2 -2
  10. package/dist/index.js +2 -6
  11. package/package.json +14 -17
  12. package/README.md +0 -53
  13. package/dist/core/async_buffer.d.ts +0 -41
  14. package/dist/core/async_buffer.js +0 -97
  15. package/dist/core/async_informant.d.ts +0 -20
  16. package/dist/core/async_informant.js +0 -69
  17. package/dist/core/client/base.d.ts +0 -33
  18. package/dist/core/client/base.js +0 -35
  19. package/dist/core/client/decentralized/base.d.ts +0 -32
  20. package/dist/core/client/decentralized/base.js +0 -212
  21. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  22. package/dist/core/client/decentralized/clear_text.js +0 -96
  23. package/dist/core/client/decentralized/index.d.ts +0 -4
  24. package/dist/core/client/decentralized/index.js +0 -9
  25. package/dist/core/client/decentralized/messages.d.ts +0 -41
  26. package/dist/core/client/decentralized/messages.js +0 -54
  27. package/dist/core/client/decentralized/peer.d.ts +0 -26
  28. package/dist/core/client/decentralized/peer.js +0 -210
  29. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  30. package/dist/core/client/decentralized/peer_pool.js +0 -92
  31. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  32. package/dist/core/client/decentralized/sec_agg.js +0 -190
  33. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  34. package/dist/core/client/decentralized/secret_shares.js +0 -39
  35. package/dist/core/client/decentralized/types.d.ts +0 -2
  36. package/dist/core/client/decentralized/types.js +0 -7
  37. package/dist/core/client/event_connection.d.ts +0 -37
  38. package/dist/core/client/event_connection.js +0 -158
  39. package/dist/core/client/federated/client.d.ts +0 -37
  40. package/dist/core/client/federated/client.js +0 -273
  41. package/dist/core/client/federated/index.d.ts +0 -2
  42. package/dist/core/client/federated/index.js +0 -7
  43. package/dist/core/client/federated/messages.d.ts +0 -38
  44. package/dist/core/client/federated/messages.js +0 -25
  45. package/dist/core/client/index.d.ts +0 -5
  46. package/dist/core/client/index.js +0 -11
  47. package/dist/core/client/local.d.ts +0 -8
  48. package/dist/core/client/local.js +0 -36
  49. package/dist/core/client/messages.d.ts +0 -28
  50. package/dist/core/client/messages.js +0 -33
  51. package/dist/core/client/utils.d.ts +0 -2
  52. package/dist/core/client/utils.js +0 -19
  53. package/dist/core/dataset/data/data.d.ts +0 -11
  54. package/dist/core/dataset/data/data.js +0 -20
  55. package/dist/core/dataset/data/data_split.d.ts +0 -5
  56. package/dist/core/dataset/data/data_split.js +0 -2
  57. package/dist/core/dataset/data/image_data.d.ts +0 -8
  58. package/dist/core/dataset/data/image_data.js +0 -64
  59. package/dist/core/dataset/data/index.d.ts +0 -5
  60. package/dist/core/dataset/data/index.js +0 -11
  61. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  62. package/dist/core/dataset/data/preprocessing.js +0 -33
  63. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  64. package/dist/core/dataset/data/tabular_data.js +0 -40
  65. package/dist/core/dataset/data_loader/data_loader.d.ts +0 -15
  66. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  67. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  68. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  69. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  70. package/dist/core/dataset/data_loader/index.js +0 -9
  71. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  72. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  73. package/dist/core/dataset/dataset.d.ts +0 -2
  74. package/dist/core/dataset/dataset.js +0 -2
  75. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  76. package/dist/core/dataset/dataset_builder.js +0 -96
  77. package/dist/core/dataset/index.d.ts +0 -4
  78. package/dist/core/dataset/index.js +0 -14
  79. package/dist/core/index.d.ts +0 -18
  80. package/dist/core/index.js +0 -41
  81. package/dist/core/informant/graph_informant.d.ts +0 -10
  82. package/dist/core/informant/graph_informant.js +0 -23
  83. package/dist/core/informant/index.d.ts +0 -3
  84. package/dist/core/informant/index.js +0 -9
  85. package/dist/core/informant/training_informant/base.d.ts +0 -31
  86. package/dist/core/informant/training_informant/base.js +0 -83
  87. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  88. package/dist/core/informant/training_informant/decentralized.js +0 -22
  89. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  90. package/dist/core/informant/training_informant/federated.js +0 -32
  91. package/dist/core/informant/training_informant/index.d.ts +0 -4
  92. package/dist/core/informant/training_informant/index.js +0 -11
  93. package/dist/core/informant/training_informant/local.d.ts +0 -6
  94. package/dist/core/informant/training_informant/local.js +0 -20
  95. package/dist/core/logging/console_logger.d.ts +0 -18
  96. package/dist/core/logging/console_logger.js +0 -33
  97. package/dist/core/logging/index.d.ts +0 -3
  98. package/dist/core/logging/index.js +0 -9
  99. package/dist/core/logging/logger.d.ts +0 -12
  100. package/dist/core/logging/logger.js +0 -9
  101. package/dist/core/logging/trainer_logger.d.ts +0 -24
  102. package/dist/core/logging/trainer_logger.js +0 -59
  103. package/dist/core/memory/base.d.ts +0 -22
  104. package/dist/core/memory/base.js +0 -9
  105. package/dist/core/memory/empty.d.ts +0 -14
  106. package/dist/core/memory/empty.js +0 -75
  107. package/dist/core/memory/index.d.ts +0 -3
  108. package/dist/core/memory/index.js +0 -9
  109. package/dist/core/memory/model_type.d.ts +0 -4
  110. package/dist/core/memory/model_type.js +0 -9
  111. package/dist/core/privacy.d.ts +0 -11
  112. package/dist/core/privacy.js +0 -47
  113. package/dist/core/serialization/index.d.ts +0 -2
  114. package/dist/core/serialization/index.js +0 -6
  115. package/dist/core/serialization/model.d.ts +0 -5
  116. package/dist/core/serialization/model.js +0 -55
  117. package/dist/core/serialization/weights.d.ts +0 -5
  118. package/dist/core/serialization/weights.js +0 -64
  119. package/dist/core/task/data_example.d.ts +0 -5
  120. package/dist/core/task/data_example.js +0 -24
  121. package/dist/core/task/display_information.d.ts +0 -15
  122. package/dist/core/task/display_information.js +0 -49
  123. package/dist/core/task/index.d.ts +0 -3
  124. package/dist/core/task/index.js +0 -8
  125. package/dist/core/task/model_compile_data.d.ts +0 -6
  126. package/dist/core/task/model_compile_data.js +0 -22
  127. package/dist/core/task/summary.d.ts +0 -5
  128. package/dist/core/task/summary.js +0 -19
  129. package/dist/core/task/task.d.ts +0 -10
  130. package/dist/core/task/task.js +0 -31
  131. package/dist/core/task/training_information.d.ts +0 -28
  132. package/dist/core/task/training_information.js +0 -66
  133. package/dist/core/tasks/cifar10.d.ts +0 -3
  134. package/dist/core/tasks/cifar10.js +0 -65
  135. package/dist/core/tasks/geotags.d.ts +0 -3
  136. package/dist/core/tasks/geotags.js +0 -67
  137. package/dist/core/tasks/index.d.ts +0 -6
  138. package/dist/core/tasks/index.js +0 -10
  139. package/dist/core/tasks/lus_covid.d.ts +0 -3
  140. package/dist/core/tasks/lus_covid.js +0 -87
  141. package/dist/core/tasks/mnist.d.ts +0 -3
  142. package/dist/core/tasks/mnist.js +0 -60
  143. package/dist/core/tasks/simple_face.d.ts +0 -2
  144. package/dist/core/tasks/simple_face.js +0 -41
  145. package/dist/core/tasks/titanic.d.ts +0 -3
  146. package/dist/core/tasks/titanic.js +0 -88
  147. package/dist/core/training/disco.d.ts +0 -23
  148. package/dist/core/training/disco.js +0 -130
  149. package/dist/core/training/index.d.ts +0 -2
  150. package/dist/core/training/index.js +0 -7
  151. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  152. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  153. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  154. package/dist/core/training/trainer/local_trainer.js +0 -34
  155. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  156. package/dist/core/training/trainer/round_tracker.js +0 -47
  157. package/dist/core/training/trainer/trainer.d.ts +0 -65
  158. package/dist/core/training/trainer/trainer.js +0 -160
  159. package/dist/core/training/trainer/trainer_builder.d.ts +0 -25
  160. package/dist/core/training/trainer/trainer_builder.js +0 -95
  161. package/dist/core/training/training_schemes.d.ts +0 -5
  162. package/dist/core/training/training_schemes.js +0 -10
  163. package/dist/core/types.d.ts +0 -4
  164. package/dist/core/types.js +0 -2
  165. package/dist/core/validation/index.d.ts +0 -1
  166. package/dist/core/validation/index.js +0 -5
  167. package/dist/core/validation/validator.d.ts +0 -17
  168. package/dist/core/validation/validator.js +0 -104
  169. package/dist/core/weights/aggregation.d.ts +0 -8
  170. package/dist/core/weights/aggregation.js +0 -96
  171. package/dist/core/weights/index.d.ts +0 -2
  172. package/dist/core/weights/index.js +0 -7
  173. package/dist/core/weights/weights_container.d.ts +0 -19
  174. package/dist/core/weights/weights_container.js +0 -64
  175. package/dist/dataset/data_loader/image_loader.d.ts +0 -4
  176. package/dist/dataset/data_loader/image_loader.js +0 -21
  177. package/dist/dataset/data_loader/index.d.ts +0 -2
  178. package/dist/dataset/data_loader/index.js +0 -7
  179. package/dist/dataset/data_loader/tabular_loader.d.ts +0 -4
  180. package/dist/dataset/data_loader/tabular_loader.js +0 -20
  181. package/dist/imports.d.ts +0 -1
  182. package/dist/imports.js +0 -5
@@ -1,273 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Client = void 0;
4
- var tslib_1 = require("tslib");
5
- var uuid_1 = require("uuid");
6
- var __1 = require("../..");
7
- var base_1 = require("../base");
8
- var messages = (0, tslib_1.__importStar)(require("./messages"));
9
- var messages_1 = require("../messages");
10
- var nodeUrl = (0, tslib_1.__importStar)(require("url"));
11
- var event_connection_1 = require("../event_connection");
12
- var utils_1 = require("../utils");
13
- /**
14
- * Class that deals with communication with the centralized server when training
15
- * a specific task in the federated setting.
16
- */
17
- var Client = /** @class */ (function (_super) {
18
- (0, tslib_1.__extends)(Client, _super);
19
- function Client() {
20
- var _this = _super !== null && _super.apply(this, arguments) || this;
21
- _this.clientID = (0, uuid_1.v4)();
22
- _this.round = 0;
23
- return _this;
24
- }
25
- Object.defineProperty(Client.prototype, "server", {
26
- get: function () {
27
- if (this._server === undefined) {
28
- throw new Error('server undefined, not connected');
29
- }
30
- return this._server;
31
- },
32
- enumerable: false,
33
- configurable: true
34
- });
35
- // It opens a new WebSocket connection and listens to new messages over the channel
36
- Client.prototype.connectServer = function (url) {
37
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
38
- var server;
39
- return (0, tslib_1.__generator)(this, function (_a) {
40
- switch (_a.label) {
41
- case 0: return [4 /*yield*/, event_connection_1.WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated)];
42
- case 1:
43
- server = _a.sent();
44
- return [2 /*return*/, server];
45
- }
46
- });
47
- });
48
- };
49
- /**
50
- * Initialize the connection to the server. TODO: In the case of FeAI,
51
- * should return the current server-side round for the task.
52
- */
53
- Client.prototype.connect = function () {
54
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
55
- var URL, serverURL, _a, msg;
56
- return (0, tslib_1.__generator)(this, function (_b) {
57
- switch (_b.label) {
58
- case 0:
59
- URL = typeof window !== 'undefined' ? window.URL : nodeUrl.URL;
60
- serverURL = new URL('', this.url.href);
61
- switch (this.url.protocol) {
62
- case 'http:':
63
- serverURL.protocol = 'ws:';
64
- break;
65
- case 'https:':
66
- serverURL.protocol = 'wss:';
67
- break;
68
- default:
69
- throw new Error("unknown protocol: " + this.url.protocol);
70
- }
71
- serverURL.pathname += "feai/" + this.task.taskID + "/" + this.clientID;
72
- _a = this;
73
- return [4 /*yield*/, this.connectServer(serverURL)];
74
- case 1:
75
- _a._server = _b.sent();
76
- msg = {
77
- type: messages_1.type.clientConnected
78
- };
79
- this.server.send(msg);
80
- return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.clientConnected, utils_1.MAX_WAIT_PER_ROUND)];
81
- case 2:
82
- _b.sent();
83
- this.connected = true;
84
- return [2 /*return*/];
85
- }
86
- });
87
- });
88
- };
89
- /**
90
- * Disconnection process when user quits the task.
91
- */
92
- Client.prototype.disconnect = function () {
93
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
94
- return (0, tslib_1.__generator)(this, function (_a) {
95
- this.server.disconnect();
96
- this._server = undefined;
97
- this.connected = false;
98
- return [2 /*return*/];
99
- });
100
- });
101
- };
102
- // It sends a message to the server
103
- Client.prototype.sendMessage = function (msg) {
104
- var _a;
105
- (_a = this.server) === null || _a === void 0 ? void 0 : _a.send(msg);
106
- };
107
- // It sends weights to the server
108
- Client.prototype.postWeightsToServer = function (weights) {
109
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
110
- var msg;
111
- var _a;
112
- return (0, tslib_1.__generator)(this, function (_b) {
113
- switch (_b.label) {
114
- case 0:
115
- _a = {
116
- type: messages_1.type.postWeightsToServer
117
- };
118
- return [4 /*yield*/, __1.serialization.weights.encode(weights)];
119
- case 1:
120
- msg = (_a.weights = _b.sent(),
121
- _a.round = this.round,
122
- _a);
123
- this.sendMessage(msg);
124
- return [2 /*return*/];
125
- }
126
- });
127
- });
128
- };
129
- // It retrieves the last server round and weights, but return only the server round
130
- Client.prototype.getLatestServerRound = function () {
131
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
132
- var msg, received;
133
- return (0, tslib_1.__generator)(this, function (_a) {
134
- switch (_a.label) {
135
- case 0:
136
- this.serverRound = undefined;
137
- this.serverWeights = undefined;
138
- msg = {
139
- type: messages_1.type.latestServerRound
140
- };
141
- this.sendMessage(msg);
142
- return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.latestServerRound, utils_1.MAX_WAIT_PER_ROUND)];
143
- case 1:
144
- received = _a.sent();
145
- this.serverRound = received.round;
146
- this.serverWeights = __1.serialization.weights.decode(received.weights);
147
- return [2 /*return*/, this.serverRound];
148
- }
149
- });
150
- });
151
- };
152
- // It retrieves the last server round and weights, but return only the server weights
153
- Client.prototype.pullRoundAndFetchWeights = function () {
154
- var _a;
155
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
156
- return (0, tslib_1.__generator)(this, function (_b) {
157
- switch (_b.label) {
158
- case 0:
159
- // get server round of latest model
160
- return [4 /*yield*/, this.getLatestServerRound()];
161
- case 1:
162
- // get server round of latest model
163
- _b.sent();
164
- if (this.round < ((_a = this.serverRound) !== null && _a !== void 0 ? _a : 0)) {
165
- // Update the local round to match the server's
166
- this.round = this.serverRound;
167
- return [2 /*return*/, this.serverWeights];
168
- }
169
- else {
170
- return [2 /*return*/, undefined];
171
- }
172
- return [2 /*return*/];
173
- }
174
- });
175
- });
176
- };
177
- // It pulls statistics from the server
178
- Client.prototype.pullServerStatistics = function (trainingInformant) {
179
- var _a;
180
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
181
- var msg, received;
182
- return (0, tslib_1.__generator)(this, function (_b) {
183
- switch (_b.label) {
184
- case 0:
185
- this.receivedStatistics = undefined;
186
- msg = {
187
- type: messages_1.type.pullServerStatistics
188
- };
189
- this.sendMessage(msg);
190
- return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.pullServerStatistics, utils_1.MAX_WAIT_PER_ROUND)];
191
- case 1:
192
- received = _b.sent();
193
- this.receivedStatistics = received.statistics;
194
- trainingInformant.update((_a = this.receivedStatistics) !== null && _a !== void 0 ? _a : {});
195
- return [2 /*return*/];
196
- }
197
- });
198
- });
199
- };
200
- // It posts a new metadata value to the server
201
- Client.prototype.postMetadata = function (metadataID, metadata) {
202
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
203
- var msg;
204
- return (0, tslib_1.__generator)(this, function (_a) {
205
- msg = {
206
- type: messages_1.type.postMetadata,
207
- taskId: this.task.taskID,
208
- clientId: this.clientID,
209
- round: this.round,
210
- metadataId: metadataID,
211
- metadata: metadata
212
- };
213
- this.sendMessage(msg);
214
- return [2 /*return*/];
215
- });
216
- });
217
- };
218
- // It gets a metadata map from the server
219
- Client.prototype.getMetadataMap = function (metadataId) {
220
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
221
- var msg, received;
222
- return (0, tslib_1.__generator)(this, function (_a) {
223
- switch (_a.label) {
224
- case 0:
225
- this.metadataMap = undefined;
226
- msg = {
227
- type: messages_1.type.getMetadataMap,
228
- taskId: this.task.taskID,
229
- clientId: this.clientID,
230
- round: this.round,
231
- metadataId: metadataId
232
- };
233
- this.sendMessage(msg);
234
- return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.getMetadataMap, utils_1.MAX_WAIT_PER_ROUND)];
235
- case 1:
236
- received = _a.sent();
237
- if (received.metadataMap !== undefined) {
238
- this.metadataMap = new Map(received.metadataMap);
239
- }
240
- return [2 /*return*/, this.metadataMap];
241
- }
242
- });
243
- });
244
- };
245
- Client.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, _, trainingInformant) {
246
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
247
- var noisyWeights, serverWeights;
248
- return (0, tslib_1.__generator)(this, function (_a) {
249
- switch (_a.label) {
250
- case 0:
251
- noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
252
- return [4 /*yield*/, this.postWeightsToServer(noisyWeights)];
253
- case 1:
254
- _a.sent();
255
- return [4 /*yield*/, this.pullServerStatistics(trainingInformant)];
256
- case 2:
257
- _a.sent();
258
- return [4 /*yield*/, this.pullRoundAndFetchWeights()];
259
- case 3:
260
- serverWeights = _a.sent();
261
- return [2 /*return*/, serverWeights !== null && serverWeights !== void 0 ? serverWeights : staleWeights];
262
- }
263
- });
264
- });
265
- };
266
- Client.prototype.onTrainEndCommunication = function () {
267
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
268
- return [2 /*return*/];
269
- }); });
270
- };
271
- return Client;
272
- }(base_1.Base));
273
- exports.Client = Client;
@@ -1,2 +0,0 @@
1
- export { Client } from './client';
2
- export * as messages from './messages';
@@ -1,7 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.messages = exports.Client = void 0;
4
- var tslib_1 = require("tslib");
5
- var client_1 = require("./client");
6
- Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Client; } });
7
- exports.messages = (0, tslib_1.__importStar)(require("./messages"));
@@ -1,38 +0,0 @@
1
- import { MetadataID } from '../..';
2
- import { weights } from '../../serialization';
3
- import { type } from '../messages';
4
- export declare type MessageFederated = postWeightsToServer | latestServerRound | pullServerStatistics | postMetadata | getMetadataMap | messageGeneral;
5
- export interface messageGeneral {
6
- type: type;
7
- }
8
- export interface postWeightsToServer {
9
- type: type.postWeightsToServer;
10
- weights: weights.Encoded;
11
- round: number;
12
- }
13
- export interface latestServerRound {
14
- type: type.latestServerRound;
15
- weights: weights.Encoded;
16
- round: number;
17
- }
18
- export interface pullServerStatistics {
19
- type: type.pullServerStatistics;
20
- statistics: Record<string, number>;
21
- }
22
- export interface postMetadata {
23
- type: type.postMetadata;
24
- clientId: string;
25
- taskId: string;
26
- round: number;
27
- metadataId: string;
28
- metadata: string;
29
- }
30
- export interface getMetadataMap {
31
- type: type.getMetadataMap;
32
- clientId: string;
33
- taskId: string;
34
- round: number;
35
- metadataId: MetadataID;
36
- metadataMap?: Array<[string, string | undefined]>;
37
- }
38
- export declare function isMessageFederated(o: unknown): o is MessageFederated;
@@ -1,25 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isMessageFederated = void 0;
4
- var messages_1 = require("../messages");
5
- function isMessageFederated(o) {
6
- if (!(0, messages_1.hasMessageType)(o)) {
7
- return false;
8
- }
9
- switch (o.type) {
10
- case messages_1.type.clientConnected:
11
- return true;
12
- case messages_1.type.postWeightsToServer:
13
- return true;
14
- case messages_1.type.latestServerRound:
15
- return true;
16
- case messages_1.type.pullServerStatistics:
17
- return true;
18
- case messages_1.type.postMetadata:
19
- return true;
20
- case messages_1.type.getMetadataMap:
21
- return true;
22
- }
23
- return false;
24
- }
25
- exports.isMessageFederated = isMessageFederated;
@@ -1,5 +0,0 @@
1
- export { Base } from './base';
2
- export * as decentralized from './decentralized';
3
- export * as federated from './federated';
4
- export * as messages from './messages';
5
- export { Local } from './local';
@@ -1,11 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Local = exports.messages = exports.federated = exports.decentralized = exports.Base = void 0;
4
- var tslib_1 = require("tslib");
5
- var base_1 = require("./base");
6
- Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
7
- exports.decentralized = (0, tslib_1.__importStar)(require("./decentralized"));
8
- exports.federated = (0, tslib_1.__importStar)(require("./federated"));
9
- exports.messages = (0, tslib_1.__importStar)(require("./messages"));
10
- var local_1 = require("./local");
11
- Object.defineProperty(exports, "Local", { enumerable: true, get: function () { return local_1.Local; } });
@@ -1,8 +0,0 @@
1
- import { WeightsContainer } from '..';
2
- import { Base } from './base';
3
- export declare class Local extends Base {
4
- connect(): Promise<void>;
5
- disconnect(): Promise<void>;
6
- onRoundEndCommunication(_: WeightsContainer): Promise<WeightsContainer>;
7
- onTrainEndCommunication(): Promise<void>;
8
- }
@@ -1,36 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Local = void 0;
4
- var tslib_1 = require("tslib");
5
- var base_1 = require("./base");
6
- // does pretty much nothing
7
- var Local = /** @class */ (function (_super) {
8
- (0, tslib_1.__extends)(Local, _super);
9
- function Local() {
10
- return _super !== null && _super.apply(this, arguments) || this;
11
- }
12
- Local.prototype.connect = function () {
13
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
14
- return [2 /*return*/];
15
- }); });
16
- };
17
- Local.prototype.disconnect = function () {
18
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
19
- return [2 /*return*/];
20
- }); });
21
- };
22
- Local.prototype.onRoundEndCommunication = function (_) {
23
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
24
- return (0, tslib_1.__generator)(this, function (_a) {
25
- return [2 /*return*/, _];
26
- });
27
- });
28
- };
29
- Local.prototype.onTrainEndCommunication = function () {
30
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
31
- return [2 /*return*/];
32
- }); });
33
- };
34
- return Local;
35
- }(base_1.Base));
36
- exports.Local = Local;
@@ -1,28 +0,0 @@
1
- import * as decentralized from './decentralized/messages';
2
- import * as federated from './federated/messages';
3
- export declare enum type {
4
- clientConnected = 0,
5
- PeerID = 1,
6
- SignalForPeer = 2,
7
- PeerIsReady = 3,
8
- PeersForRound = 4,
9
- Weights = 5,
10
- Shares = 6,
11
- PartialSums = 7,
12
- postWeightsToServer = 8,
13
- postMetadata = 9,
14
- getMetadataMap = 10,
15
- latestServerRound = 11,
16
- pullRoundAndFetchWeights = 12,
17
- pullServerStatistics = 13
18
- }
19
- export interface clientConnected {
20
- type: type.clientConnected;
21
- }
22
- export declare type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
23
- export declare type NarrowMessage<D> = Extract<Message, {
24
- type: D;
25
- }>;
26
- export declare function hasMessageType(raw: unknown): raw is {
27
- type: type;
28
- } & Record<string, unknown>;
@@ -1,33 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.hasMessageType = exports.type = void 0;
4
- var type;
5
- (function (type) {
6
- type[type["clientConnected"] = 0] = "clientConnected";
7
- // decentralized
8
- type[type["PeerID"] = 1] = "PeerID";
9
- type[type["SignalForPeer"] = 2] = "SignalForPeer";
10
- type[type["PeerIsReady"] = 3] = "PeerIsReady";
11
- type[type["PeersForRound"] = 4] = "PeersForRound";
12
- type[type["Weights"] = 5] = "Weights";
13
- type[type["Shares"] = 6] = "Shares";
14
- type[type["PartialSums"] = 7] = "PartialSums";
15
- // federated
16
- type[type["postWeightsToServer"] = 8] = "postWeightsToServer";
17
- type[type["postMetadata"] = 9] = "postMetadata";
18
- type[type["getMetadataMap"] = 10] = "getMetadataMap";
19
- type[type["latestServerRound"] = 11] = "latestServerRound";
20
- type[type["pullRoundAndFetchWeights"] = 12] = "pullRoundAndFetchWeights";
21
- type[type["pullServerStatistics"] = 13] = "pullServerStatistics";
22
- })(type = exports.type || (exports.type = {}));
23
- function hasMessageType(raw) {
24
- if (typeof raw !== 'object' || raw === null) {
25
- return false;
26
- }
27
- var o = raw;
28
- if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
29
- return false;
30
- }
31
- return true;
32
- }
33
- exports.hasMessageType = hasMessageType;
@@ -1,2 +0,0 @@
1
- export declare const MAX_WAIT_PER_ROUND = 10000;
2
- export declare function timeout(ms: number): Promise<never>;
@@ -1,19 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.timeout = exports.MAX_WAIT_PER_ROUND = void 0;
4
- var tslib_1 = require("tslib");
5
- // Time to wait for the others in milliseconds.
6
- exports.MAX_WAIT_PER_ROUND = 10000;
7
- function timeout(ms) {
8
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
9
- return (0, tslib_1.__generator)(this, function (_a) {
10
- switch (_a.label) {
11
- case 0: return [4 /*yield*/, new Promise(function (resolve, reject) {
12
- setTimeout(function () { return reject(new Error('timeout')); }, ms);
13
- })];
14
- case 1: return [2 /*return*/, _a.sent()];
15
- }
16
- });
17
- });
18
- }
19
- exports.timeout = timeout;
@@ -1,11 +0,0 @@
1
- import { Task } from '../..';
2
- import { Dataset } from '../dataset';
3
- export declare abstract class Data {
4
- readonly dataset: Dataset;
5
- readonly task: Task;
6
- readonly size?: number | undefined;
7
- protected constructor(dataset: Dataset, task: Task, size?: number | undefined);
8
- static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
9
- abstract batch(): Data;
10
- abstract preprocess(): Data;
11
- }
@@ -1,20 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Data = void 0;
4
- var tslib_1 = require("tslib");
5
- var Data = /** @class */ (function () {
6
- function Data(dataset, task, size) {
7
- this.dataset = dataset;
8
- this.task = task;
9
- this.size = size;
10
- }
11
- Data.init = function (dataset, task, size) {
12
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
13
- return (0, tslib_1.__generator)(this, function (_a) {
14
- throw new Error('abstract');
15
- });
16
- });
17
- };
18
- return Data;
19
- }());
20
- exports.Data = Data;
@@ -1,5 +0,0 @@
1
- import { Data } from './data';
2
- export interface DataSplit {
3
- train: Data;
4
- validation?: Data;
5
- }
@@ -1,2 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
@@ -1,8 +0,0 @@
1
- import { Task } from '../..';
2
- import { Dataset } from '../dataset';
3
- import { Data } from './data';
4
- export declare class ImageData extends Data {
5
- static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
6
- batch(): Data;
7
- preprocess(): Data;
8
- }
@@ -1,64 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ImageData = void 0;
4
- var tslib_1 = require("tslib");
5
- var preprocessing_1 = require("./preprocessing");
6
- var data_1 = require("./data");
7
- var ImageData = /** @class */ (function (_super) {
8
- (0, tslib_1.__extends)(ImageData, _super);
9
- function ImageData() {
10
- return _super !== null && _super.apply(this, arguments) || this;
11
- }
12
- ImageData.init = function (dataset, task, size) {
13
- var _a;
14
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
15
- var sample, shape, e_1;
16
- return (0, tslib_1.__generator)(this, function (_b) {
17
- switch (_b.label) {
18
- case 0:
19
- if (!!((_a = task.trainingInformation.preprocessingFunctions) === null || _a === void 0 ? void 0 : _a.includes(preprocessing_1.ImagePreprocessing.Resize))) return [3 /*break*/, 4];
20
- _b.label = 1;
21
- case 1:
22
- _b.trys.push([1, 3, , 4]);
23
- return [4 /*yield*/, dataset.take(1).toArray()];
24
- case 2:
25
- sample = (_b.sent())[0];
26
- // TODO: We suppose the presence of labels
27
- // TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
28
- if (!(typeof sample === 'object' && sample !== null)) {
29
- throw new Error();
30
- }
31
- shape = void 0;
32
- if ('xs' in sample && 'ys' in sample) {
33
- shape = sample.xs.shape;
34
- }
35
- else {
36
- shape = sample.shape;
37
- }
38
- if (!(shape[0] === task.trainingInformation.IMAGE_W &&
39
- shape[1] === task.trainingInformation.IMAGE_H)) {
40
- throw new Error();
41
- }
42
- return [3 /*break*/, 4];
43
- case 3:
44
- e_1 = _b.sent();
45
- throw new Error('Data input format is not compatible with the chosen task');
46
- case 4: return [2 /*return*/, new ImageData(dataset, task, size)];
47
- }
48
- });
49
- });
50
- };
51
- ImageData.prototype.batch = function () {
52
- var batchSize = this.task.trainingInformation.batchSize;
53
- var newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize);
54
- return new ImageData(newDataset, this.task, this.size);
55
- };
56
- ImageData.prototype.preprocess = function () {
57
- var newDataset = this.dataset;
58
- var preprocessImage = (0, preprocessing_1.getPreprocessImage)(this.task);
59
- newDataset = newDataset.map(function (x) { return preprocessImage(x); });
60
- return new ImageData(newDataset, this.task, this.size);
61
- };
62
- return ImageData;
63
- }(data_1.Data));
64
- exports.ImageData = ImageData;
@@ -1,5 +0,0 @@
1
- export { DataSplit } from './data_split';
2
- export { Data } from './data';
3
- export { ImageData } from './image_data';
4
- export { TabularData } from './tabular_data';
5
- export { ImagePreprocessing } from './preprocessing';
@@ -1,11 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ImagePreprocessing = exports.TabularData = exports.ImageData = exports.Data = void 0;
4
- var data_1 = require("./data");
5
- Object.defineProperty(exports, "Data", { enumerable: true, get: function () { return data_1.Data; } });
6
- var image_data_1 = require("./image_data");
7
- Object.defineProperty(exports, "ImageData", { enumerable: true, get: function () { return image_data_1.ImageData; } });
8
- var tabular_data_1 = require("./tabular_data");
9
- Object.defineProperty(exports, "TabularData", { enumerable: true, get: function () { return tabular_data_1.TabularData; } });
10
- var preprocessing_1 = require("./preprocessing");
11
- Object.defineProperty(exports, "ImagePreprocessing", { enumerable: true, get: function () { return preprocessing_1.ImagePreprocessing; } });
@@ -1,13 +0,0 @@
1
- import { tf, Task } from '../..';
2
- declare type PreprocessImage = (image: tf.TensorContainer) => tf.TensorContainer;
3
- export declare type Preprocessing = ImagePreprocessing;
4
- export interface ImageTensorContainer extends tf.TensorContainerObject {
5
- xs: tf.Tensor3D | tf.Tensor4D;
6
- ys: tf.Tensor1D | number | undefined;
7
- }
8
- export declare enum ImagePreprocessing {
9
- Normalize = "normalize",
10
- Resize = "resize"
11
- }
12
- export declare function getPreprocessImage(task: Task): PreprocessImage;
13
- export {};