@epfml/discojs 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/README.md +32 -0
  2. package/dist/aggregation.d.ts +3 -0
  3. package/dist/aggregation.js +19 -0
  4. package/dist/async_buffer.d.ts +41 -0
  5. package/dist/async_buffer.js +98 -0
  6. package/dist/async_informant.d.ts +20 -0
  7. package/dist/async_informant.js +69 -0
  8. package/dist/client/base.d.ts +36 -0
  9. package/dist/client/base.js +34 -0
  10. package/dist/client/decentralized.d.ts +23 -0
  11. package/dist/client/decentralized.js +275 -0
  12. package/dist/client/federated.d.ts +30 -0
  13. package/dist/client/federated.js +221 -0
  14. package/dist/client/index.d.ts +4 -0
  15. package/dist/client/index.js +11 -0
  16. package/dist/client/local.d.ts +8 -0
  17. package/dist/client/local.js +34 -0
  18. package/dist/dataset/data_loader/data_loader.d.ts +16 -0
  19. package/dist/dataset/data_loader/data_loader.js +10 -0
  20. package/dist/dataset/data_loader/image_loader.d.ts +14 -0
  21. package/dist/dataset/data_loader/image_loader.js +93 -0
  22. package/dist/dataset/data_loader/index.d.ts +3 -0
  23. package/dist/dataset/data_loader/index.js +9 -0
  24. package/dist/dataset/data_loader/tabular_loader.d.ts +29 -0
  25. package/dist/dataset/data_loader/tabular_loader.js +88 -0
  26. package/dist/dataset/dataset_builder.d.ts +17 -0
  27. package/dist/dataset/dataset_builder.js +80 -0
  28. package/dist/dataset/index.d.ts +2 -0
  29. package/dist/dataset/index.js +7 -0
  30. package/dist/index.d.ts +17 -0
  31. package/dist/index.js +34 -0
  32. package/dist/logging/console_logger.d.ts +18 -0
  33. package/dist/logging/console_logger.js +33 -0
  34. package/dist/logging/index.d.ts +2 -0
  35. package/dist/logging/index.js +7 -0
  36. package/dist/logging/logger.d.ts +12 -0
  37. package/dist/logging/logger.js +9 -0
  38. package/dist/logging/trainer_logger.d.ts +24 -0
  39. package/dist/logging/trainer_logger.js +59 -0
  40. package/dist/memory/base.d.ts +53 -0
  41. package/dist/memory/base.js +9 -0
  42. package/dist/memory/empty.d.ts +12 -0
  43. package/dist/memory/empty.js +69 -0
  44. package/dist/memory/index.d.ts +3 -0
  45. package/dist/memory/index.js +9 -0
  46. package/dist/memory/model_type.d.ts +4 -0
  47. package/dist/memory/model_type.js +9 -0
  48. package/dist/model_actor.d.ts +16 -0
  49. package/dist/model_actor.js +20 -0
  50. package/dist/privacy.d.ts +12 -0
  51. package/dist/privacy.js +60 -0
  52. package/dist/serialization/index.d.ts +2 -0
  53. package/dist/serialization/index.js +6 -0
  54. package/dist/serialization/model.d.ts +5 -0
  55. package/dist/serialization/model.js +55 -0
  56. package/dist/serialization/weights.d.ts +5 -0
  57. package/dist/serialization/weights.js +62 -0
  58. package/dist/task/data_example.d.ts +5 -0
  59. package/dist/task/data_example.js +24 -0
  60. package/dist/task/display_information.d.ts +15 -0
  61. package/dist/task/display_information.js +53 -0
  62. package/dist/task/index.d.ts +3 -0
  63. package/dist/task/index.js +8 -0
  64. package/dist/task/model_compile_data.d.ts +6 -0
  65. package/dist/task/model_compile_data.js +12 -0
  66. package/dist/task/task.d.ts +10 -0
  67. package/dist/task/task.js +32 -0
  68. package/dist/task/training_information.d.ts +29 -0
  69. package/dist/task/training_information.js +2 -0
  70. package/dist/tasks/cifar10.d.ts +4 -0
  71. package/dist/tasks/cifar10.js +74 -0
  72. package/dist/tasks/index.d.ts +5 -0
  73. package/dist/tasks/index.js +9 -0
  74. package/dist/tasks/lus_covid.d.ts +4 -0
  75. package/dist/tasks/lus_covid.js +48 -0
  76. package/dist/tasks/mnist.d.ts +4 -0
  77. package/dist/tasks/mnist.js +56 -0
  78. package/dist/tasks/simple_face.d.ts +4 -0
  79. package/dist/tasks/simple_face.js +88 -0
  80. package/dist/tasks/titanic.d.ts +4 -0
  81. package/dist/tasks/titanic.js +86 -0
  82. package/dist/testing/tester.d.ts +5 -0
  83. package/dist/testing/tester.js +21 -0
  84. package/dist/training/disco.d.ts +12 -0
  85. package/dist/training/disco.js +62 -0
  86. package/dist/training/index.d.ts +2 -0
  87. package/dist/training/index.js +7 -0
  88. package/dist/training/trainer/distributed_trainer.d.ts +21 -0
  89. package/dist/training/trainer/distributed_trainer.js +60 -0
  90. package/dist/training/trainer/local_trainer.d.ts +10 -0
  91. package/dist/training/trainer/local_trainer.js +37 -0
  92. package/dist/training/trainer/round_tracker.d.ts +30 -0
  93. package/dist/training/trainer/round_tracker.js +44 -0
  94. package/dist/training/trainer/trainer.d.ts +66 -0
  95. package/dist/training/trainer/trainer.js +146 -0
  96. package/dist/training/trainer/trainer_builder.d.ts +25 -0
  97. package/dist/training/trainer/trainer_builder.js +102 -0
  98. package/dist/training/training_schemes.d.ts +5 -0
  99. package/dist/training/training_schemes.js +10 -0
  100. package/dist/training_informant.d.ts +88 -0
  101. package/dist/training_informant.js +135 -0
  102. package/dist/types.d.ts +4 -0
  103. package/dist/types.js +2 -0
  104. package/package.json +48 -0
@@ -0,0 +1,30 @@
1
+ import { TrainingInformant, MetadataID, Weights } from '..';
2
+ import { Base } from './base';
3
+ /**
4
+ * Class that deals with communication with the centralized server when training
5
+ * a specific task in the federated setting.
6
+ */
7
+ export declare class Federated extends Base {
8
+ private readonly clientID;
9
+ private readonly peer;
10
+ private round;
11
+ private urlTo;
12
+ private urlToMetadata;
13
+ /**
14
+ * Initialize the connection to the server. TODO: In the case of FeAI,
15
+ * should return the current server-side round for the task.
16
+ */
17
+ connect(): Promise<void>;
18
+ /**
19
+ * Disconnection process when user quits the task.
20
+ */
21
+ disconnect(): Promise<void>;
22
+ postWeightsToServer(weights: Weights): Promise<void>;
23
+ postMetadata(metadataID: string, metadata: string): Promise<void>;
24
+ getMetadataMap(metadataID: MetadataID): Promise<Map<string, unknown>>;
25
+ getLatestServerRound(): Promise<number>;
26
+ pullRoundAndFetchWeights(): Promise<Weights | undefined>;
27
+ pullServerStatistics(trainingInformant: TrainingInformant): Promise<void>;
28
+ onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, _: number, trainingInformant: TrainingInformant): Promise<Weights>;
29
+ onTrainEndCommunication(_: Weights, trainingInformant: TrainingInformant): Promise<void>;
30
+ }
@@ -0,0 +1,221 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Federated = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var msgpack = (0, tslib_1.__importStar)(require("msgpack-lite"));
6
+ var axios_1 = (0, tslib_1.__importDefault)(require("axios"));
7
+ var uuid_1 = require("uuid");
8
+ var __1 = require("..");
9
+ var base_1 = require("./base");
10
+ /**
11
+ * Class that deals with communication with the centralized server when training
12
+ * a specific task in the federated setting.
13
+ */
14
+ var Federated = /** @class */ (function (_super) {
15
+ (0, tslib_1.__extends)(Federated, _super);
16
+ function Federated() {
17
+ var _this = _super !== null && _super.apply(this, arguments) || this;
18
+ _this.clientID = (0, uuid_1.v4)();
19
+ _this.round = 0;
20
+ return _this;
21
+ }
22
+ Federated.prototype.urlTo = function (category) {
23
+ var url = new URL('', this.url);
24
+ url.pathname += [
25
+ 'feai',
26
+ category,
27
+ this.task.taskID,
28
+ this.clientID
29
+ ].join('/');
30
+ return url.href;
31
+ };
32
+ Federated.prototype.urlToMetadata = function (metadataID) {
33
+ var url = new URL('', this.url);
34
+ url.pathname += [
35
+ 'feai',
36
+ 'metadata',
37
+ metadataID,
38
+ this.task.taskID,
39
+ this.round,
40
+ this.clientID
41
+ ].join('/');
42
+ return url.href;
43
+ };
44
+ /**
45
+ * Initialize the connection to the server. TODO: In the case of FeAI,
46
+ * should return the current server-side round for the task.
47
+ */
48
+ Federated.prototype.connect = function () {
49
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
50
+ return (0, tslib_1.__generator)(this, function (_a) {
51
+ switch (_a.label) {
52
+ case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('connect'))];
53
+ case 1:
54
+ _a.sent();
55
+ return [2 /*return*/];
56
+ }
57
+ });
58
+ });
59
+ };
60
+ /**
61
+ * Disconnection process when user quits the task.
62
+ */
63
+ Federated.prototype.disconnect = function () {
64
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
65
+ return (0, tslib_1.__generator)(this, function (_a) {
66
+ switch (_a.label) {
67
+ case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('disconnect'))];
68
+ case 1:
69
+ _a.sent();
70
+ return [2 /*return*/];
71
+ }
72
+ });
73
+ });
74
+ };
75
+ Federated.prototype.postWeightsToServer = function (weights) {
76
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
77
+ var _a;
78
+ var _b, _c;
79
+ return (0, tslib_1.__generator)(this, function (_d) {
80
+ switch (_d.label) {
81
+ case 0:
82
+ _a = axios_1.default;
83
+ _b = {
84
+ method: 'post',
85
+ url: this.urlTo('weights')
86
+ };
87
+ _c = {};
88
+ return [4 /*yield*/, __1.serialization.weights.encode(weights)];
89
+ case 1: return [4 /*yield*/, _a.apply(void 0, [(_b.data = (_c.weights = _d.sent(),
90
+ _c.round = this.round,
91
+ _c),
92
+ _b)])];
93
+ case 2:
94
+ _d.sent();
95
+ return [2 /*return*/];
96
+ }
97
+ });
98
+ });
99
+ };
100
+ Federated.prototype.postMetadata = function (metadataID, metadata) {
101
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
102
+ return (0, tslib_1.__generator)(this, function (_a) {
103
+ switch (_a.label) {
104
+ case 0: return [4 /*yield*/, (0, axios_1.default)({
105
+ method: 'post',
106
+ url: this.urlToMetadata(metadataID),
107
+ data: {
108
+ metadataID: metadata
109
+ }
110
+ })];
111
+ case 1:
112
+ _a.sent();
113
+ return [2 /*return*/];
114
+ }
115
+ });
116
+ });
117
+ };
118
+ Federated.prototype.getMetadataMap = function (metadataID) {
119
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
120
+ var response, body;
121
+ return (0, tslib_1.__generator)(this, function (_a) {
122
+ switch (_a.label) {
123
+ case 0: return [4 /*yield*/, axios_1.default.get(this.urlToMetadata(metadataID))];
124
+ case 1:
125
+ response = _a.sent();
126
+ return [4 /*yield*/, response.data];
127
+ case 2:
128
+ body = _a.sent();
129
+ return [2 /*return*/, new Map(msgpack.decode(body[metadataID]))];
130
+ }
131
+ });
132
+ });
133
+ };
134
+ Federated.prototype.getLatestServerRound = function () {
135
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
136
+ var response;
137
+ return (0, tslib_1.__generator)(this, function (_a) {
138
+ switch (_a.label) {
139
+ case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('round'))];
140
+ case 1:
141
+ response = _a.sent();
142
+ if (response.status === 200) {
143
+ return [2 /*return*/, response.data.round];
144
+ }
145
+ console.log('Error getting weights: code', response.status);
146
+ return [2 /*return*/, -1];
147
+ }
148
+ });
149
+ });
150
+ };
151
+ Federated.prototype.pullRoundAndFetchWeights = function () {
152
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
153
+ var serverRound, response, serverWeights;
154
+ return (0, tslib_1.__generator)(this, function (_a) {
155
+ switch (_a.label) {
156
+ case 0: return [4 /*yield*/, this.getLatestServerRound()];
157
+ case 1:
158
+ serverRound = _a.sent();
159
+ return [4 /*yield*/, axios_1.default.get(this.urlTo('weights'))];
160
+ case 2:
161
+ response = _a.sent();
162
+ serverWeights = __1.serialization.weights.decode(response.data);
163
+ if (this.round < serverRound) {
164
+ // Update the local round to match the server's
165
+ this.round = serverRound;
166
+ return [2 /*return*/, serverWeights];
167
+ }
168
+ else {
169
+ return [2 /*return*/, undefined];
170
+ }
171
+ return [2 /*return*/];
172
+ }
173
+ });
174
+ });
175
+ };
176
+ Federated.prototype.pullServerStatistics = function (trainingInformant) {
177
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
178
+ var response;
179
+ return (0, tslib_1.__generator)(this, function (_a) {
180
+ switch (_a.label) {
181
+ case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('statistics'))];
182
+ case 1:
183
+ response = _a.sent();
184
+ trainingInformant.updateWithServerStatistics(response.data.statistics);
185
+ return [2 /*return*/];
186
+ }
187
+ });
188
+ });
189
+ };
190
+ Federated.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, _, trainingInformant) {
191
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
192
+ var noisyWeights, serverWeights;
193
+ return (0, tslib_1.__generator)(this, function (_a) {
194
+ switch (_a.label) {
195
+ case 0:
196
+ noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
197
+ return [4 /*yield*/, this.postWeightsToServer(noisyWeights)];
198
+ case 1:
199
+ _a.sent();
200
+ return [4 /*yield*/, this.pullServerStatistics(trainingInformant)];
201
+ case 2:
202
+ _a.sent();
203
+ return [4 /*yield*/, this.pullRoundAndFetchWeights()];
204
+ case 3:
205
+ serverWeights = _a.sent();
206
+ return [2 /*return*/, serverWeights !== null && serverWeights !== void 0 ? serverWeights : staleWeights];
207
+ }
208
+ });
209
+ });
210
+ };
211
+ Federated.prototype.onTrainEndCommunication = function (_, trainingInformant) {
212
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
213
+ return (0, tslib_1.__generator)(this, function (_a) {
214
+ trainingInformant.addMessage('Training finished.');
215
+ return [2 /*return*/];
216
+ });
217
+ });
218
+ };
219
+ return Federated;
220
+ }(base_1.Base));
221
+ exports.Federated = Federated;
@@ -0,0 +1,4 @@
1
+ export { Base } from './base';
2
+ export { Decentralized } from './decentralized';
3
+ export { Federated } from './federated';
4
+ export { Local } from './local';
@@ -0,0 +1,11 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Local = exports.Federated = exports.Decentralized = exports.Base = void 0;
4
+ var base_1 = require("./base");
5
+ Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
6
+ var decentralized_1 = require("./decentralized");
7
+ Object.defineProperty(exports, "Decentralized", { enumerable: true, get: function () { return decentralized_1.Decentralized; } });
8
+ var federated_1 = require("./federated");
9
+ Object.defineProperty(exports, "Federated", { enumerable: true, get: function () { return federated_1.Federated; } });
10
+ var local_1 = require("./local");
11
+ Object.defineProperty(exports, "Local", { enumerable: true, get: function () { return local_1.Local; } });
@@ -0,0 +1,8 @@
1
+ import { Weights } from '../types';
2
+ import { Base } from './base';
3
+ export declare class Local extends Base {
4
+ connect(): Promise<void>;
5
+ disconnect(): Promise<void>;
6
+ onRoundEndCommunication(_: Weights): Promise<Weights>;
7
+ onTrainEndCommunication(): Promise<void>;
8
+ }
@@ -0,0 +1,34 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Local = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var base_1 = require("./base");
6
+ // does pretty much nothing
7
+ var Local = /** @class */ (function (_super) {
8
+ (0, tslib_1.__extends)(Local, _super);
9
+ function Local() {
10
+ return _super !== null && _super.apply(this, arguments) || this;
11
+ }
12
+ Local.prototype.connect = function () {
13
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
14
+ return [2 /*return*/];
15
+ }); });
16
+ };
17
+ Local.prototype.disconnect = function () {
18
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
19
+ return [2 /*return*/];
20
+ }); });
21
+ };
22
+ Local.prototype.onRoundEndCommunication = function (_) {
23
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
24
+ return [2 /*return*/, _];
25
+ }); });
26
+ };
27
+ Local.prototype.onTrainEndCommunication = function () {
28
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
29
+ return [2 /*return*/];
30
+ }); });
31
+ };
32
+ return Local;
33
+ }(base_1.Base));
34
+ exports.Local = Local;
@@ -0,0 +1,16 @@
1
+ import { Dataset } from '../dataset_builder';
2
+ import { Task } from '../../task';
3
+ export interface DataConfig {
4
+ features?: string[];
5
+ labels?: string[];
6
+ }
7
+ export interface Data {
8
+ dataset: Dataset;
9
+ size: number;
10
+ }
11
+ export declare abstract class DataLoader<Source> {
12
+ protected task: Task;
13
+ constructor(task: Task);
14
+ abstract load(source: Source, config: DataConfig): Promise<Dataset>;
15
+ abstract loadAll(sources: Source[], config: DataConfig): Promise<Data>;
16
+ }
@@ -0,0 +1,10 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.DataLoader = void 0;
4
+ var DataLoader = /** @class */ (function () {
5
+ function DataLoader(task) {
6
+ this.task = task;
7
+ }
8
+ return DataLoader;
9
+ }());
10
+ exports.DataLoader = DataLoader;
@@ -0,0 +1,14 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { Dataset } from '../dataset_builder';
3
+ import { DataLoader, DataConfig, Data } from './data_loader';
4
+ /**
5
+ * TODO @s314cy:
6
+ * Load labels and correctly match them with their respective images, with the following constraints:
7
+ * 1. Images are given as 1 image/1 file
8
+ * 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels
9
+ */
10
+ export declare abstract class ImageLoader<Source> extends DataLoader<Source> {
11
+ abstract readImageFrom(source: Source): Promise<tf.Tensor3D>;
12
+ load(image: Source, config?: DataConfig): Promise<Dataset>;
13
+ loadAll(images: Source[], config?: DataConfig): Promise<Data>;
14
+ }
@@ -0,0 +1,93 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.ImageLoader = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
+ var data_loader_1 = require("./data_loader");
7
+ /**
8
+ * TODO @s314cy:
9
+ * Load labels and correctly match them with their respective images, with the following constraints:
10
+ * 1. Images are given as 1 image/1 file
11
+ * 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels
12
+ */
13
+ var ImageLoader = /** @class */ (function (_super) {
14
+ (0, tslib_1.__extends)(ImageLoader, _super);
15
+ function ImageLoader() {
16
+ return _super !== null && _super.apply(this, arguments) || this;
17
+ }
18
+ ImageLoader.prototype.load = function (image, config) {
19
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
20
+ var tensorContainer;
21
+ var _a;
22
+ return (0, tslib_1.__generator)(this, function (_b) {
23
+ switch (_b.label) {
24
+ case 0:
25
+ if (!(config === undefined || config.labels === undefined)) return [3 /*break*/, 2];
26
+ return [4 /*yield*/, this.readImageFrom(image)];
27
+ case 1:
28
+ tensorContainer = _b.sent();
29
+ return [3 /*break*/, 4];
30
+ case 2:
31
+ _a = {};
32
+ return [4 /*yield*/, this.readImageFrom(image)];
33
+ case 3:
34
+ tensorContainer = (_a.xs = _b.sent(),
35
+ _a.ys = config.labels[0],
36
+ _a);
37
+ _b.label = 4;
38
+ case 4: return [2 /*return*/, tf.data.array([tensorContainer])];
39
+ }
40
+ });
41
+ });
42
+ };
43
+ ImageLoader.prototype.loadAll = function (images, config) {
44
+ var _a, _b;
45
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
46
+ var labels, numberOfClasses, dataset;
47
+ var _this = this;
48
+ return (0, tslib_1.__generator)(this, function (_c) {
49
+ if ((config === null || config === void 0 ? void 0 : config.labels) !== undefined) {
50
+ numberOfClasses = (_b = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.LABEL_LIST) === null || _b === void 0 ? void 0 : _b.length;
51
+ if (numberOfClasses === undefined) {
52
+ throw new Error('wanted labels but none found in task');
53
+ }
54
+ labels = tf.oneHot(tf.tensor1d(config.labels, 'int32'), numberOfClasses).arraySync();
55
+ }
56
+ dataset = tf.data.generator(function () {
57
+ var withLabels = (config === null || config === void 0 ? void 0 : config.labels) !== undefined;
58
+ var index = 0;
59
+ var iterator = {
60
+ next: function () { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
61
+ var sample, label, value;
62
+ return (0, tslib_1.__generator)(this, function (_a) {
63
+ switch (_a.label) {
64
+ case 0:
65
+ if (index === images.length) {
66
+ return [2 /*return*/, { done: true }];
67
+ }
68
+ return [4 /*yield*/, this.readImageFrom(images[index])];
69
+ case 1:
70
+ sample = _a.sent();
71
+ label = withLabels ? labels[index] : undefined;
72
+ value = withLabels ? { xs: sample, ys: label } : sample;
73
+ index++;
74
+ return [2 /*return*/, {
75
+ value: value,
76
+ done: false
77
+ }];
78
+ }
79
+ });
80
+ }); }
81
+ };
82
+ return iterator; // Lazy
83
+ });
84
+ return [2 /*return*/, {
85
+ dataset: dataset,
86
+ size: images.length
87
+ }];
88
+ });
89
+ });
90
+ };
91
+ return ImageLoader;
92
+ }(data_loader_1.DataLoader));
93
+ exports.ImageLoader = ImageLoader;
@@ -0,0 +1,3 @@
1
+ export { Data, DataConfig, DataLoader } from './data_loader';
2
+ export { ImageLoader } from './image_loader';
3
+ export { TabularLoader } from './tabular_loader';
@@ -0,0 +1,9 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.TabularLoader = exports.ImageLoader = exports.DataLoader = void 0;
4
+ var data_loader_1 = require("./data_loader");
5
+ Object.defineProperty(exports, "DataLoader", { enumerable: true, get: function () { return data_loader_1.DataLoader; } });
6
+ var image_loader_1 = require("./image_loader");
7
+ Object.defineProperty(exports, "ImageLoader", { enumerable: true, get: function () { return image_loader_1.ImageLoader; } });
8
+ var tabular_loader_1 = require("./tabular_loader");
9
+ Object.defineProperty(exports, "TabularLoader", { enumerable: true, get: function () { return tabular_loader_1.TabularLoader; } });
@@ -0,0 +1,29 @@
1
+ import { DataLoader, DataConfig, Data } from './data_loader';
2
+ import { Dataset } from '../dataset_builder';
3
+ import { Task } from '../../task';
4
+ import * as tf from '@tensorflow/tfjs';
5
+ export declare abstract class TabularLoader<Source> extends DataLoader<Source> {
6
+ private readonly delimiter;
7
+ constructor(task: Task, delimiter: string);
8
+ /**
9
+ * Creates a CSV dataset object based off the given source.
10
+ * @param source File object, URL string or local file system path.
11
+ * @param csvConfig Object expected by TF.js to create a CSVDataset.
12
+ * @returns The CSVDataset object built upon the given source.
13
+ */
14
+ abstract loadTabularDatasetFrom(source: Source, csvConfig: Record<string, unknown>): tf.data.CSVDataset;
15
+ /**
16
+ * Expects delimiter-separated tabular data made of N columns. The data may be
17
+ * potentially split among several sources. Every source should contain N-1
18
+ * feature columns and 1 single label column.
19
+ * @param source List of File objects, URLs or file system paths.
20
+ * @param config
21
+ * @returns A TF.js dataset built upon read tabular data stored in the given sources.
22
+ */
23
+ load(source: Source, config?: DataConfig): Promise<Dataset>;
24
+ /**
25
+ * Creates the CSV datasets based off the given sources, then fuses them into a single CSV
26
+ * dataset.
27
+ */
28
+ loadAll(sources: Source[], config: DataConfig): Promise<Data>;
29
+ }
@@ -0,0 +1,88 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.TabularLoader = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var data_loader_1 = require("./data_loader");
6
+ var immutable_1 = require("immutable");
7
+ var TabularLoader = /** @class */ (function (_super) {
8
+ (0, tslib_1.__extends)(TabularLoader, _super);
9
+ function TabularLoader(task, delimiter) {
10
+ var _this = _super.call(this, task) || this;
11
+ _this.delimiter = delimiter;
12
+ return _this;
13
+ }
14
+ /**
15
+ * Expects delimiter-separated tabular data made of N columns. The data may be
16
+ * potentially split among several sources. Every source should contain N-1
17
+ * feature columns and 1 single label column.
18
+ * @param source List of File objects, URLs or file system paths.
19
+ * @param config
20
+ * @returns A TF.js dataset built upon read tabular data stored in the given sources.
21
+ */
22
+ TabularLoader.prototype.load = function (source, config) {
23
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
24
+ var columnConfigs, csvConfig;
25
+ return (0, tslib_1.__generator)(this, function (_a) {
26
+ /**
27
+ * Prepare the CSV config object based off the given features and labels.
28
+ * If labels is empty, then the returned dataset is comprised of samples only.
29
+ * Otherwise, each entry is of the form `{ xs, ys }` with `xs` as features and `ys`
30
+ * as labels.
31
+ */
32
+ if ((config === null || config === void 0 ? void 0 : config.features) === undefined) {
33
+ // TODO @s314cy
34
+ throw new Error('not implemented');
35
+ }
36
+ columnConfigs = (0, immutable_1.Map)((0, immutable_1.Set)(config.features).map(function (feature) { return [feature, { required: false, isLabel: false }]; })).merge((0, immutable_1.Set)(config.labels).map(function (label) { return [label, { required: true, isLabel: true }]; }));
37
+ csvConfig = {
38
+ hasHeader: true,
39
+ columnConfigs: columnConfigs.toObject(),
40
+ configuredColumnsOnly: true,
41
+ delimiter: this.delimiter
42
+ };
43
+ return [2 /*return*/, this.loadTabularDatasetFrom(source, csvConfig).map(function (t) {
44
+ if (typeof t === 'object' && ('xs' in t) && ('ys' in t)) {
45
+ return t;
46
+ }
47
+ throw new Error('expected TensorContainerObject');
48
+ }).map(function (t) {
49
+ // TODO order may not be stable between tensor
50
+ var _a = t, xs = _a.xs, ys = _a.ys;
51
+ return {
52
+ xs: Object.values(xs),
53
+ ys: Object.values(ys)
54
+ };
55
+ })];
56
+ });
57
+ });
58
+ };
59
+ /**
60
+ * Creates the CSV datasets based off the given sources, then fuses them into a single CSV
61
+ * dataset.
62
+ */
63
+ TabularLoader.prototype.loadAll = function (sources, config) {
64
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
65
+ var datasets, dataset;
66
+ var _this = this;
67
+ return (0, tslib_1.__generator)(this, function (_a) {
68
+ switch (_a.label) {
69
+ case 0: return [4 /*yield*/, Promise.all(sources.map(function (source) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
70
+ switch (_a.label) {
71
+ case 0: return [4 /*yield*/, this.load(source, config)];
72
+ case 1: return [2 /*return*/, _a.sent()];
73
+ }
74
+ }); }); }))];
75
+ case 1:
76
+ datasets = _a.sent();
77
+ dataset = (0, immutable_1.List)(datasets).reduce(function (acc, dataset) { return acc.concatenate(dataset); });
78
+ return [2 /*return*/, {
79
+ dataset: dataset,
80
+ size: dataset.size // TODO: needs to be tested
81
+ }];
82
+ }
83
+ });
84
+ });
85
+ };
86
+ return TabularLoader;
87
+ }(data_loader_1.DataLoader));
88
+ exports.TabularLoader = TabularLoader;
@@ -0,0 +1,17 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { DataLoader, Data } from './data_loader/data_loader';
3
+ import { Task } from '@/task';
4
+ export declare type Dataset = tf.data.Dataset<tf.TensorContainer>;
5
+ export declare class DatasetBuilder<Source> {
6
+ private readonly task;
7
+ private readonly dataLoader;
8
+ private sources;
9
+ private readonly labelledSources;
10
+ private built;
11
+ constructor(dataLoader: DataLoader<Source>, task: Task);
12
+ addFiles(sources: Source[], label?: string): void;
13
+ clearFiles(label?: string): void;
14
+ build(): Promise<Data>;
15
+ isBuilt(): boolean;
16
+ size(): number;
17
+ }