@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/{core/task → task}/task.d.ts +7 -7
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +24 -15
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/default_tasks/cifar10.d.ts +0 -2
- package/dist/core/default_tasks/geotags.d.ts +0 -2
- package/dist/core/default_tasks/geotags.js +0 -69
- package/dist/core/default_tasks/index.d.ts +0 -6
- package/dist/core/default_tasks/index.js +0 -15
- package/dist/core/default_tasks/lus_covid.d.ts +0 -2
- package/dist/core/default_tasks/lus_covid.js +0 -96
- package/dist/core/default_tasks/mnist.d.ts +0 -2
- package/dist/core/default_tasks/simple_face.d.ts +0 -2
- package/dist/core/default_tasks/titanic.d.ts +0 -2
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -39
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/digest.js +0 -18
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -6
- package/dist/core/task/index.js +0 -15
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.js +0 -35
- package/dist/core/task/task_handler.d.ts +0 -5
- package/dist/core/task/task_handler.js +0 -53
- package/dist/core/task/task_provider.d.ts +0 -6
- package/dist/core/task/task_provider.js +0 -13
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -7
- package/dist/core/weights/aggregation.js +0 -72
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/digest.d.ts +0 -0
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -1,273 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Client = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var uuid_1 = require("uuid");
|
|
6
|
-
var __1 = require("../..");
|
|
7
|
-
var base_1 = require("../base");
|
|
8
|
-
var messages = (0, tslib_1.__importStar)(require("./messages"));
|
|
9
|
-
var messages_1 = require("../messages");
|
|
10
|
-
var nodeUrl = (0, tslib_1.__importStar)(require("url"));
|
|
11
|
-
var event_connection_1 = require("../event_connection");
|
|
12
|
-
var utils_1 = require("../utils");
|
|
13
|
-
/**
|
|
14
|
-
* Class that deals with communication with the centralized server when training
|
|
15
|
-
* a specific task in the federated setting.
|
|
16
|
-
*/
|
|
17
|
-
var Client = /** @class */ (function (_super) {
|
|
18
|
-
(0, tslib_1.__extends)(Client, _super);
|
|
19
|
-
function Client() {
|
|
20
|
-
var _this = _super !== null && _super.apply(this, arguments) || this;
|
|
21
|
-
_this.clientID = (0, uuid_1.v4)();
|
|
22
|
-
_this.round = 0;
|
|
23
|
-
return _this;
|
|
24
|
-
}
|
|
25
|
-
Object.defineProperty(Client.prototype, "server", {
|
|
26
|
-
get: function () {
|
|
27
|
-
if (this._server === undefined) {
|
|
28
|
-
throw new Error('server undefined, not connected');
|
|
29
|
-
}
|
|
30
|
-
return this._server;
|
|
31
|
-
},
|
|
32
|
-
enumerable: false,
|
|
33
|
-
configurable: true
|
|
34
|
-
});
|
|
35
|
-
// It opens a new WebSocket connection and listens to new messages over the channel
|
|
36
|
-
Client.prototype.connectServer = function (url) {
|
|
37
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
38
|
-
var server;
|
|
39
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
40
|
-
switch (_a.label) {
|
|
41
|
-
case 0: return [4 /*yield*/, event_connection_1.WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated)];
|
|
42
|
-
case 1:
|
|
43
|
-
server = _a.sent();
|
|
44
|
-
return [2 /*return*/, server];
|
|
45
|
-
}
|
|
46
|
-
});
|
|
47
|
-
});
|
|
48
|
-
};
|
|
49
|
-
/**
|
|
50
|
-
* Initialize the connection to the server. TODO: In the case of FeAI,
|
|
51
|
-
* should return the current server-side round for the task.
|
|
52
|
-
*/
|
|
53
|
-
Client.prototype.connect = function () {
|
|
54
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
55
|
-
var URL, serverURL, _a, msg;
|
|
56
|
-
return (0, tslib_1.__generator)(this, function (_b) {
|
|
57
|
-
switch (_b.label) {
|
|
58
|
-
case 0:
|
|
59
|
-
URL = typeof window !== 'undefined' ? window.URL : nodeUrl.URL;
|
|
60
|
-
serverURL = new URL('', this.url.href);
|
|
61
|
-
switch (this.url.protocol) {
|
|
62
|
-
case 'http:':
|
|
63
|
-
serverURL.protocol = 'ws:';
|
|
64
|
-
break;
|
|
65
|
-
case 'https:':
|
|
66
|
-
serverURL.protocol = 'wss:';
|
|
67
|
-
break;
|
|
68
|
-
default:
|
|
69
|
-
throw new Error("unknown protocol: " + this.url.protocol);
|
|
70
|
-
}
|
|
71
|
-
serverURL.pathname += "feai/" + this.task.taskID + "/" + this.clientID;
|
|
72
|
-
_a = this;
|
|
73
|
-
return [4 /*yield*/, this.connectServer(serverURL)];
|
|
74
|
-
case 1:
|
|
75
|
-
_a._server = _b.sent();
|
|
76
|
-
msg = {
|
|
77
|
-
type: messages_1.type.clientConnected
|
|
78
|
-
};
|
|
79
|
-
this.server.send(msg);
|
|
80
|
-
return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.clientConnected, utils_1.MAX_WAIT_PER_ROUND)];
|
|
81
|
-
case 2:
|
|
82
|
-
_b.sent();
|
|
83
|
-
this.connected = true;
|
|
84
|
-
return [2 /*return*/];
|
|
85
|
-
}
|
|
86
|
-
});
|
|
87
|
-
});
|
|
88
|
-
};
|
|
89
|
-
/**
|
|
90
|
-
* Disconnection process when user quits the task.
|
|
91
|
-
*/
|
|
92
|
-
Client.prototype.disconnect = function () {
|
|
93
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
94
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
95
|
-
this.server.disconnect();
|
|
96
|
-
this._server = undefined;
|
|
97
|
-
this.connected = false;
|
|
98
|
-
return [2 /*return*/];
|
|
99
|
-
});
|
|
100
|
-
});
|
|
101
|
-
};
|
|
102
|
-
// It sends a message to the server
|
|
103
|
-
Client.prototype.sendMessage = function (msg) {
|
|
104
|
-
var _a;
|
|
105
|
-
(_a = this.server) === null || _a === void 0 ? void 0 : _a.send(msg);
|
|
106
|
-
};
|
|
107
|
-
// It sends weights to the server
|
|
108
|
-
Client.prototype.postWeightsToServer = function (weights) {
|
|
109
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
110
|
-
var msg;
|
|
111
|
-
var _a;
|
|
112
|
-
return (0, tslib_1.__generator)(this, function (_b) {
|
|
113
|
-
switch (_b.label) {
|
|
114
|
-
case 0:
|
|
115
|
-
_a = {
|
|
116
|
-
type: messages_1.type.postWeightsToServer
|
|
117
|
-
};
|
|
118
|
-
return [4 /*yield*/, __1.serialization.weights.encode(weights)];
|
|
119
|
-
case 1:
|
|
120
|
-
msg = (_a.weights = _b.sent(),
|
|
121
|
-
_a.round = this.round,
|
|
122
|
-
_a);
|
|
123
|
-
this.sendMessage(msg);
|
|
124
|
-
return [2 /*return*/];
|
|
125
|
-
}
|
|
126
|
-
});
|
|
127
|
-
});
|
|
128
|
-
};
|
|
129
|
-
// It retrieves the last server round and weights, but return only the server round
|
|
130
|
-
Client.prototype.getLatestServerRound = function () {
|
|
131
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
132
|
-
var msg, received;
|
|
133
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
134
|
-
switch (_a.label) {
|
|
135
|
-
case 0:
|
|
136
|
-
this.serverRound = undefined;
|
|
137
|
-
this.serverWeights = undefined;
|
|
138
|
-
msg = {
|
|
139
|
-
type: messages_1.type.latestServerRound
|
|
140
|
-
};
|
|
141
|
-
this.sendMessage(msg);
|
|
142
|
-
return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.latestServerRound, utils_1.MAX_WAIT_PER_ROUND)];
|
|
143
|
-
case 1:
|
|
144
|
-
received = _a.sent();
|
|
145
|
-
this.serverRound = received.round;
|
|
146
|
-
this.serverWeights = __1.serialization.weights.decode(received.weights);
|
|
147
|
-
return [2 /*return*/, this.serverRound];
|
|
148
|
-
}
|
|
149
|
-
});
|
|
150
|
-
});
|
|
151
|
-
};
|
|
152
|
-
// It retrieves the last server round and weights, but return only the server weights
|
|
153
|
-
Client.prototype.pullRoundAndFetchWeights = function () {
|
|
154
|
-
var _a;
|
|
155
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
156
|
-
return (0, tslib_1.__generator)(this, function (_b) {
|
|
157
|
-
switch (_b.label) {
|
|
158
|
-
case 0:
|
|
159
|
-
// get server round of latest model
|
|
160
|
-
return [4 /*yield*/, this.getLatestServerRound()];
|
|
161
|
-
case 1:
|
|
162
|
-
// get server round of latest model
|
|
163
|
-
_b.sent();
|
|
164
|
-
if (this.round < ((_a = this.serverRound) !== null && _a !== void 0 ? _a : 0)) {
|
|
165
|
-
// Update the local round to match the server's
|
|
166
|
-
this.round = this.serverRound;
|
|
167
|
-
return [2 /*return*/, this.serverWeights];
|
|
168
|
-
}
|
|
169
|
-
else {
|
|
170
|
-
return [2 /*return*/, undefined];
|
|
171
|
-
}
|
|
172
|
-
return [2 /*return*/];
|
|
173
|
-
}
|
|
174
|
-
});
|
|
175
|
-
});
|
|
176
|
-
};
|
|
177
|
-
// It pulls statistics from the server
|
|
178
|
-
Client.prototype.pullServerStatistics = function (trainingInformant) {
|
|
179
|
-
var _a;
|
|
180
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
181
|
-
var msg, received;
|
|
182
|
-
return (0, tslib_1.__generator)(this, function (_b) {
|
|
183
|
-
switch (_b.label) {
|
|
184
|
-
case 0:
|
|
185
|
-
this.receivedStatistics = undefined;
|
|
186
|
-
msg = {
|
|
187
|
-
type: messages_1.type.pullServerStatistics
|
|
188
|
-
};
|
|
189
|
-
this.sendMessage(msg);
|
|
190
|
-
return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.pullServerStatistics, utils_1.MAX_WAIT_PER_ROUND)];
|
|
191
|
-
case 1:
|
|
192
|
-
received = _b.sent();
|
|
193
|
-
this.receivedStatistics = received.statistics;
|
|
194
|
-
trainingInformant.update((_a = this.receivedStatistics) !== null && _a !== void 0 ? _a : {});
|
|
195
|
-
return [2 /*return*/];
|
|
196
|
-
}
|
|
197
|
-
});
|
|
198
|
-
});
|
|
199
|
-
};
|
|
200
|
-
// It posts a new metadata value to the server
|
|
201
|
-
Client.prototype.postMetadata = function (metadataID, metadata) {
|
|
202
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
203
|
-
var msg;
|
|
204
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
205
|
-
msg = {
|
|
206
|
-
type: messages_1.type.postMetadata,
|
|
207
|
-
taskId: this.task.taskID,
|
|
208
|
-
clientId: this.clientID,
|
|
209
|
-
round: this.round,
|
|
210
|
-
metadataId: metadataID,
|
|
211
|
-
metadata: metadata
|
|
212
|
-
};
|
|
213
|
-
this.sendMessage(msg);
|
|
214
|
-
return [2 /*return*/];
|
|
215
|
-
});
|
|
216
|
-
});
|
|
217
|
-
};
|
|
218
|
-
// It gets a metadata map from the server
|
|
219
|
-
Client.prototype.getMetadataMap = function (metadataId) {
|
|
220
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
221
|
-
var msg, received;
|
|
222
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
223
|
-
switch (_a.label) {
|
|
224
|
-
case 0:
|
|
225
|
-
this.metadataMap = undefined;
|
|
226
|
-
msg = {
|
|
227
|
-
type: messages_1.type.getMetadataMap,
|
|
228
|
-
taskId: this.task.taskID,
|
|
229
|
-
clientId: this.clientID,
|
|
230
|
-
round: this.round,
|
|
231
|
-
metadataId: metadataId
|
|
232
|
-
};
|
|
233
|
-
this.sendMessage(msg);
|
|
234
|
-
return [4 /*yield*/, (0, event_connection_1.waitMessageWithTimeout)(this.server, messages_1.type.getMetadataMap, utils_1.MAX_WAIT_PER_ROUND)];
|
|
235
|
-
case 1:
|
|
236
|
-
received = _a.sent();
|
|
237
|
-
if (received.metadataMap !== undefined) {
|
|
238
|
-
this.metadataMap = new Map(received.metadataMap);
|
|
239
|
-
}
|
|
240
|
-
return [2 /*return*/, this.metadataMap];
|
|
241
|
-
}
|
|
242
|
-
});
|
|
243
|
-
});
|
|
244
|
-
};
|
|
245
|
-
Client.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, _, trainingInformant) {
|
|
246
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
247
|
-
var noisyWeights, serverWeights;
|
|
248
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
249
|
-
switch (_a.label) {
|
|
250
|
-
case 0:
|
|
251
|
-
noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
|
|
252
|
-
return [4 /*yield*/, this.postWeightsToServer(noisyWeights)];
|
|
253
|
-
case 1:
|
|
254
|
-
_a.sent();
|
|
255
|
-
return [4 /*yield*/, this.pullServerStatistics(trainingInformant)];
|
|
256
|
-
case 2:
|
|
257
|
-
_a.sent();
|
|
258
|
-
return [4 /*yield*/, this.pullRoundAndFetchWeights()];
|
|
259
|
-
case 3:
|
|
260
|
-
serverWeights = _a.sent();
|
|
261
|
-
return [2 /*return*/, serverWeights !== null && serverWeights !== void 0 ? serverWeights : staleWeights];
|
|
262
|
-
}
|
|
263
|
-
});
|
|
264
|
-
});
|
|
265
|
-
};
|
|
266
|
-
Client.prototype.onTrainEndCommunication = function () {
|
|
267
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
268
|
-
return [2 /*return*/];
|
|
269
|
-
}); });
|
|
270
|
-
};
|
|
271
|
-
return Client;
|
|
272
|
-
}(base_1.Base));
|
|
273
|
-
exports.Client = Client;
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.messages = exports.Client = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var client_1 = require("./client");
|
|
6
|
-
Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Client; } });
|
|
7
|
-
exports.messages = (0, tslib_1.__importStar)(require("./messages"));
|
|
@@ -1,38 +0,0 @@
|
|
|
1
|
-
import { MetadataID } from '../..';
|
|
2
|
-
import { weights } from '../../serialization';
|
|
3
|
-
import { type } from '../messages';
|
|
4
|
-
export declare type MessageFederated = postWeightsToServer | latestServerRound | pullServerStatistics | postMetadata | getMetadataMap | messageGeneral;
|
|
5
|
-
export interface messageGeneral {
|
|
6
|
-
type: type;
|
|
7
|
-
}
|
|
8
|
-
export interface postWeightsToServer {
|
|
9
|
-
type: type.postWeightsToServer;
|
|
10
|
-
weights: weights.Encoded;
|
|
11
|
-
round: number;
|
|
12
|
-
}
|
|
13
|
-
export interface latestServerRound {
|
|
14
|
-
type: type.latestServerRound;
|
|
15
|
-
weights: weights.Encoded;
|
|
16
|
-
round: number;
|
|
17
|
-
}
|
|
18
|
-
export interface pullServerStatistics {
|
|
19
|
-
type: type.pullServerStatistics;
|
|
20
|
-
statistics: Record<string, number>;
|
|
21
|
-
}
|
|
22
|
-
export interface postMetadata {
|
|
23
|
-
type: type.postMetadata;
|
|
24
|
-
clientId: string;
|
|
25
|
-
taskId: string;
|
|
26
|
-
round: number;
|
|
27
|
-
metadataId: string;
|
|
28
|
-
metadata: string;
|
|
29
|
-
}
|
|
30
|
-
export interface getMetadataMap {
|
|
31
|
-
type: type.getMetadataMap;
|
|
32
|
-
clientId: string;
|
|
33
|
-
taskId: string;
|
|
34
|
-
round: number;
|
|
35
|
-
metadataId: MetadataID;
|
|
36
|
-
metadataMap?: Array<[string, string | undefined]>;
|
|
37
|
-
}
|
|
38
|
-
export declare function isMessageFederated(o: unknown): o is MessageFederated;
|
|
@@ -1,25 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.isMessageFederated = void 0;
|
|
4
|
-
var messages_1 = require("../messages");
|
|
5
|
-
function isMessageFederated(o) {
|
|
6
|
-
if (!(0, messages_1.hasMessageType)(o)) {
|
|
7
|
-
return false;
|
|
8
|
-
}
|
|
9
|
-
switch (o.type) {
|
|
10
|
-
case messages_1.type.clientConnected:
|
|
11
|
-
return true;
|
|
12
|
-
case messages_1.type.postWeightsToServer:
|
|
13
|
-
return true;
|
|
14
|
-
case messages_1.type.latestServerRound:
|
|
15
|
-
return true;
|
|
16
|
-
case messages_1.type.pullServerStatistics:
|
|
17
|
-
return true;
|
|
18
|
-
case messages_1.type.postMetadata:
|
|
19
|
-
return true;
|
|
20
|
-
case messages_1.type.getMetadataMap:
|
|
21
|
-
return true;
|
|
22
|
-
}
|
|
23
|
-
return false;
|
|
24
|
-
}
|
|
25
|
-
exports.isMessageFederated = isMessageFederated;
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Local = exports.messages = exports.federated = exports.decentralized = exports.Base = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var base_1 = require("./base");
|
|
6
|
-
Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
|
|
7
|
-
exports.decentralized = (0, tslib_1.__importStar)(require("./decentralized"));
|
|
8
|
-
exports.federated = (0, tslib_1.__importStar)(require("./federated"));
|
|
9
|
-
exports.messages = (0, tslib_1.__importStar)(require("./messages"));
|
|
10
|
-
var local_1 = require("./local");
|
|
11
|
-
Object.defineProperty(exports, "Local", { enumerable: true, get: function () { return local_1.Local; } });
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
import { WeightsContainer } from '..';
|
|
2
|
-
import { Base } from './base';
|
|
3
|
-
export declare class Local extends Base {
|
|
4
|
-
connect(): Promise<void>;
|
|
5
|
-
disconnect(): Promise<void>;
|
|
6
|
-
onRoundEndCommunication(_: WeightsContainer): Promise<WeightsContainer>;
|
|
7
|
-
onTrainEndCommunication(): Promise<void>;
|
|
8
|
-
}
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Local = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var base_1 = require("./base");
|
|
6
|
-
// does pretty much nothing
|
|
7
|
-
var Local = /** @class */ (function (_super) {
|
|
8
|
-
(0, tslib_1.__extends)(Local, _super);
|
|
9
|
-
function Local() {
|
|
10
|
-
return _super !== null && _super.apply(this, arguments) || this;
|
|
11
|
-
}
|
|
12
|
-
Local.prototype.connect = function () {
|
|
13
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
14
|
-
return [2 /*return*/];
|
|
15
|
-
}); });
|
|
16
|
-
};
|
|
17
|
-
Local.prototype.disconnect = function () {
|
|
18
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
19
|
-
return [2 /*return*/];
|
|
20
|
-
}); });
|
|
21
|
-
};
|
|
22
|
-
Local.prototype.onRoundEndCommunication = function (_) {
|
|
23
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
24
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
25
|
-
return [2 /*return*/, _];
|
|
26
|
-
});
|
|
27
|
-
});
|
|
28
|
-
};
|
|
29
|
-
Local.prototype.onTrainEndCommunication = function () {
|
|
30
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
31
|
-
return [2 /*return*/];
|
|
32
|
-
}); });
|
|
33
|
-
};
|
|
34
|
-
return Local;
|
|
35
|
-
}(base_1.Base));
|
|
36
|
-
exports.Local = Local;
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
import * as decentralized from './decentralized/messages';
|
|
2
|
-
import * as federated from './federated/messages';
|
|
3
|
-
export declare enum type {
|
|
4
|
-
clientConnected = 0,
|
|
5
|
-
PeerID = 1,
|
|
6
|
-
SignalForPeer = 2,
|
|
7
|
-
PeerIsReady = 3,
|
|
8
|
-
PeersForRound = 4,
|
|
9
|
-
Weights = 5,
|
|
10
|
-
Shares = 6,
|
|
11
|
-
PartialSums = 7,
|
|
12
|
-
postWeightsToServer = 8,
|
|
13
|
-
postMetadata = 9,
|
|
14
|
-
getMetadataMap = 10,
|
|
15
|
-
latestServerRound = 11,
|
|
16
|
-
pullRoundAndFetchWeights = 12,
|
|
17
|
-
pullServerStatistics = 13
|
|
18
|
-
}
|
|
19
|
-
export interface clientConnected {
|
|
20
|
-
type: type.clientConnected;
|
|
21
|
-
}
|
|
22
|
-
export declare type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
|
|
23
|
-
export declare type NarrowMessage<D> = Extract<Message, {
|
|
24
|
-
type: D;
|
|
25
|
-
}>;
|
|
26
|
-
export declare function hasMessageType(raw: unknown): raw is {
|
|
27
|
-
type: type;
|
|
28
|
-
} & Record<string, unknown>;
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.hasMessageType = exports.type = void 0;
|
|
4
|
-
var type;
|
|
5
|
-
(function (type) {
|
|
6
|
-
type[type["clientConnected"] = 0] = "clientConnected";
|
|
7
|
-
// decentralized
|
|
8
|
-
type[type["PeerID"] = 1] = "PeerID";
|
|
9
|
-
type[type["SignalForPeer"] = 2] = "SignalForPeer";
|
|
10
|
-
type[type["PeerIsReady"] = 3] = "PeerIsReady";
|
|
11
|
-
type[type["PeersForRound"] = 4] = "PeersForRound";
|
|
12
|
-
type[type["Weights"] = 5] = "Weights";
|
|
13
|
-
type[type["Shares"] = 6] = "Shares";
|
|
14
|
-
type[type["PartialSums"] = 7] = "PartialSums";
|
|
15
|
-
// federated
|
|
16
|
-
type[type["postWeightsToServer"] = 8] = "postWeightsToServer";
|
|
17
|
-
type[type["postMetadata"] = 9] = "postMetadata";
|
|
18
|
-
type[type["getMetadataMap"] = 10] = "getMetadataMap";
|
|
19
|
-
type[type["latestServerRound"] = 11] = "latestServerRound";
|
|
20
|
-
type[type["pullRoundAndFetchWeights"] = 12] = "pullRoundAndFetchWeights";
|
|
21
|
-
type[type["pullServerStatistics"] = 13] = "pullServerStatistics";
|
|
22
|
-
})(type = exports.type || (exports.type = {}));
|
|
23
|
-
function hasMessageType(raw) {
|
|
24
|
-
if (typeof raw !== 'object' || raw === null) {
|
|
25
|
-
return false;
|
|
26
|
-
}
|
|
27
|
-
var o = raw;
|
|
28
|
-
if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
|
|
29
|
-
return false;
|
|
30
|
-
}
|
|
31
|
-
return true;
|
|
32
|
-
}
|
|
33
|
-
exports.hasMessageType = hasMessageType;
|
|
@@ -1,19 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.timeout = exports.MAX_WAIT_PER_ROUND = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
// Time to wait for the others in milliseconds.
|
|
6
|
-
exports.MAX_WAIT_PER_ROUND = 10000;
|
|
7
|
-
function timeout(ms) {
|
|
8
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
9
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
10
|
-
switch (_a.label) {
|
|
11
|
-
case 0: return [4 /*yield*/, new Promise(function (resolve, reject) {
|
|
12
|
-
setTimeout(function () { return reject(new Error('timeout')); }, ms);
|
|
13
|
-
})];
|
|
14
|
-
case 1: return [2 /*return*/, _a.sent()];
|
|
15
|
-
}
|
|
16
|
-
});
|
|
17
|
-
});
|
|
18
|
-
}
|
|
19
|
-
exports.timeout = timeout;
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
import { Task } from '../..';
|
|
2
|
-
import { Dataset } from '../dataset';
|
|
3
|
-
export declare abstract class Data {
|
|
4
|
-
readonly dataset: Dataset;
|
|
5
|
-
readonly task: Task;
|
|
6
|
-
readonly size?: number | undefined;
|
|
7
|
-
protected constructor(dataset: Dataset, task: Task, size?: number | undefined);
|
|
8
|
-
static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
|
|
9
|
-
abstract batch(): Data;
|
|
10
|
-
abstract preprocess(): Data;
|
|
11
|
-
}
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Data = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var Data = /** @class */ (function () {
|
|
6
|
-
function Data(dataset, task, size) {
|
|
7
|
-
this.dataset = dataset;
|
|
8
|
-
this.task = task;
|
|
9
|
-
this.size = size;
|
|
10
|
-
}
|
|
11
|
-
Data.init = function (dataset, task, size) {
|
|
12
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
13
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
14
|
-
throw new Error('abstract');
|
|
15
|
-
});
|
|
16
|
-
});
|
|
17
|
-
};
|
|
18
|
-
return Data;
|
|
19
|
-
}());
|
|
20
|
-
exports.Data = Data;
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
import { Task } from '../..';
|
|
2
|
-
import { Dataset } from '../dataset';
|
|
3
|
-
import { Data } from './data';
|
|
4
|
-
export declare class ImageData extends Data {
|
|
5
|
-
static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
|
|
6
|
-
batch(): Data;
|
|
7
|
-
preprocess(): Data;
|
|
8
|
-
}
|
|
@@ -1,64 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.ImageData = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var preprocessing_1 = require("./preprocessing");
|
|
6
|
-
var data_1 = require("./data");
|
|
7
|
-
var ImageData = /** @class */ (function (_super) {
|
|
8
|
-
(0, tslib_1.__extends)(ImageData, _super);
|
|
9
|
-
function ImageData() {
|
|
10
|
-
return _super !== null && _super.apply(this, arguments) || this;
|
|
11
|
-
}
|
|
12
|
-
ImageData.init = function (dataset, task, size) {
|
|
13
|
-
var _a;
|
|
14
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
15
|
-
var sample, shape, e_1;
|
|
16
|
-
return (0, tslib_1.__generator)(this, function (_b) {
|
|
17
|
-
switch (_b.label) {
|
|
18
|
-
case 0:
|
|
19
|
-
if (!!((_a = task.trainingInformation.preprocessingFunctions) === null || _a === void 0 ? void 0 : _a.includes(preprocessing_1.ImagePreprocessing.Resize))) return [3 /*break*/, 4];
|
|
20
|
-
_b.label = 1;
|
|
21
|
-
case 1:
|
|
22
|
-
_b.trys.push([1, 3, , 4]);
|
|
23
|
-
return [4 /*yield*/, dataset.take(1).toArray()];
|
|
24
|
-
case 2:
|
|
25
|
-
sample = (_b.sent())[0];
|
|
26
|
-
// TODO: We suppose the presence of labels
|
|
27
|
-
// TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
|
|
28
|
-
if (!(typeof sample === 'object' && sample !== null)) {
|
|
29
|
-
throw new Error();
|
|
30
|
-
}
|
|
31
|
-
shape = void 0;
|
|
32
|
-
if ('xs' in sample && 'ys' in sample) {
|
|
33
|
-
shape = sample.xs.shape;
|
|
34
|
-
}
|
|
35
|
-
else {
|
|
36
|
-
shape = sample.shape;
|
|
37
|
-
}
|
|
38
|
-
if (!(shape[0] === task.trainingInformation.IMAGE_W &&
|
|
39
|
-
shape[1] === task.trainingInformation.IMAGE_H)) {
|
|
40
|
-
throw new Error();
|
|
41
|
-
}
|
|
42
|
-
return [3 /*break*/, 4];
|
|
43
|
-
case 3:
|
|
44
|
-
e_1 = _b.sent();
|
|
45
|
-
throw new Error('Data input format is not compatible with the chosen task');
|
|
46
|
-
case 4: return [2 /*return*/, new ImageData(dataset, task, size)];
|
|
47
|
-
}
|
|
48
|
-
});
|
|
49
|
-
});
|
|
50
|
-
};
|
|
51
|
-
ImageData.prototype.batch = function () {
|
|
52
|
-
var batchSize = this.task.trainingInformation.batchSize;
|
|
53
|
-
var newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize);
|
|
54
|
-
return new ImageData(newDataset, this.task, this.size);
|
|
55
|
-
};
|
|
56
|
-
ImageData.prototype.preprocess = function () {
|
|
57
|
-
var newDataset = this.dataset;
|
|
58
|
-
var preprocessImage = (0, preprocessing_1.getPreprocessImage)(this.task);
|
|
59
|
-
newDataset = newDataset.map(function (x) { return preprocessImage(x); });
|
|
60
|
-
return new ImageData(newDataset, this.task, this.size);
|
|
61
|
-
};
|
|
62
|
-
return ImageData;
|
|
63
|
-
}(data_1.Data));
|
|
64
|
-
exports.ImageData = ImageData;
|
|
@@ -1,11 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.ImagePreprocessing = exports.TabularData = exports.ImageData = exports.Data = void 0;
|
|
4
|
-
var data_1 = require("./data");
|
|
5
|
-
Object.defineProperty(exports, "Data", { enumerable: true, get: function () { return data_1.Data; } });
|
|
6
|
-
var image_data_1 = require("./image_data");
|
|
7
|
-
Object.defineProperty(exports, "ImageData", { enumerable: true, get: function () { return image_data_1.ImageData; } });
|
|
8
|
-
var tabular_data_1 = require("./tabular_data");
|
|
9
|
-
Object.defineProperty(exports, "TabularData", { enumerable: true, get: function () { return tabular_data_1.TabularData; } });
|
|
10
|
-
var preprocessing_1 = require("./preprocessing");
|
|
11
|
-
Object.defineProperty(exports, "ImagePreprocessing", { enumerable: true, get: function () { return preprocessing_1.ImagePreprocessing; } });
|
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
import { tf, Task } from '../..';
|
|
2
|
-
declare type PreprocessImage = (image: tf.TensorContainer) => tf.TensorContainer;
|
|
3
|
-
export declare type Preprocessing = ImagePreprocessing;
|
|
4
|
-
export interface ImageTensorContainer extends tf.TensorContainerObject {
|
|
5
|
-
xs: tf.Tensor3D | tf.Tensor4D;
|
|
6
|
-
ys: tf.Tensor1D | number | undefined;
|
|
7
|
-
}
|
|
8
|
-
export declare enum ImagePreprocessing {
|
|
9
|
-
Normalize = "normalize",
|
|
10
|
-
Resize = "resize"
|
|
11
|
-
}
|
|
12
|
-
export declare function getPreprocessImage(task: Task): PreprocessImage;
|
|
13
|
-
export {};
|