@epfml/discojs 0.0.1 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +12 -11
- package/dist/aggregation.d.ts +4 -2
- package/dist/aggregation.js +20 -6
- package/dist/client/base.d.ts +1 -1
- package/dist/client/decentralized/base.d.ts +43 -0
- package/dist/client/decentralized/base.js +243 -0
- package/dist/client/decentralized/clear_text.d.ts +13 -0
- package/dist/client/decentralized/clear_text.js +78 -0
- package/dist/client/decentralized/index.d.ts +4 -0
- package/dist/client/decentralized/index.js +9 -0
- package/dist/client/decentralized/messages.d.ts +37 -0
- package/dist/client/decentralized/messages.js +15 -0
- package/dist/client/decentralized/sec_agg.d.ts +18 -0
- package/dist/client/decentralized/sec_agg.js +169 -0
- package/dist/client/decentralized/secret_shares.d.ts +5 -0
- package/dist/client/decentralized/secret_shares.js +58 -0
- package/dist/client/decentralized/types.d.ts +1 -0
- package/dist/client/decentralized/types.js +2 -0
- package/dist/client/federated.d.ts +4 -4
- package/dist/client/federated.js +5 -8
- package/dist/client/index.d.ts +1 -1
- package/dist/client/index.js +3 -3
- package/dist/client/local.js +5 -3
- package/dist/dataset/data_loader/data_loader.d.ts +7 -1
- package/dist/dataset/data_loader/image_loader.d.ts +5 -3
- package/dist/dataset/data_loader/image_loader.js +64 -18
- package/dist/dataset/data_loader/index.d.ts +1 -1
- package/dist/dataset/data_loader/tabular_loader.d.ts +3 -3
- package/dist/dataset/data_loader/tabular_loader.js +27 -18
- package/dist/dataset/dataset_builder.d.ts +3 -2
- package/dist/dataset/dataset_builder.js +29 -17
- package/dist/index.d.ts +6 -4
- package/dist/index.js +13 -5
- package/dist/informant/graph_informant.d.ts +10 -0
- package/dist/informant/graph_informant.js +23 -0
- package/dist/informant/index.d.ts +3 -0
- package/dist/informant/index.js +9 -0
- package/dist/informant/training_informant/base.d.ts +31 -0
- package/dist/informant/training_informant/base.js +82 -0
- package/dist/informant/training_informant/decentralized.d.ts +5 -0
- package/dist/informant/training_informant/decentralized.js +22 -0
- package/dist/informant/training_informant/federated.d.ts +14 -0
- package/dist/informant/training_informant/federated.js +32 -0
- package/dist/informant/training_informant/index.d.ts +4 -0
- package/dist/informant/training_informant/index.js +11 -0
- package/dist/informant/training_informant/local.d.ts +6 -0
- package/dist/informant/training_informant/local.js +20 -0
- package/dist/logging/index.d.ts +1 -0
- package/dist/logging/index.js +3 -1
- package/dist/logging/trainer_logger.d.ts +1 -1
- package/dist/logging/trainer_logger.js +5 -5
- package/dist/memory/base.d.ts +17 -48
- package/dist/memory/empty.d.ts +6 -4
- package/dist/memory/empty.js +8 -2
- package/dist/memory/index.d.ts +1 -1
- package/dist/privacy.js +3 -3
- package/dist/serialization/model.d.ts +1 -1
- package/dist/serialization/model.js +2 -2
- package/dist/serialization/weights.js +2 -2
- package/dist/task/display_information.d.ts +2 -2
- package/dist/task/display_information.js +6 -5
- package/dist/task/summary.d.ts +5 -0
- package/dist/task/summary.js +23 -0
- package/dist/task/training_information.d.ts +3 -0
- package/dist/tasks/cifar10.js +5 -3
- package/dist/tasks/lus_covid.d.ts +1 -1
- package/dist/tasks/lus_covid.js +50 -13
- package/dist/tasks/mnist.js +4 -2
- package/dist/tasks/simple_face.d.ts +1 -1
- package/dist/tasks/simple_face.js +13 -17
- package/dist/tasks/titanic.js +9 -7
- package/dist/tfjs.d.ts +2 -0
- package/dist/tfjs.js +6 -0
- package/dist/training/disco.d.ts +3 -1
- package/dist/training/disco.js +14 -6
- package/dist/training/trainer/distributed_trainer.d.ts +1 -1
- package/dist/training/trainer/distributed_trainer.js +5 -1
- package/dist/training/trainer/local_trainer.d.ts +4 -3
- package/dist/training/trainer/local_trainer.js +6 -9
- package/dist/training/trainer/round_tracker.js +3 -0
- package/dist/training/trainer/trainer.d.ts +15 -15
- package/dist/training/trainer/trainer.js +57 -43
- package/dist/training/trainer/trainer_builder.js +8 -15
- package/dist/types.d.ts +1 -1
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +5 -0
- package/dist/validation/validator.d.ts +20 -0
- package/dist/validation/validator.js +106 -0
- package/package.json +2 -3
- package/dist/client/decentralized.d.ts +0 -23
- package/dist/client/decentralized.js +0 -275
- package/dist/testing/tester.d.ts +0 -5
- package/dist/testing/tester.js +0 -21
- package/dist/training_informant.d.ts +0 -88
- package/dist/training_informant.js +0 -135
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.SecAgg = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
|
|
7
|
+
var __1 = require("../..");
|
|
8
|
+
var base_1 = require("./base");
|
|
9
|
+
var messages = (0, tslib_1.__importStar)(require("./messages"));
|
|
10
|
+
var secret_shares = (0, tslib_1.__importStar)(require("./secret_shares"));
|
|
11
|
+
/**
|
|
12
|
+
* Decentralized client that utilizes secure aggregation so client updates remain private
|
|
13
|
+
*/
|
|
14
|
+
var SecAgg = /** @class */ (function (_super) {
|
|
15
|
+
(0, tslib_1.__extends)(SecAgg, _super);
|
|
16
|
+
function SecAgg() {
|
|
17
|
+
var _this = _super !== null && _super.apply(this, arguments) || this;
|
|
18
|
+
// list of weights received from other clients
|
|
19
|
+
_this.receivedShares = (0, immutable_1.List)();
|
|
20
|
+
// list of partial sums received by client
|
|
21
|
+
_this.receivedPartialSums = (0, immutable_1.List)();
|
|
22
|
+
// the partial sum calculated by the client
|
|
23
|
+
_this.mySum = [];
|
|
24
|
+
return _this;
|
|
25
|
+
}
|
|
26
|
+
/*
|
|
27
|
+
generates shares and sends to all ready peers adds differential privacy
|
|
28
|
+
*/
|
|
29
|
+
SecAgg.prototype.sendShares = function (noisyWeights, round, trainingInformant) {
|
|
30
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
31
|
+
var weightShares, i, weights, msg, encodedMsg;
|
|
32
|
+
var _a;
|
|
33
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
34
|
+
switch (_b.label) {
|
|
35
|
+
case 0:
|
|
36
|
+
weightShares = secret_shares.generateAllShares(noisyWeights, this.peers.length, this.maxShareValue);
|
|
37
|
+
i = 0;
|
|
38
|
+
_b.label = 1;
|
|
39
|
+
case 1:
|
|
40
|
+
if (!(i < this.peers.length)) return [3 /*break*/, 4];
|
|
41
|
+
weights = weightShares.get(i);
|
|
42
|
+
if (weights === undefined) {
|
|
43
|
+
throw new Error('weight shares generated incorrectly');
|
|
44
|
+
}
|
|
45
|
+
_a = {
|
|
46
|
+
type: messages.messageType.clientSharesMessageServer,
|
|
47
|
+
peerID: this.ID
|
|
48
|
+
};
|
|
49
|
+
return [4 /*yield*/, __1.serialization.weights.encode(weights)];
|
|
50
|
+
case 2:
|
|
51
|
+
msg = (_a.weights = _b.sent(),
|
|
52
|
+
_a.destination = this.peers[i],
|
|
53
|
+
_a);
|
|
54
|
+
encodedMsg = msgpack_lite_1.default.encode(msg);
|
|
55
|
+
this.sendMessagetoPeer(encodedMsg);
|
|
56
|
+
_b.label = 3;
|
|
57
|
+
case 3:
|
|
58
|
+
i++;
|
|
59
|
+
return [3 /*break*/, 1];
|
|
60
|
+
case 4: return [2 /*return*/];
|
|
61
|
+
}
|
|
62
|
+
});
|
|
63
|
+
});
|
|
64
|
+
};
|
|
65
|
+
/*
|
|
66
|
+
sends partial sums to connected peers so final update can be calculated
|
|
67
|
+
*/
|
|
68
|
+
SecAgg.prototype.sendPartialSums = function () {
|
|
69
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
70
|
+
var i, msg, encodedMsg;
|
|
71
|
+
var _a;
|
|
72
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
73
|
+
switch (_b.label) {
|
|
74
|
+
case 0:
|
|
75
|
+
// calculating my personal partial sum from received shares that i will share with peers
|
|
76
|
+
this.mySum = __1.aggregation.sumWeights((0, immutable_1.List)(Array.from(this.receivedShares.values())));
|
|
77
|
+
i = 0;
|
|
78
|
+
_b.label = 1;
|
|
79
|
+
case 1:
|
|
80
|
+
if (!(i < this.peers.length)) return [3 /*break*/, 4];
|
|
81
|
+
_a = {
|
|
82
|
+
type: messages.messageType.clientPartialSumsMessageServer,
|
|
83
|
+
peerID: this.ID
|
|
84
|
+
};
|
|
85
|
+
return [4 /*yield*/, __1.serialization.weights.encode(this.mySum)];
|
|
86
|
+
case 2:
|
|
87
|
+
msg = (_a.partials = _b.sent(),
|
|
88
|
+
_a.destination = this.peers[i],
|
|
89
|
+
_a);
|
|
90
|
+
encodedMsg = msgpack_lite_1.default.encode(msg);
|
|
91
|
+
this.sendMessagetoPeer(encodedMsg);
|
|
92
|
+
_b.label = 3;
|
|
93
|
+
case 3:
|
|
94
|
+
i++;
|
|
95
|
+
return [3 /*break*/, 1];
|
|
96
|
+
case 4: return [2 /*return*/];
|
|
97
|
+
}
|
|
98
|
+
});
|
|
99
|
+
});
|
|
100
|
+
};
|
|
101
|
+
SecAgg.prototype.sendAndReceiveWeights = function (noisyWeights, round, trainingInformant) {
|
|
102
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
103
|
+
var _this = this;
|
|
104
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
105
|
+
switch (_a.label) {
|
|
106
|
+
case 0:
|
|
107
|
+
// reset fields at beginning of each round
|
|
108
|
+
this.receivedShares = this.receivedShares.clear();
|
|
109
|
+
this.receivedPartialSums = this.receivedPartialSums.clear();
|
|
110
|
+
// PHASE 1 COMMUNICATION --> send additive shares to ready peers, pause program until shares are received from all peers
|
|
111
|
+
return [4 /*yield*/, this.sendShares(noisyWeights, round, trainingInformant)];
|
|
112
|
+
case 1:
|
|
113
|
+
// PHASE 1 COMMUNICATION --> send additive shares to ready peers, pause program until shares are received from all peers
|
|
114
|
+
_a.sent();
|
|
115
|
+
return [4 /*yield*/, this.pauseUntil(function () { return _this.receivedShares.size >= _this.peers.length; })
|
|
116
|
+
// PHASE 2 COMMUNICATION --> send partial sums to ready peers
|
|
117
|
+
];
|
|
118
|
+
case 2:
|
|
119
|
+
_a.sent();
|
|
120
|
+
// PHASE 2 COMMUNICATION --> send partial sums to ready peers
|
|
121
|
+
return [4 /*yield*/, this.sendPartialSums()
|
|
122
|
+
// after all partial sums are received, return list of partial sums to be aggregated
|
|
123
|
+
];
|
|
124
|
+
case 3:
|
|
125
|
+
// PHASE 2 COMMUNICATION --> send partial sums to ready peers
|
|
126
|
+
_a.sent();
|
|
127
|
+
// after all partial sums are received, return list of partial sums to be aggregated
|
|
128
|
+
return [4 /*yield*/, this.pauseUntil(function () { return _this.receivedPartialSums.size >= _this.peers.length; })];
|
|
129
|
+
case 4:
|
|
130
|
+
// after all partial sums are received, return list of partial sums to be aggregated
|
|
131
|
+
_a.sent();
|
|
132
|
+
return [2 /*return*/, this.receivedPartialSums];
|
|
133
|
+
}
|
|
134
|
+
});
|
|
135
|
+
});
|
|
136
|
+
};
|
|
137
|
+
/*
|
|
138
|
+
checks if message contains shares from a peer
|
|
139
|
+
*/
|
|
140
|
+
SecAgg.prototype.instanceOfClientSharesMessageServer = function (msg) {
|
|
141
|
+
return msg.type === messages.messageType.clientSharesMessageServer;
|
|
142
|
+
};
|
|
143
|
+
/*
|
|
144
|
+
checks if message contains partial sums from a peer
|
|
145
|
+
*/
|
|
146
|
+
SecAgg.prototype.instanceOfClientPartialSumsMessageServer = function (msg) {
|
|
147
|
+
return msg.type === messages.messageType.clientPartialSumsMessageServer;
|
|
148
|
+
};
|
|
149
|
+
/*
|
|
150
|
+
handles received messages from signaling server
|
|
151
|
+
*/
|
|
152
|
+
SecAgg.prototype.clientHandle = function (msg) {
|
|
153
|
+
if (this.instanceOfClientSharesMessageServer(msg)) {
|
|
154
|
+
// update received weights by one weights reception
|
|
155
|
+
var weights = __1.serialization.weights.decode(msg.weights);
|
|
156
|
+
this.receivedShares = this.receivedShares.push(weights);
|
|
157
|
+
}
|
|
158
|
+
else if (this.instanceOfClientPartialSumsMessageServer(msg)) {
|
|
159
|
+
// update received partial sums by one partial sum
|
|
160
|
+
var partials = __1.serialization.weights.decode(msg.partials);
|
|
161
|
+
this.receivedPartialSums = this.receivedPartialSums.push(partials);
|
|
162
|
+
}
|
|
163
|
+
else {
|
|
164
|
+
throw new Error('Unexpected Message Type');
|
|
165
|
+
}
|
|
166
|
+
};
|
|
167
|
+
return SecAgg;
|
|
168
|
+
}(base_1.Base));
|
|
169
|
+
exports.SecAgg = SecAgg;
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import { Weights } from '../..';
|
|
3
|
+
export declare function lastShare(currentShares: Weights[], secret: Weights): Weights;
|
|
4
|
+
export declare function generateAllShares(secret: Weights, nParticipants: number, maxShareValue: number): List<Weights>;
|
|
5
|
+
export declare function generateRandomShare(secret: Weights, maxShareValue: number): Weights;
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.generateRandomShare = exports.generateAllShares = exports.lastShare = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
|
|
7
|
+
var crypto = (0, tslib_1.__importStar)(require("crypto"));
|
|
8
|
+
var __1 = require("../..");
|
|
9
|
+
var maxSeed = Math.pow(2, 47);
|
|
10
|
+
/*
|
|
11
|
+
Return Weights in the remaining share once N-1 shares have been constructed (where N is number of ready clients)
|
|
12
|
+
*/
|
|
13
|
+
function lastShare(currentShares, secret) {
|
|
14
|
+
if (currentShares.length === 0) {
|
|
15
|
+
throw new Error('Need at least one current share to be able to subtract secret from');
|
|
16
|
+
}
|
|
17
|
+
var currentShares2 = (0, immutable_1.List)(currentShares);
|
|
18
|
+
var last = __1.aggregation.subtractWeights((0, immutable_1.List)([secret, __1.aggregation.sumWeights(currentShares2)]));
|
|
19
|
+
return last;
|
|
20
|
+
}
|
|
21
|
+
exports.lastShare = lastShare;
|
|
22
|
+
/*
|
|
23
|
+
Generate N additive shares that aggregate to the secret weights array (where N is number of ready clients)
|
|
24
|
+
*/
|
|
25
|
+
function generateAllShares(secret, nParticipants, maxShareValue) {
|
|
26
|
+
var shares = [];
|
|
27
|
+
for (var i = 0; i < nParticipants - 1; i++) {
|
|
28
|
+
shares.push(generateRandomShare(secret, maxShareValue));
|
|
29
|
+
}
|
|
30
|
+
shares.push(lastShare(shares, secret));
|
|
31
|
+
var sharesFinal = (0, immutable_1.List)(shares);
|
|
32
|
+
return sharesFinal;
|
|
33
|
+
}
|
|
34
|
+
exports.generateAllShares = generateAllShares;
|
|
35
|
+
/*
|
|
36
|
+
generates one share in the same shape as the secret that is populated with values randomly chosend from
|
|
37
|
+
a uniform distribution between (-maxShareValue, maxShareValue).
|
|
38
|
+
*/
|
|
39
|
+
function generateRandomShare(secret, maxShareValue) {
|
|
40
|
+
var e_1, _a;
|
|
41
|
+
var share = [];
|
|
42
|
+
var seed = crypto.randomInt(maxSeed);
|
|
43
|
+
try {
|
|
44
|
+
for (var secret_1 = (0, tslib_1.__values)(secret), secret_1_1 = secret_1.next(); !secret_1_1.done; secret_1_1 = secret_1.next()) {
|
|
45
|
+
var t = secret_1_1.value;
|
|
46
|
+
share.push(tf.randomUniform(t.shape, -maxShareValue, maxShareValue, 'float32', seed));
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
catch (e_1_1) { e_1 = { error: e_1_1 }; }
|
|
50
|
+
finally {
|
|
51
|
+
try {
|
|
52
|
+
if (secret_1_1 && !secret_1_1.done && (_a = secret_1.return)) _a.call(secret_1);
|
|
53
|
+
}
|
|
54
|
+
finally { if (e_1) throw e_1.error; }
|
|
55
|
+
}
|
|
56
|
+
return share;
|
|
57
|
+
}
|
|
58
|
+
exports.generateRandomShare = generateRandomShare;
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export declare type PeerID = number;
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { informant, MetadataID, Weights } from '..';
|
|
2
2
|
import { Base } from './base';
|
|
3
3
|
/**
|
|
4
4
|
* Class that deals with communication with the centralized server when training
|
|
@@ -24,7 +24,7 @@ export declare class Federated extends Base {
|
|
|
24
24
|
getMetadataMap(metadataID: MetadataID): Promise<Map<string, unknown>>;
|
|
25
25
|
getLatestServerRound(): Promise<number>;
|
|
26
26
|
pullRoundAndFetchWeights(): Promise<Weights | undefined>;
|
|
27
|
-
pullServerStatistics(trainingInformant:
|
|
28
|
-
onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, _: number, trainingInformant:
|
|
29
|
-
onTrainEndCommunication(
|
|
27
|
+
pullServerStatistics(trainingInformant: informant.FederatedInformant): Promise<void>;
|
|
28
|
+
onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, _: number, trainingInformant: informant.FederatedInformant): Promise<Weights>;
|
|
29
|
+
onTrainEndCommunication(): Promise<void>;
|
|
30
30
|
}
|
package/dist/client/federated.js
CHANGED
|
@@ -181,7 +181,7 @@ var Federated = /** @class */ (function (_super) {
|
|
|
181
181
|
case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('statistics'))];
|
|
182
182
|
case 1:
|
|
183
183
|
response = _a.sent();
|
|
184
|
-
trainingInformant.
|
|
184
|
+
trainingInformant.update(response.data.statistics);
|
|
185
185
|
return [2 /*return*/];
|
|
186
186
|
}
|
|
187
187
|
});
|
|
@@ -208,13 +208,10 @@ var Federated = /** @class */ (function (_super) {
|
|
|
208
208
|
});
|
|
209
209
|
});
|
|
210
210
|
};
|
|
211
|
-
Federated.prototype.onTrainEndCommunication = function (
|
|
212
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
213
|
-
return
|
|
214
|
-
|
|
215
|
-
return [2 /*return*/];
|
|
216
|
-
});
|
|
217
|
-
});
|
|
211
|
+
Federated.prototype.onTrainEndCommunication = function () {
|
|
212
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
213
|
+
return [2 /*return*/];
|
|
214
|
+
}); });
|
|
218
215
|
};
|
|
219
216
|
return Federated;
|
|
220
217
|
}(base_1.Base));
|
package/dist/client/index.d.ts
CHANGED
package/dist/client/index.js
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"use strict";
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Local = exports.Federated = exports.
|
|
3
|
+
exports.Local = exports.Federated = exports.decentralized = exports.Base = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
4
5
|
var base_1 = require("./base");
|
|
5
6
|
Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
|
|
6
|
-
|
|
7
|
-
Object.defineProperty(exports, "Decentralized", { enumerable: true, get: function () { return decentralized_1.Decentralized; } });
|
|
7
|
+
exports.decentralized = (0, tslib_1.__importStar)(require("./decentralized"));
|
|
8
8
|
var federated_1 = require("./federated");
|
|
9
9
|
Object.defineProperty(exports, "Federated", { enumerable: true, get: function () { return federated_1.Federated; } });
|
|
10
10
|
var local_1 = require("./local");
|
package/dist/client/local.js
CHANGED
|
@@ -20,9 +20,11 @@ var Local = /** @class */ (function (_super) {
|
|
|
20
20
|
}); });
|
|
21
21
|
};
|
|
22
22
|
Local.prototype.onRoundEndCommunication = function (_) {
|
|
23
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
24
|
-
return
|
|
25
|
-
|
|
23
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
24
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
25
|
+
return [2 /*return*/, _];
|
|
26
|
+
});
|
|
27
|
+
});
|
|
26
28
|
};
|
|
27
29
|
Local.prototype.onTrainEndCommunication = function () {
|
|
28
30
|
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
@@ -3,14 +3,20 @@ import { Task } from '../../task';
|
|
|
3
3
|
export interface DataConfig {
|
|
4
4
|
features?: string[];
|
|
5
5
|
labels?: string[];
|
|
6
|
+
shuffle?: boolean;
|
|
7
|
+
validationSplit?: number;
|
|
6
8
|
}
|
|
7
9
|
export interface Data {
|
|
8
10
|
dataset: Dataset;
|
|
9
11
|
size: number;
|
|
10
12
|
}
|
|
13
|
+
export interface DataTuple {
|
|
14
|
+
train: Data;
|
|
15
|
+
validation?: Data;
|
|
16
|
+
}
|
|
11
17
|
export declare abstract class DataLoader<Source> {
|
|
12
18
|
protected task: Task;
|
|
13
19
|
constructor(task: Task);
|
|
14
20
|
abstract load(source: Source, config: DataConfig): Promise<Dataset>;
|
|
15
|
-
abstract loadAll(sources: Source[], config: DataConfig): Promise<
|
|
21
|
+
abstract loadAll(sources: Source[], config: DataConfig): Promise<DataTuple>;
|
|
16
22
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import
|
|
1
|
+
import { tf } from '../..';
|
|
2
2
|
import { Dataset } from '../dataset_builder';
|
|
3
|
-
import { DataLoader, DataConfig,
|
|
3
|
+
import { DataLoader, DataConfig, DataTuple } from './data_loader';
|
|
4
4
|
/**
|
|
5
5
|
* TODO @s314cy:
|
|
6
6
|
* Load labels and correctly match them with their respective images, with the following constraints:
|
|
@@ -10,5 +10,7 @@ import { DataLoader, DataConfig, Data } from './data_loader';
|
|
|
10
10
|
export declare abstract class ImageLoader<Source> extends DataLoader<Source> {
|
|
11
11
|
abstract readImageFrom(source: Source): Promise<tf.Tensor3D>;
|
|
12
12
|
load(image: Source, config?: DataConfig): Promise<Dataset>;
|
|
13
|
-
|
|
13
|
+
private buildDataset;
|
|
14
|
+
loadAll(images: Source[], config?: DataConfig): Promise<DataTuple>;
|
|
15
|
+
shuffle(array: number[]): void;
|
|
14
16
|
}
|
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.ImageLoader = void 0;
|
|
4
4
|
var tslib_1 = require("tslib");
|
|
5
|
-
var
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var __1 = require("../..");
|
|
6
7
|
var data_loader_1 = require("./data_loader");
|
|
7
8
|
/**
|
|
8
9
|
* TODO @s314cy:
|
|
@@ -35,25 +36,17 @@ var ImageLoader = /** @class */ (function (_super) {
|
|
|
35
36
|
_a.ys = config.labels[0],
|
|
36
37
|
_a);
|
|
37
38
|
_b.label = 4;
|
|
38
|
-
case 4: return [2 /*return*/, tf.data.array([tensorContainer])];
|
|
39
|
+
case 4: return [2 /*return*/, __1.tf.data.array([tensorContainer])];
|
|
39
40
|
}
|
|
40
41
|
});
|
|
41
42
|
});
|
|
42
43
|
};
|
|
43
|
-
ImageLoader.prototype.
|
|
44
|
-
var _a, _b;
|
|
44
|
+
ImageLoader.prototype.buildDataset = function (images, labels, indices, config) {
|
|
45
45
|
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
46
|
-
var
|
|
46
|
+
var dataset;
|
|
47
47
|
var _this = this;
|
|
48
|
-
return (0, tslib_1.__generator)(this, function (
|
|
49
|
-
|
|
50
|
-
numberOfClasses = (_b = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.LABEL_LIST) === null || _b === void 0 ? void 0 : _b.length;
|
|
51
|
-
if (numberOfClasses === undefined) {
|
|
52
|
-
throw new Error('wanted labels but none found in task');
|
|
53
|
-
}
|
|
54
|
-
labels = tf.oneHot(tf.tensor1d(config.labels, 'int32'), numberOfClasses).arraySync();
|
|
55
|
-
}
|
|
56
|
-
dataset = tf.data.generator(function () {
|
|
48
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
49
|
+
dataset = __1.tf.data.generator(function () {
|
|
57
50
|
var withLabels = (config === null || config === void 0 ? void 0 : config.labels) !== undefined;
|
|
58
51
|
var index = 0;
|
|
59
52
|
var iterator = {
|
|
@@ -62,13 +55,13 @@ var ImageLoader = /** @class */ (function (_super) {
|
|
|
62
55
|
return (0, tslib_1.__generator)(this, function (_a) {
|
|
63
56
|
switch (_a.label) {
|
|
64
57
|
case 0:
|
|
65
|
-
if (index ===
|
|
58
|
+
if (index === indices.length) {
|
|
66
59
|
return [2 /*return*/, { done: true }];
|
|
67
60
|
}
|
|
68
|
-
return [4 /*yield*/, this.readImageFrom(images[index])];
|
|
61
|
+
return [4 /*yield*/, this.readImageFrom(images[indices[index]])];
|
|
69
62
|
case 1:
|
|
70
63
|
sample = _a.sent();
|
|
71
|
-
label = withLabels ? labels[index] : undefined;
|
|
64
|
+
label = withLabels ? labels[indices[index]] : undefined;
|
|
72
65
|
value = withLabels ? { xs: sample, ys: label } : sample;
|
|
73
66
|
index++;
|
|
74
67
|
return [2 /*return*/, {
|
|
@@ -83,11 +76,64 @@ var ImageLoader = /** @class */ (function (_super) {
|
|
|
83
76
|
});
|
|
84
77
|
return [2 /*return*/, {
|
|
85
78
|
dataset: dataset,
|
|
86
|
-
size:
|
|
79
|
+
size: indices.length
|
|
87
80
|
}];
|
|
88
81
|
});
|
|
89
82
|
});
|
|
90
83
|
};
|
|
84
|
+
ImageLoader.prototype.loadAll = function (images, config) {
|
|
85
|
+
var _a, _b;
|
|
86
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
87
|
+
var labels, indices, numberOfClasses, dataset, trainSize, trainIndices, valIndices, trainDataset, valDataset;
|
|
88
|
+
return (0, tslib_1.__generator)(this, function (_c) {
|
|
89
|
+
switch (_c.label) {
|
|
90
|
+
case 0:
|
|
91
|
+
labels = [];
|
|
92
|
+
indices = (0, immutable_1.Range)(0, images.length).toArray();
|
|
93
|
+
if ((config === null || config === void 0 ? void 0 : config.labels) !== undefined) {
|
|
94
|
+
numberOfClasses = (_b = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.LABEL_LIST) === null || _b === void 0 ? void 0 : _b.length;
|
|
95
|
+
if (numberOfClasses === undefined) {
|
|
96
|
+
throw new Error('wanted labels but none found in task');
|
|
97
|
+
}
|
|
98
|
+
labels = __1.tf.oneHot(__1.tf.tensor1d(config.labels, 'int32'), numberOfClasses).arraySync();
|
|
99
|
+
}
|
|
100
|
+
if ((config === null || config === void 0 ? void 0 : config.shuffle) === undefined || (config === null || config === void 0 ? void 0 : config.shuffle)) {
|
|
101
|
+
this.shuffle(indices);
|
|
102
|
+
}
|
|
103
|
+
if (!((config === null || config === void 0 ? void 0 : config.validationSplit) === undefined)) return [3 /*break*/, 2];
|
|
104
|
+
return [4 /*yield*/, this.buildDataset(images, labels, indices, config)];
|
|
105
|
+
case 1:
|
|
106
|
+
dataset = _c.sent();
|
|
107
|
+
return [2 /*return*/, {
|
|
108
|
+
train: dataset,
|
|
109
|
+
validation: dataset
|
|
110
|
+
}];
|
|
111
|
+
case 2:
|
|
112
|
+
trainSize = Math.floor(images.length * (1 - config.validationSplit));
|
|
113
|
+
trainIndices = indices.slice(0, trainSize);
|
|
114
|
+
valIndices = indices.slice(trainSize);
|
|
115
|
+
return [4 /*yield*/, this.buildDataset(images, labels, trainIndices, config)];
|
|
116
|
+
case 3:
|
|
117
|
+
trainDataset = _c.sent();
|
|
118
|
+
return [4 /*yield*/, this.buildDataset(images, labels, valIndices, config)];
|
|
119
|
+
case 4:
|
|
120
|
+
valDataset = _c.sent();
|
|
121
|
+
return [2 /*return*/, {
|
|
122
|
+
train: trainDataset,
|
|
123
|
+
validation: valDataset
|
|
124
|
+
}];
|
|
125
|
+
}
|
|
126
|
+
});
|
|
127
|
+
});
|
|
128
|
+
};
|
|
129
|
+
ImageLoader.prototype.shuffle = function (array) {
|
|
130
|
+
for (var i = 0; i < array.length; i++) {
|
|
131
|
+
var j = Math.floor(Math.random() * i);
|
|
132
|
+
var swap = array[i];
|
|
133
|
+
array[i] = array[j];
|
|
134
|
+
array[j] = swap;
|
|
135
|
+
}
|
|
136
|
+
};
|
|
91
137
|
return ImageLoader;
|
|
92
138
|
}(data_loader_1.DataLoader));
|
|
93
139
|
exports.ImageLoader = ImageLoader;
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { DataLoader, DataConfig,
|
|
1
|
+
import { DataLoader, DataConfig, DataTuple } from './data_loader';
|
|
2
2
|
import { Dataset } from '../dataset_builder';
|
|
3
3
|
import { Task } from '../../task';
|
|
4
|
-
import
|
|
4
|
+
import { tf } from '../..';
|
|
5
5
|
export declare abstract class TabularLoader<Source> extends DataLoader<Source> {
|
|
6
6
|
private readonly delimiter;
|
|
7
7
|
constructor(task: Task, delimiter: string);
|
|
@@ -25,5 +25,5 @@ export declare abstract class TabularLoader<Source> extends DataLoader<Source> {
|
|
|
25
25
|
* Creates the CSV datasets based off the given sources, then fuses them into a single CSV
|
|
26
26
|
* dataset.
|
|
27
27
|
*/
|
|
28
|
-
loadAll(sources: Source[], config: DataConfig): Promise<
|
|
28
|
+
loadAll(sources: Source[], config: DataConfig): Promise<DataTuple>;
|
|
29
29
|
}
|
|
@@ -4,6 +4,8 @@ exports.TabularLoader = void 0;
|
|
|
4
4
|
var tslib_1 = require("tslib");
|
|
5
5
|
var data_loader_1 = require("./data_loader");
|
|
6
6
|
var immutable_1 = require("immutable");
|
|
7
|
+
// window size from which the dataset shuffling will sample
|
|
8
|
+
var BUFFER_SIZE = 1000;
|
|
7
9
|
var TabularLoader = /** @class */ (function (_super) {
|
|
8
10
|
(0, tslib_1.__extends)(TabularLoader, _super);
|
|
9
11
|
function TabularLoader(task, delimiter) {
|
|
@@ -21,7 +23,7 @@ var TabularLoader = /** @class */ (function (_super) {
|
|
|
21
23
|
*/
|
|
22
24
|
TabularLoader.prototype.load = function (source, config) {
|
|
23
25
|
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
24
|
-
var columnConfigs, csvConfig;
|
|
26
|
+
var columnConfigs, csvConfig, dataset;
|
|
25
27
|
return (0, tslib_1.__generator)(this, function (_a) {
|
|
26
28
|
/**
|
|
27
29
|
* Prepare the CSV config object based off the given features and labels.
|
|
@@ -40,19 +42,20 @@ var TabularLoader = /** @class */ (function (_super) {
|
|
|
40
42
|
configuredColumnsOnly: true,
|
|
41
43
|
delimiter: this.delimiter
|
|
42
44
|
};
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
45
|
+
dataset = this.loadTabularDatasetFrom(source, csvConfig).map(function (t) {
|
|
46
|
+
if (typeof t === 'object' && ('xs' in t) && ('ys' in t)) {
|
|
47
|
+
return t;
|
|
48
|
+
}
|
|
49
|
+
throw new Error('expected TensorContainerObject');
|
|
50
|
+
}).map(function (t) {
|
|
51
|
+
// TODO order may not be stable between tensor
|
|
52
|
+
var _a = t, xs = _a.xs, ys = _a.ys;
|
|
53
|
+
return {
|
|
54
|
+
xs: Object.values(xs),
|
|
55
|
+
ys: Object.values(ys)
|
|
56
|
+
};
|
|
57
|
+
});
|
|
58
|
+
return [2 /*return*/, ((config === null || config === void 0 ? void 0 : config.shuffle) === undefined || (config === null || config === void 0 ? void 0 : config.shuffle)) ? dataset.shuffle(BUFFER_SIZE) : dataset];
|
|
56
59
|
});
|
|
57
60
|
});
|
|
58
61
|
};
|
|
@@ -62,22 +65,28 @@ var TabularLoader = /** @class */ (function (_super) {
|
|
|
62
65
|
*/
|
|
63
66
|
TabularLoader.prototype.loadAll = function (sources, config) {
|
|
64
67
|
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
65
|
-
var datasets, dataset;
|
|
68
|
+
var datasets, dataset, data;
|
|
66
69
|
var _this = this;
|
|
67
70
|
return (0, tslib_1.__generator)(this, function (_a) {
|
|
68
71
|
switch (_a.label) {
|
|
69
72
|
case 0: return [4 /*yield*/, Promise.all(sources.map(function (source) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
70
73
|
switch (_a.label) {
|
|
71
|
-
case 0: return [4 /*yield*/, this.load(source, config)];
|
|
74
|
+
case 0: return [4 /*yield*/, this.load(source, (0, tslib_1.__assign)((0, tslib_1.__assign)({}, config), { shuffle: false }))];
|
|
72
75
|
case 1: return [2 /*return*/, _a.sent()];
|
|
73
76
|
}
|
|
74
77
|
}); }); }))];
|
|
75
78
|
case 1:
|
|
76
79
|
datasets = _a.sent();
|
|
77
80
|
dataset = (0, immutable_1.List)(datasets).reduce(function (acc, dataset) { return acc.concatenate(dataset); });
|
|
81
|
+
data = {
|
|
82
|
+
dataset: (config === null || config === void 0 ? void 0 : config.shuffle) ? dataset.shuffle(BUFFER_SIZE) : dataset,
|
|
83
|
+
// dataset.size does not work for csv datasets
|
|
84
|
+
// https://github.com/tensorflow/tfjs/issues/5845
|
|
85
|
+
size: 0
|
|
86
|
+
};
|
|
87
|
+
// TODO: Implement validation split for tabular data (tricky due to streaming)
|
|
78
88
|
return [2 /*return*/, {
|
|
79
|
-
|
|
80
|
-
size: dataset.size // TODO: needs to be tested
|
|
89
|
+
train: data
|
|
81
90
|
}];
|
|
82
91
|
}
|
|
83
92
|
});
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { DataLoader,
|
|
2
|
+
import { DataConfig, DataLoader, DataTuple } from './data_loader/data_loader';
|
|
3
3
|
import { Task } from '@/task';
|
|
4
4
|
export declare type Dataset = tf.data.Dataset<tf.TensorContainer>;
|
|
5
5
|
export declare class DatasetBuilder<Source> {
|
|
@@ -11,7 +11,8 @@ export declare class DatasetBuilder<Source> {
|
|
|
11
11
|
constructor(dataLoader: DataLoader<Source>, task: Task);
|
|
12
12
|
addFiles(sources: Source[], label?: string): void;
|
|
13
13
|
clearFiles(label?: string): void;
|
|
14
|
-
|
|
14
|
+
private getLabels;
|
|
15
|
+
build(config?: DataConfig): Promise<DataTuple>;
|
|
15
16
|
isBuilt(): boolean;
|
|
16
17
|
size(): number;
|
|
17
18
|
}
|