@epfml/discojs 0.1.0 → 2.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +28 -8
- package/dist/{async_buffer.d.ts → core/async_buffer.d.ts} +3 -3
- package/dist/{async_buffer.js → core/async_buffer.js} +5 -6
- package/dist/{async_informant.d.ts → core/async_informant.d.ts} +0 -0
- package/dist/{async_informant.js → core/async_informant.js} +0 -0
- package/dist/{client → core/client}/base.d.ts +4 -7
- package/dist/{client → core/client}/base.js +3 -2
- package/dist/core/client/decentralized/base.d.ts +32 -0
- package/dist/core/client/decentralized/base.js +212 -0
- package/dist/core/client/decentralized/clear_text.d.ts +14 -0
- package/dist/core/client/decentralized/clear_text.js +96 -0
- package/dist/{client → core/client}/decentralized/index.d.ts +0 -0
- package/dist/{client → core/client}/decentralized/index.js +0 -0
- package/dist/core/client/decentralized/messages.d.ts +41 -0
- package/dist/core/client/decentralized/messages.js +54 -0
- package/dist/core/client/decentralized/peer.d.ts +26 -0
- package/dist/core/client/decentralized/peer.js +210 -0
- package/dist/core/client/decentralized/peer_pool.d.ts +14 -0
- package/dist/core/client/decentralized/peer_pool.js +92 -0
- package/dist/core/client/decentralized/sec_agg.d.ts +22 -0
- package/dist/core/client/decentralized/sec_agg.js +190 -0
- package/dist/core/client/decentralized/secret_shares.d.ts +3 -0
- package/dist/core/client/decentralized/secret_shares.js +39 -0
- package/dist/core/client/decentralized/types.d.ts +2 -0
- package/dist/core/client/decentralized/types.js +7 -0
- package/dist/core/client/event_connection.d.ts +37 -0
- package/dist/core/client/event_connection.js +158 -0
- package/dist/core/client/federated/client.d.ts +37 -0
- package/dist/core/client/federated/client.js +273 -0
- package/dist/core/client/federated/index.d.ts +2 -0
- package/dist/core/client/federated/index.js +7 -0
- package/dist/core/client/federated/messages.d.ts +38 -0
- package/dist/core/client/federated/messages.js +25 -0
- package/dist/{client → core/client}/index.d.ts +2 -1
- package/dist/{client → core/client}/index.js +3 -3
- package/dist/{client → core/client}/local.d.ts +2 -2
- package/dist/{client → core/client}/local.js +0 -0
- package/dist/core/client/messages.d.ts +28 -0
- package/dist/core/client/messages.js +33 -0
- package/dist/core/client/utils.d.ts +2 -0
- package/dist/core/client/utils.js +19 -0
- package/dist/core/dataset/data/data.d.ts +11 -0
- package/dist/core/dataset/data/data.js +20 -0
- package/dist/core/dataset/data/data_split.d.ts +5 -0
- package/dist/{client/decentralized/types.js → core/dataset/data/data_split.js} +0 -0
- package/dist/core/dataset/data/image_data.d.ts +8 -0
- package/dist/core/dataset/data/image_data.js +64 -0
- package/dist/core/dataset/data/index.d.ts +5 -0
- package/dist/core/dataset/data/index.js +11 -0
- package/dist/core/dataset/data/preprocessing.d.ts +13 -0
- package/dist/core/dataset/data/preprocessing.js +33 -0
- package/dist/core/dataset/data/tabular_data.d.ts +8 -0
- package/dist/core/dataset/data/tabular_data.js +40 -0
- package/dist/{dataset → core/dataset}/data_loader/data_loader.d.ts +4 -11
- package/dist/{dataset → core/dataset}/data_loader/data_loader.js +0 -0
- package/dist/core/dataset/data_loader/image_loader.d.ts +17 -0
- package/dist/core/dataset/data_loader/image_loader.js +141 -0
- package/dist/core/dataset/data_loader/index.d.ts +3 -0
- package/dist/core/dataset/data_loader/index.js +9 -0
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +29 -0
- package/dist/core/dataset/data_loader/tabular_loader.js +101 -0
- package/dist/core/dataset/dataset.d.ts +2 -0
- package/dist/{task/training_information.js → core/dataset/dataset.js} +0 -0
- package/dist/{dataset → core/dataset}/dataset_builder.d.ts +5 -5
- package/dist/{dataset → core/dataset}/dataset_builder.js +14 -10
- package/dist/core/dataset/index.d.ts +4 -0
- package/dist/core/dataset/index.js +14 -0
- package/dist/core/index.d.ts +18 -0
- package/dist/core/index.js +41 -0
- package/dist/{informant → core/informant}/graph_informant.d.ts +0 -0
- package/dist/{informant → core/informant}/graph_informant.js +0 -0
- package/dist/{informant → core/informant}/index.d.ts +0 -0
- package/dist/{informant → core/informant}/index.js +0 -0
- package/dist/{informant → core/informant}/training_informant/base.d.ts +3 -3
- package/dist/{informant → core/informant}/training_informant/base.js +3 -2
- package/dist/{informant → core/informant}/training_informant/decentralized.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/decentralized.js +0 -0
- package/dist/{informant → core/informant}/training_informant/federated.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/federated.js +0 -0
- package/dist/{informant → core/informant}/training_informant/index.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/index.js +0 -0
- package/dist/{informant → core/informant}/training_informant/local.d.ts +2 -2
- package/dist/{informant → core/informant}/training_informant/local.js +2 -2
- package/dist/{logging → core/logging}/console_logger.d.ts +0 -0
- package/dist/{logging → core/logging}/console_logger.js +0 -0
- package/dist/{logging → core/logging}/index.d.ts +0 -0
- package/dist/{logging → core/logging}/index.js +0 -0
- package/dist/{logging → core/logging}/logger.d.ts +0 -0
- package/dist/{logging → core/logging}/logger.js +0 -0
- package/dist/{logging → core/logging}/trainer_logger.d.ts +0 -0
- package/dist/{logging → core/logging}/trainer_logger.js +0 -0
- package/dist/{memory → core/memory}/base.d.ts +2 -2
- package/dist/{memory → core/memory}/base.js +0 -0
- package/dist/{memory → core/memory}/empty.d.ts +0 -0
- package/dist/{memory → core/memory}/empty.js +0 -0
- package/dist/core/memory/index.d.ts +3 -0
- package/dist/core/memory/index.js +9 -0
- package/dist/{memory → core/memory}/model_type.d.ts +0 -0
- package/dist/{memory → core/memory}/model_type.js +0 -0
- package/dist/{privacy.d.ts → core/privacy.d.ts} +2 -3
- package/dist/{privacy.js → core/privacy.js} +3 -16
- package/dist/{serialization → core/serialization}/index.d.ts +0 -0
- package/dist/{serialization → core/serialization}/index.js +0 -0
- package/dist/{serialization → core/serialization}/model.d.ts +0 -0
- package/dist/{serialization → core/serialization}/model.js +0 -0
- package/dist/core/serialization/weights.d.ts +5 -0
- package/dist/{serialization → core/serialization}/weights.js +11 -9
- package/dist/{task → core/task}/data_example.d.ts +0 -0
- package/dist/{task → core/task}/data_example.js +0 -0
- package/dist/{task → core/task}/display_information.d.ts +5 -5
- package/dist/{task → core/task}/display_information.js +5 -10
- package/dist/{task → core/task}/index.d.ts +0 -0
- package/dist/{task → core/task}/index.js +0 -0
- package/dist/core/task/model_compile_data.d.ts +6 -0
- package/dist/core/task/model_compile_data.js +22 -0
- package/dist/{task → core/task}/summary.d.ts +0 -0
- package/dist/{task → core/task}/summary.js +0 -4
- package/dist/{task → core/task}/task.d.ts +2 -2
- package/dist/{task → core/task}/task.js +6 -7
- package/dist/{task → core/task}/training_information.d.ts +10 -14
- package/dist/core/task/training_information.js +66 -0
- package/dist/{tasks → core/tasks}/cifar10.d.ts +1 -2
- package/dist/{tasks → core/tasks}/cifar10.js +12 -23
- package/dist/core/tasks/geotags.d.ts +3 -0
- package/dist/core/tasks/geotags.js +67 -0
- package/dist/{tasks → core/tasks}/index.d.ts +2 -1
- package/dist/{tasks → core/tasks}/index.js +3 -2
- package/dist/core/tasks/lus_covid.d.ts +3 -0
- package/dist/{tasks → core/tasks}/lus_covid.js +26 -24
- package/dist/{tasks → core/tasks}/mnist.d.ts +1 -2
- package/dist/{tasks → core/tasks}/mnist.js +18 -16
- package/dist/core/tasks/simple_face.d.ts +2 -0
- package/dist/core/tasks/simple_face.js +41 -0
- package/dist/{tasks → core/tasks}/titanic.d.ts +1 -2
- package/dist/{tasks → core/tasks}/titanic.js +11 -11
- package/dist/core/training/disco.d.ts +23 -0
- package/dist/core/training/disco.js +130 -0
- package/dist/{training → core/training}/index.d.ts +0 -0
- package/dist/{training → core/training}/index.js +0 -0
- package/dist/{training → core/training}/trainer/distributed_trainer.d.ts +1 -2
- package/dist/{training → core/training}/trainer/distributed_trainer.js +6 -5
- package/dist/{training → core/training}/trainer/local_trainer.d.ts +2 -2
- package/dist/{training → core/training}/trainer/local_trainer.js +0 -0
- package/dist/{training → core/training}/trainer/round_tracker.d.ts +0 -0
- package/dist/{training → core/training}/trainer/round_tracker.js +0 -0
- package/dist/{training → core/training}/trainer/trainer.d.ts +1 -2
- package/dist/{training → core/training}/trainer/trainer.js +2 -2
- package/dist/{training → core/training}/trainer/trainer_builder.d.ts +0 -0
- package/dist/{training → core/training}/trainer/trainer_builder.js +0 -0
- package/dist/core/training/training_schemes.d.ts +5 -0
- package/dist/{training → core/training}/training_schemes.js +2 -2
- package/dist/{types.d.ts → core/types.d.ts} +0 -0
- package/dist/{types.js → core/types.js} +0 -0
- package/dist/{validation → core/validation}/index.d.ts +0 -0
- package/dist/{validation → core/validation}/index.js +0 -0
- package/dist/{validation → core/validation}/validator.d.ts +5 -8
- package/dist/{validation → core/validation}/validator.js +9 -11
- package/dist/core/weights/aggregation.d.ts +8 -0
- package/dist/core/weights/aggregation.js +96 -0
- package/dist/core/weights/index.d.ts +2 -0
- package/dist/core/weights/index.js +7 -0
- package/dist/core/weights/weights_container.d.ts +19 -0
- package/dist/core/weights/weights_container.js +64 -0
- package/dist/dataset/data_loader/image_loader.d.ts +3 -15
- package/dist/dataset/data_loader/image_loader.js +12 -125
- package/dist/dataset/data_loader/index.d.ts +2 -3
- package/dist/dataset/data_loader/index.js +3 -5
- package/dist/dataset/data_loader/tabular_loader.d.ts +3 -28
- package/dist/dataset/data_loader/tabular_loader.js +11 -92
- package/dist/imports.d.ts +2 -0
- package/dist/imports.js +7 -0
- package/dist/index.d.ts +2 -19
- package/dist/index.js +3 -39
- package/dist/memory/index.d.ts +1 -3
- package/dist/memory/index.js +3 -7
- package/dist/memory/memory.d.ts +26 -0
- package/dist/memory/memory.js +160 -0
- package/package.json +13 -26
- package/dist/aggregation.d.ts +0 -5
- package/dist/aggregation.js +0 -33
- package/dist/client/decentralized/base.d.ts +0 -43
- package/dist/client/decentralized/base.js +0 -243
- package/dist/client/decentralized/clear_text.d.ts +0 -13
- package/dist/client/decentralized/clear_text.js +0 -78
- package/dist/client/decentralized/messages.d.ts +0 -37
- package/dist/client/decentralized/messages.js +0 -15
- package/dist/client/decentralized/sec_agg.d.ts +0 -18
- package/dist/client/decentralized/sec_agg.js +0 -169
- package/dist/client/decentralized/secret_shares.d.ts +0 -5
- package/dist/client/decentralized/secret_shares.js +0 -58
- package/dist/client/decentralized/types.d.ts +0 -1
- package/dist/client/federated.d.ts +0 -30
- package/dist/client/federated.js +0 -218
- package/dist/dataset/index.d.ts +0 -2
- package/dist/dataset/index.js +0 -7
- package/dist/model_actor.d.ts +0 -16
- package/dist/model_actor.js +0 -20
- package/dist/serialization/weights.d.ts +0 -5
- package/dist/task/model_compile_data.d.ts +0 -6
- package/dist/task/model_compile_data.js +0 -12
- package/dist/tasks/lus_covid.d.ts +0 -4
- package/dist/tasks/simple_face.d.ts +0 -4
- package/dist/tasks/simple_face.js +0 -84
- package/dist/tfjs.d.ts +0 -2
- package/dist/tfjs.js +0 -6
- package/dist/training/disco.d.ts +0 -14
- package/dist/training/disco.js +0 -70
- package/dist/training/training_schemes.d.ts +0 -5
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import { Task } from '../..';
|
|
2
|
+
import { Dataset } from '../dataset';
|
|
3
|
+
export declare abstract class Data {
|
|
4
|
+
readonly dataset: Dataset;
|
|
5
|
+
readonly task: Task;
|
|
6
|
+
readonly size?: number | undefined;
|
|
7
|
+
protected constructor(dataset: Dataset, task: Task, size?: number | undefined);
|
|
8
|
+
static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
|
|
9
|
+
abstract batch(): Data;
|
|
10
|
+
abstract preprocess(): Data;
|
|
11
|
+
}
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Data = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var Data = /** @class */ (function () {
|
|
6
|
+
function Data(dataset, task, size) {
|
|
7
|
+
this.dataset = dataset;
|
|
8
|
+
this.task = task;
|
|
9
|
+
this.size = size;
|
|
10
|
+
}
|
|
11
|
+
Data.init = function (dataset, task, size) {
|
|
12
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
13
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
14
|
+
throw new Error('abstract');
|
|
15
|
+
});
|
|
16
|
+
});
|
|
17
|
+
};
|
|
18
|
+
return Data;
|
|
19
|
+
}());
|
|
20
|
+
exports.Data = Data;
|
|
File without changes
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
import { Task } from '../..';
|
|
2
|
+
import { Dataset } from '../dataset';
|
|
3
|
+
import { Data } from './data';
|
|
4
|
+
export declare class ImageData extends Data {
|
|
5
|
+
static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
|
|
6
|
+
batch(): Data;
|
|
7
|
+
preprocess(): Data;
|
|
8
|
+
}
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ImageData = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var preprocessing_1 = require("./preprocessing");
|
|
6
|
+
var data_1 = require("./data");
|
|
7
|
+
var ImageData = /** @class */ (function (_super) {
|
|
8
|
+
(0, tslib_1.__extends)(ImageData, _super);
|
|
9
|
+
function ImageData() {
|
|
10
|
+
return _super !== null && _super.apply(this, arguments) || this;
|
|
11
|
+
}
|
|
12
|
+
ImageData.init = function (dataset, task, size) {
|
|
13
|
+
var _a;
|
|
14
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
15
|
+
var sample, shape, e_1;
|
|
16
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
17
|
+
switch (_b.label) {
|
|
18
|
+
case 0:
|
|
19
|
+
if (!!((_a = task.trainingInformation.preprocessingFunctions) === null || _a === void 0 ? void 0 : _a.includes(preprocessing_1.ImagePreprocessing.Resize))) return [3 /*break*/, 4];
|
|
20
|
+
_b.label = 1;
|
|
21
|
+
case 1:
|
|
22
|
+
_b.trys.push([1, 3, , 4]);
|
|
23
|
+
return [4 /*yield*/, dataset.take(1).toArray()];
|
|
24
|
+
case 2:
|
|
25
|
+
sample = (_b.sent())[0];
|
|
26
|
+
// TODO: We suppose the presence of labels
|
|
27
|
+
// TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
|
|
28
|
+
if (!(typeof sample === 'object' && sample !== null)) {
|
|
29
|
+
throw new Error();
|
|
30
|
+
}
|
|
31
|
+
shape = void 0;
|
|
32
|
+
if ('xs' in sample && 'ys' in sample) {
|
|
33
|
+
shape = sample.xs.shape;
|
|
34
|
+
}
|
|
35
|
+
else {
|
|
36
|
+
shape = sample.shape;
|
|
37
|
+
}
|
|
38
|
+
if (!(shape[0] === task.trainingInformation.IMAGE_W &&
|
|
39
|
+
shape[1] === task.trainingInformation.IMAGE_H)) {
|
|
40
|
+
throw new Error();
|
|
41
|
+
}
|
|
42
|
+
return [3 /*break*/, 4];
|
|
43
|
+
case 3:
|
|
44
|
+
e_1 = _b.sent();
|
|
45
|
+
throw new Error('Data input format is not compatible with the chosen task');
|
|
46
|
+
case 4: return [2 /*return*/, new ImageData(dataset, task, size)];
|
|
47
|
+
}
|
|
48
|
+
});
|
|
49
|
+
});
|
|
50
|
+
};
|
|
51
|
+
ImageData.prototype.batch = function () {
|
|
52
|
+
var batchSize = this.task.trainingInformation.batchSize;
|
|
53
|
+
var newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize);
|
|
54
|
+
return new ImageData(newDataset, this.task, this.size);
|
|
55
|
+
};
|
|
56
|
+
ImageData.prototype.preprocess = function () {
|
|
57
|
+
var newDataset = this.dataset;
|
|
58
|
+
var preprocessImage = (0, preprocessing_1.getPreprocessImage)(this.task);
|
|
59
|
+
newDataset = newDataset.map(function (x) { return preprocessImage(x); });
|
|
60
|
+
return new ImageData(newDataset, this.task, this.size);
|
|
61
|
+
};
|
|
62
|
+
return ImageData;
|
|
63
|
+
}(data_1.Data));
|
|
64
|
+
exports.ImageData = ImageData;
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ImagePreprocessing = exports.TabularData = exports.ImageData = exports.Data = void 0;
|
|
4
|
+
var data_1 = require("./data");
|
|
5
|
+
Object.defineProperty(exports, "Data", { enumerable: true, get: function () { return data_1.Data; } });
|
|
6
|
+
var image_data_1 = require("./image_data");
|
|
7
|
+
Object.defineProperty(exports, "ImageData", { enumerable: true, get: function () { return image_data_1.ImageData; } });
|
|
8
|
+
var tabular_data_1 = require("./tabular_data");
|
|
9
|
+
Object.defineProperty(exports, "TabularData", { enumerable: true, get: function () { return tabular_data_1.TabularData; } });
|
|
10
|
+
var preprocessing_1 = require("./preprocessing");
|
|
11
|
+
Object.defineProperty(exports, "ImagePreprocessing", { enumerable: true, get: function () { return preprocessing_1.ImagePreprocessing; } });
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { tf, Task } from '../..';
|
|
2
|
+
declare type PreprocessImage = (image: tf.TensorContainer) => tf.TensorContainer;
|
|
3
|
+
export declare type Preprocessing = ImagePreprocessing;
|
|
4
|
+
export interface ImageTensorContainer extends tf.TensorContainerObject {
|
|
5
|
+
xs: tf.Tensor3D | tf.Tensor4D;
|
|
6
|
+
ys: tf.Tensor1D | number | undefined;
|
|
7
|
+
}
|
|
8
|
+
export declare enum ImagePreprocessing {
|
|
9
|
+
Normalize = "normalize",
|
|
10
|
+
Resize = "resize"
|
|
11
|
+
}
|
|
12
|
+
export declare function getPreprocessImage(task: Task): PreprocessImage;
|
|
13
|
+
export {};
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.getPreprocessImage = exports.ImagePreprocessing = void 0;
|
|
4
|
+
var __1 = require("../..");
|
|
5
|
+
var ImagePreprocessing;
|
|
6
|
+
(function (ImagePreprocessing) {
|
|
7
|
+
ImagePreprocessing["Normalize"] = "normalize";
|
|
8
|
+
ImagePreprocessing["Resize"] = "resize";
|
|
9
|
+
})(ImagePreprocessing = exports.ImagePreprocessing || (exports.ImagePreprocessing = {}));
|
|
10
|
+
function getPreprocessImage(task) {
|
|
11
|
+
var preprocessImage = function (tensorContainer) {
|
|
12
|
+
var _a, _b;
|
|
13
|
+
// TODO unsafe cast, tfjs does not provide the right interface
|
|
14
|
+
var info = task.trainingInformation;
|
|
15
|
+
var _c = tensorContainer, xs = _c.xs, ys = _c.ys;
|
|
16
|
+
if ((_a = info.preprocessingFunctions) === null || _a === void 0 ? void 0 : _a.includes(ImagePreprocessing.Normalize)) {
|
|
17
|
+
xs = xs.div(__1.tf.scalar(255));
|
|
18
|
+
}
|
|
19
|
+
if (((_b = info.preprocessingFunctions) === null || _b === void 0 ? void 0 : _b.includes(ImagePreprocessing.Resize)) &&
|
|
20
|
+
info.IMAGE_H !== undefined &&
|
|
21
|
+
info.IMAGE_W !== undefined) {
|
|
22
|
+
xs = __1.tf.image.resizeBilinear(xs, [
|
|
23
|
+
info.IMAGE_H, info.IMAGE_W
|
|
24
|
+
]);
|
|
25
|
+
}
|
|
26
|
+
return {
|
|
27
|
+
xs: xs,
|
|
28
|
+
ys: ys
|
|
29
|
+
};
|
|
30
|
+
};
|
|
31
|
+
return preprocessImage;
|
|
32
|
+
}
|
|
33
|
+
exports.getPreprocessImage = getPreprocessImage;
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
import { Task } from '../..';
|
|
2
|
+
import { Dataset } from '../dataset';
|
|
3
|
+
import { Data } from './data';
|
|
4
|
+
export declare class TabularData extends Data {
|
|
5
|
+
static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
|
|
6
|
+
batch(): Data;
|
|
7
|
+
preprocess(): Data;
|
|
8
|
+
}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.TabularData = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var data_1 = require("./data");
|
|
6
|
+
var TabularData = /** @class */ (function (_super) {
|
|
7
|
+
(0, tslib_1.__extends)(TabularData, _super);
|
|
8
|
+
function TabularData() {
|
|
9
|
+
return _super !== null && _super.apply(this, arguments) || this;
|
|
10
|
+
}
|
|
11
|
+
TabularData.init = function (dataset, task, size) {
|
|
12
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
13
|
+
var e_1;
|
|
14
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
15
|
+
switch (_a.label) {
|
|
16
|
+
case 0:
|
|
17
|
+
_a.trys.push([0, 2, , 3]);
|
|
18
|
+
return [4 /*yield*/, dataset.iterator()];
|
|
19
|
+
case 1:
|
|
20
|
+
_a.sent();
|
|
21
|
+
return [3 /*break*/, 3];
|
|
22
|
+
case 2:
|
|
23
|
+
e_1 = _a.sent();
|
|
24
|
+
throw new Error('Data input format is not compatible with the chosen task');
|
|
25
|
+
case 3: return [2 /*return*/, new TabularData(dataset, task, size)];
|
|
26
|
+
}
|
|
27
|
+
});
|
|
28
|
+
});
|
|
29
|
+
};
|
|
30
|
+
TabularData.prototype.batch = function () {
|
|
31
|
+
var batchSize = this.task.trainingInformation.batchSize;
|
|
32
|
+
var newDataset = batchSize === undefined ? this.dataset : this.dataset.batch(batchSize);
|
|
33
|
+
return new TabularData(newDataset, this.task, this.size);
|
|
34
|
+
};
|
|
35
|
+
TabularData.prototype.preprocess = function () {
|
|
36
|
+
return this;
|
|
37
|
+
};
|
|
38
|
+
return TabularData;
|
|
39
|
+
}(data_1.Data));
|
|
40
|
+
exports.TabularData = TabularData;
|
|
@@ -1,22 +1,15 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
1
|
+
import { Task } from '../..';
|
|
2
|
+
import { Dataset } from '../dataset';
|
|
3
|
+
import { DataSplit } from '../data';
|
|
3
4
|
export interface DataConfig {
|
|
4
5
|
features?: string[];
|
|
5
6
|
labels?: string[];
|
|
6
7
|
shuffle?: boolean;
|
|
7
8
|
validationSplit?: number;
|
|
8
9
|
}
|
|
9
|
-
export interface Data {
|
|
10
|
-
dataset: Dataset;
|
|
11
|
-
size: number;
|
|
12
|
-
}
|
|
13
|
-
export interface DataTuple {
|
|
14
|
-
train: Data;
|
|
15
|
-
validation?: Data;
|
|
16
|
-
}
|
|
17
10
|
export declare abstract class DataLoader<Source> {
|
|
18
11
|
protected task: Task;
|
|
19
12
|
constructor(task: Task);
|
|
20
13
|
abstract load(source: Source, config: DataConfig): Promise<Dataset>;
|
|
21
|
-
abstract loadAll(sources: Source[], config: DataConfig): Promise<
|
|
14
|
+
abstract loadAll(sources: Source[], config: DataConfig): Promise<DataSplit>;
|
|
22
15
|
}
|
|
File without changes
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import { tf } from '../..';
|
|
2
|
+
import { Dataset } from '../dataset';
|
|
3
|
+
import { DataSplit } from '../data';
|
|
4
|
+
import { DataLoader, DataConfig } from '../data_loader';
|
|
5
|
+
/**
|
|
6
|
+
* TODO @s314cy:
|
|
7
|
+
* Load labels and correctly match them with their respective images, with the following constraints:
|
|
8
|
+
* 1. Images are given as 1 image/1 file
|
|
9
|
+
* 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels
|
|
10
|
+
*/
|
|
11
|
+
export declare abstract class ImageLoader<Source> extends DataLoader<Source> {
|
|
12
|
+
abstract readImageFrom(source: Source): Promise<tf.Tensor3D>;
|
|
13
|
+
load(image: Source, config?: DataConfig): Promise<Dataset>;
|
|
14
|
+
private buildDataset;
|
|
15
|
+
loadAll(images: Source[], config?: DataConfig): Promise<DataSplit>;
|
|
16
|
+
shuffle(array: number[]): void;
|
|
17
|
+
}
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ImageLoader = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var __1 = require("../..");
|
|
7
|
+
var data_1 = require("../data");
|
|
8
|
+
var data_loader_1 = require("../data_loader");
|
|
9
|
+
/**
|
|
10
|
+
* TODO @s314cy:
|
|
11
|
+
* Load labels and correctly match them with their respective images, with the following constraints:
|
|
12
|
+
* 1. Images are given as 1 image/1 file
|
|
13
|
+
* 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels
|
|
14
|
+
*/
|
|
15
|
+
var ImageLoader = /** @class */ (function (_super) {
|
|
16
|
+
(0, tslib_1.__extends)(ImageLoader, _super);
|
|
17
|
+
function ImageLoader() {
|
|
18
|
+
return _super !== null && _super.apply(this, arguments) || this;
|
|
19
|
+
}
|
|
20
|
+
ImageLoader.prototype.load = function (image, config) {
|
|
21
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
22
|
+
var tensorContainer;
|
|
23
|
+
var _a;
|
|
24
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
25
|
+
switch (_b.label) {
|
|
26
|
+
case 0:
|
|
27
|
+
if (!(config === undefined || config.labels === undefined)) return [3 /*break*/, 2];
|
|
28
|
+
return [4 /*yield*/, this.readImageFrom(image)];
|
|
29
|
+
case 1:
|
|
30
|
+
tensorContainer = _b.sent();
|
|
31
|
+
return [3 /*break*/, 4];
|
|
32
|
+
case 2:
|
|
33
|
+
_a = {};
|
|
34
|
+
return [4 /*yield*/, this.readImageFrom(image)];
|
|
35
|
+
case 3:
|
|
36
|
+
tensorContainer = (_a.xs = _b.sent(),
|
|
37
|
+
_a.ys = config.labels[0],
|
|
38
|
+
_a);
|
|
39
|
+
_b.label = 4;
|
|
40
|
+
case 4: return [2 /*return*/, __1.tf.data.array([tensorContainer])];
|
|
41
|
+
}
|
|
42
|
+
});
|
|
43
|
+
});
|
|
44
|
+
};
|
|
45
|
+
ImageLoader.prototype.buildDataset = function (images, labels, indices, config) {
|
|
46
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
47
|
+
var dataset;
|
|
48
|
+
var _this = this;
|
|
49
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
50
|
+
switch (_a.label) {
|
|
51
|
+
case 0:
|
|
52
|
+
dataset = __1.tf.data.generator(function () {
|
|
53
|
+
var withLabels = (config === null || config === void 0 ? void 0 : config.labels) !== undefined;
|
|
54
|
+
var index = 0;
|
|
55
|
+
var iterator = {
|
|
56
|
+
next: function () { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
|
|
57
|
+
var sample, label, value;
|
|
58
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
59
|
+
switch (_a.label) {
|
|
60
|
+
case 0:
|
|
61
|
+
if (index === indices.length) {
|
|
62
|
+
return [2 /*return*/, { done: true }];
|
|
63
|
+
}
|
|
64
|
+
return [4 /*yield*/, this.readImageFrom(images[indices[index]])];
|
|
65
|
+
case 1:
|
|
66
|
+
sample = _a.sent();
|
|
67
|
+
label = withLabels ? labels[indices[index]] : undefined;
|
|
68
|
+
value = withLabels ? { xs: sample, ys: label } : sample;
|
|
69
|
+
index++;
|
|
70
|
+
return [2 /*return*/, {
|
|
71
|
+
value: value,
|
|
72
|
+
done: false
|
|
73
|
+
}];
|
|
74
|
+
}
|
|
75
|
+
});
|
|
76
|
+
}); }
|
|
77
|
+
};
|
|
78
|
+
return iterator; // Lazy
|
|
79
|
+
});
|
|
80
|
+
return [4 /*yield*/, data_1.ImageData.init(dataset, this.task, indices.length)];
|
|
81
|
+
case 1: return [2 /*return*/, _a.sent()];
|
|
82
|
+
}
|
|
83
|
+
});
|
|
84
|
+
});
|
|
85
|
+
};
|
|
86
|
+
ImageLoader.prototype.loadAll = function (images, config) {
|
|
87
|
+
var _a, _b;
|
|
88
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
89
|
+
var labels, indices, numberOfClasses, dataset, trainSize, trainIndices, valIndices, trainDataset, valDataset;
|
|
90
|
+
return (0, tslib_1.__generator)(this, function (_c) {
|
|
91
|
+
switch (_c.label) {
|
|
92
|
+
case 0:
|
|
93
|
+
labels = [];
|
|
94
|
+
indices = (0, immutable_1.Range)(0, images.length).toArray();
|
|
95
|
+
if ((config === null || config === void 0 ? void 0 : config.labels) !== undefined) {
|
|
96
|
+
numberOfClasses = (_b = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.LABEL_LIST) === null || _b === void 0 ? void 0 : _b.length;
|
|
97
|
+
if (numberOfClasses === undefined) {
|
|
98
|
+
throw new Error('wanted labels but none found in task');
|
|
99
|
+
}
|
|
100
|
+
labels = __1.tf.oneHot(__1.tf.tensor1d(config.labels, 'int32'), numberOfClasses).arraySync();
|
|
101
|
+
}
|
|
102
|
+
if ((config === null || config === void 0 ? void 0 : config.shuffle) === undefined || (config === null || config === void 0 ? void 0 : config.shuffle)) {
|
|
103
|
+
this.shuffle(indices);
|
|
104
|
+
}
|
|
105
|
+
if (!((config === null || config === void 0 ? void 0 : config.validationSplit) === undefined || (config === null || config === void 0 ? void 0 : config.validationSplit) === 0)) return [3 /*break*/, 2];
|
|
106
|
+
return [4 /*yield*/, this.buildDataset(images, labels, indices, config)];
|
|
107
|
+
case 1:
|
|
108
|
+
dataset = _c.sent();
|
|
109
|
+
return [2 /*return*/, {
|
|
110
|
+
train: dataset,
|
|
111
|
+
validation: undefined
|
|
112
|
+
}];
|
|
113
|
+
case 2:
|
|
114
|
+
trainSize = Math.floor(images.length * (1 - config.validationSplit));
|
|
115
|
+
trainIndices = indices.slice(0, trainSize);
|
|
116
|
+
valIndices = indices.slice(trainSize);
|
|
117
|
+
return [4 /*yield*/, this.buildDataset(images, labels, trainIndices, config)];
|
|
118
|
+
case 3:
|
|
119
|
+
trainDataset = _c.sent();
|
|
120
|
+
return [4 /*yield*/, this.buildDataset(images, labels, valIndices, config)];
|
|
121
|
+
case 4:
|
|
122
|
+
valDataset = _c.sent();
|
|
123
|
+
return [2 /*return*/, {
|
|
124
|
+
train: trainDataset,
|
|
125
|
+
validation: valDataset
|
|
126
|
+
}];
|
|
127
|
+
}
|
|
128
|
+
});
|
|
129
|
+
});
|
|
130
|
+
};
|
|
131
|
+
ImageLoader.prototype.shuffle = function (array) {
|
|
132
|
+
for (var i = 0; i < array.length; i++) {
|
|
133
|
+
var j = Math.floor(Math.random() * i);
|
|
134
|
+
var swap = array[i];
|
|
135
|
+
array[i] = array[j];
|
|
136
|
+
array[j] = swap;
|
|
137
|
+
}
|
|
138
|
+
};
|
|
139
|
+
return ImageLoader;
|
|
140
|
+
}(data_loader_1.DataLoader));
|
|
141
|
+
exports.ImageLoader = ImageLoader;
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.TabularLoader = exports.ImageLoader = exports.DataLoader = void 0;
|
|
4
|
+
var data_loader_1 = require("./data_loader");
|
|
5
|
+
Object.defineProperty(exports, "DataLoader", { enumerable: true, get: function () { return data_loader_1.DataLoader; } });
|
|
6
|
+
var image_loader_1 = require("./image_loader");
|
|
7
|
+
Object.defineProperty(exports, "ImageLoader", { enumerable: true, get: function () { return image_loader_1.ImageLoader; } });
|
|
8
|
+
var tabular_loader_1 = require("./tabular_loader");
|
|
9
|
+
Object.defineProperty(exports, "TabularLoader", { enumerable: true, get: function () { return tabular_loader_1.TabularLoader; } });
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import { tf, Task } from '../..';
|
|
2
|
+
import { Dataset } from '../dataset';
|
|
3
|
+
import { DataSplit } from '../data';
|
|
4
|
+
import { DataLoader, DataConfig } from '../data_loader';
|
|
5
|
+
export declare abstract class TabularLoader<Source> extends DataLoader<Source> {
|
|
6
|
+
private readonly delimiter;
|
|
7
|
+
constructor(task: Task, delimiter: string);
|
|
8
|
+
/**
|
|
9
|
+
* Creates a CSV dataset object based off the given source.
|
|
10
|
+
* @param source File object, URL string or local file system path.
|
|
11
|
+
* @param csvConfig Object expected by TF.js to create a CSVDataset.
|
|
12
|
+
* @returns The CSVDataset object built upon the given source.
|
|
13
|
+
*/
|
|
14
|
+
abstract loadTabularDatasetFrom(source: Source, csvConfig: Record<string, unknown>): tf.data.CSVDataset;
|
|
15
|
+
/**
|
|
16
|
+
* Expects delimiter-separated tabular data made of N columns. The data may be
|
|
17
|
+
* potentially split among several sources. Every source should contain N-1
|
|
18
|
+
* feature columns and 1 single label column.
|
|
19
|
+
* @param source List of File objects, URLs or file system paths.
|
|
20
|
+
* @param config
|
|
21
|
+
* @returns A TF.js dataset built upon read tabular data stored in the given sources.
|
|
22
|
+
*/
|
|
23
|
+
load(source: Source, config?: DataConfig): Promise<Dataset>;
|
|
24
|
+
/**
|
|
25
|
+
* Creates the CSV datasets based off the given sources, then fuses them into a single CSV
|
|
26
|
+
* dataset.
|
|
27
|
+
*/
|
|
28
|
+
loadAll(sources: Source[], config: DataConfig): Promise<DataSplit>;
|
|
29
|
+
}
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.TabularLoader = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var data_1 = require("../data");
|
|
7
|
+
var data_loader_1 = require("../data_loader");
|
|
8
|
+
// window size from which the dataset shuffling will sample
|
|
9
|
+
var BUFFER_SIZE = 1000;
|
|
10
|
+
var TabularLoader = /** @class */ (function (_super) {
|
|
11
|
+
(0, tslib_1.__extends)(TabularLoader, _super);
|
|
12
|
+
function TabularLoader(task, delimiter) {
|
|
13
|
+
var _this = _super.call(this, task) || this;
|
|
14
|
+
_this.delimiter = delimiter;
|
|
15
|
+
return _this;
|
|
16
|
+
}
|
|
17
|
+
/**
|
|
18
|
+
* Expects delimiter-separated tabular data made of N columns. The data may be
|
|
19
|
+
* potentially split among several sources. Every source should contain N-1
|
|
20
|
+
* feature columns and 1 single label column.
|
|
21
|
+
* @param source List of File objects, URLs or file system paths.
|
|
22
|
+
* @param config
|
|
23
|
+
* @returns A TF.js dataset built upon read tabular data stored in the given sources.
|
|
24
|
+
*/
|
|
25
|
+
TabularLoader.prototype.load = function (source, config) {
|
|
26
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
27
|
+
var columnConfigs, csvConfig, dataset;
|
|
28
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
29
|
+
/**
|
|
30
|
+
* Prepare the CSV config object based off the given features and labels.
|
|
31
|
+
* If labels is empty, then the returned dataset is comprised of samples only.
|
|
32
|
+
* Otherwise, each entry is of the form `{ xs, ys }` with `xs` as features and `ys`
|
|
33
|
+
* as labels.
|
|
34
|
+
*/
|
|
35
|
+
if ((config === null || config === void 0 ? void 0 : config.features) === undefined) {
|
|
36
|
+
// TODO @s314cy
|
|
37
|
+
throw new Error('Not implemented');
|
|
38
|
+
}
|
|
39
|
+
columnConfigs = (0, immutable_1.Map)((0, immutable_1.Set)(config.features).map(function (feature) { return [feature, { required: false, isLabel: false }]; })).merge((0, immutable_1.Set)(config.labels).map(function (label) { return [label, { required: true, isLabel: true }]; }));
|
|
40
|
+
csvConfig = {
|
|
41
|
+
hasHeader: true,
|
|
42
|
+
columnConfigs: columnConfigs.toObject(),
|
|
43
|
+
configuredColumnsOnly: true,
|
|
44
|
+
delimiter: this.delimiter
|
|
45
|
+
};
|
|
46
|
+
dataset = this.loadTabularDatasetFrom(source, csvConfig).map(function (t) {
|
|
47
|
+
if (typeof t === 'object' && ('xs' in t) && ('ys' in t)) {
|
|
48
|
+
return t;
|
|
49
|
+
}
|
|
50
|
+
throw new TypeError('Expected TensorContainerObject');
|
|
51
|
+
}).map(function (t) {
|
|
52
|
+
// TODO order may not be stable between tensor
|
|
53
|
+
var _a = t, xs = _a.xs, ys = _a.ys;
|
|
54
|
+
return {
|
|
55
|
+
xs: Object.values(xs),
|
|
56
|
+
ys: Object.values(ys)
|
|
57
|
+
};
|
|
58
|
+
});
|
|
59
|
+
return [2 /*return*/, ((config === null || config === void 0 ? void 0 : config.shuffle) === undefined || (config === null || config === void 0 ? void 0 : config.shuffle)) ? dataset.shuffle(BUFFER_SIZE) : dataset];
|
|
60
|
+
});
|
|
61
|
+
});
|
|
62
|
+
};
|
|
63
|
+
/**
|
|
64
|
+
* Creates the CSV datasets based off the given sources, then fuses them into a single CSV
|
|
65
|
+
* dataset.
|
|
66
|
+
*/
|
|
67
|
+
TabularLoader.prototype.loadAll = function (sources, config) {
|
|
68
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
69
|
+
var datasets, dataset, data;
|
|
70
|
+
var _this = this;
|
|
71
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
72
|
+
switch (_a.label) {
|
|
73
|
+
case 0: return [4 /*yield*/, Promise.all(sources.map(function (source) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
74
|
+
switch (_a.label) {
|
|
75
|
+
case 0: return [4 /*yield*/, this.load(source, (0, tslib_1.__assign)((0, tslib_1.__assign)({}, config), { shuffle: false }))];
|
|
76
|
+
case 1: return [2 /*return*/, _a.sent()];
|
|
77
|
+
}
|
|
78
|
+
}); }); }))];
|
|
79
|
+
case 1:
|
|
80
|
+
datasets = _a.sent();
|
|
81
|
+
dataset = (0, immutable_1.List)(datasets).reduce(function (acc, dataset) { return acc.concatenate(dataset); });
|
|
82
|
+
dataset = (config === null || config === void 0 ? void 0 : config.shuffle) ? dataset.shuffle(BUFFER_SIZE) : dataset;
|
|
83
|
+
return [4 /*yield*/, data_1.TabularData.init(dataset, this.task,
|
|
84
|
+
// dataset.size does not work for csv datasets
|
|
85
|
+
// https://github.com/tensorflow/tfjs/issues/5845
|
|
86
|
+
undefined)
|
|
87
|
+
// TODO: Implement validation split for tabular data (tricky due to streaming)
|
|
88
|
+
];
|
|
89
|
+
case 2:
|
|
90
|
+
data = _a.sent();
|
|
91
|
+
// TODO: Implement validation split for tabular data (tricky due to streaming)
|
|
92
|
+
return [2 /*return*/, {
|
|
93
|
+
train: data
|
|
94
|
+
}];
|
|
95
|
+
}
|
|
96
|
+
});
|
|
97
|
+
});
|
|
98
|
+
};
|
|
99
|
+
return TabularLoader;
|
|
100
|
+
}(data_loader_1.DataLoader));
|
|
101
|
+
exports.TabularLoader = TabularLoader;
|
|
File without changes
|
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
import
|
|
2
|
-
import {
|
|
3
|
-
import {
|
|
4
|
-
export declare type Dataset = tf.data.Dataset<tf.TensorContainer>;
|
|
1
|
+
import { Task } from '..';
|
|
2
|
+
import { DataSplit } from './data';
|
|
3
|
+
import { DataConfig, DataLoader } from './data_loader/data_loader';
|
|
5
4
|
export declare class DatasetBuilder<Source> {
|
|
6
5
|
private readonly task;
|
|
7
6
|
private readonly dataLoader;
|
|
@@ -11,8 +10,9 @@ export declare class DatasetBuilder<Source> {
|
|
|
11
10
|
constructor(dataLoader: DataLoader<Source>, task: Task);
|
|
12
11
|
addFiles(sources: Source[], label?: string): void;
|
|
13
12
|
clearFiles(label?: string): void;
|
|
13
|
+
private resetBuiltState;
|
|
14
14
|
private getLabels;
|
|
15
|
-
build(config?: DataConfig): Promise<
|
|
15
|
+
build(config?: DataConfig): Promise<DataSplit>;
|
|
16
16
|
isBuilt(): boolean;
|
|
17
17
|
size(): number;
|
|
18
18
|
}
|