@epfml/discojs 1.0.0 → 2.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +28 -8
- package/dist/{async_buffer.d.ts → core/async_buffer.d.ts} +3 -3
- package/dist/{async_buffer.js → core/async_buffer.js} +5 -6
- package/dist/{async_informant.d.ts → core/async_informant.d.ts} +0 -0
- package/dist/{async_informant.js → core/async_informant.js} +0 -0
- package/dist/{client → core/client}/base.d.ts +4 -7
- package/dist/{client → core/client}/base.js +3 -2
- package/dist/core/client/decentralized/base.d.ts +32 -0
- package/dist/core/client/decentralized/base.js +212 -0
- package/dist/core/client/decentralized/clear_text.d.ts +14 -0
- package/dist/core/client/decentralized/clear_text.js +96 -0
- package/dist/{client → core/client}/decentralized/index.d.ts +0 -0
- package/dist/{client → core/client}/decentralized/index.js +0 -0
- package/dist/core/client/decentralized/messages.d.ts +41 -0
- package/dist/core/client/decentralized/messages.js +54 -0
- package/dist/core/client/decentralized/peer.d.ts +26 -0
- package/dist/core/client/decentralized/peer.js +210 -0
- package/dist/core/client/decentralized/peer_pool.d.ts +14 -0
- package/dist/core/client/decentralized/peer_pool.js +92 -0
- package/dist/core/client/decentralized/sec_agg.d.ts +22 -0
- package/dist/core/client/decentralized/sec_agg.js +190 -0
- package/dist/core/client/decentralized/secret_shares.d.ts +3 -0
- package/dist/core/client/decentralized/secret_shares.js +39 -0
- package/dist/core/client/decentralized/types.d.ts +2 -0
- package/dist/core/client/decentralized/types.js +7 -0
- package/dist/core/client/event_connection.d.ts +37 -0
- package/dist/core/client/event_connection.js +158 -0
- package/dist/core/client/federated/client.d.ts +37 -0
- package/dist/core/client/federated/client.js +273 -0
- package/dist/core/client/federated/index.d.ts +2 -0
- package/dist/core/client/federated/index.js +7 -0
- package/dist/core/client/federated/messages.d.ts +38 -0
- package/dist/core/client/federated/messages.js +25 -0
- package/dist/{client → core/client}/index.d.ts +2 -1
- package/dist/{client → core/client}/index.js +3 -3
- package/dist/{client → core/client}/local.d.ts +2 -2
- package/dist/{client → core/client}/local.js +0 -0
- package/dist/core/client/messages.d.ts +28 -0
- package/dist/core/client/messages.js +33 -0
- package/dist/core/client/utils.d.ts +2 -0
- package/dist/core/client/utils.js +19 -0
- package/dist/core/dataset/data/data.d.ts +11 -0
- package/dist/core/dataset/data/data.js +20 -0
- package/dist/core/dataset/data/data_split.d.ts +5 -0
- package/dist/{client/decentralized/types.js → core/dataset/data/data_split.js} +0 -0
- package/dist/core/dataset/data/image_data.d.ts +8 -0
- package/dist/core/dataset/data/image_data.js +64 -0
- package/dist/core/dataset/data/index.d.ts +5 -0
- package/dist/core/dataset/data/index.js +11 -0
- package/dist/core/dataset/data/preprocessing.d.ts +13 -0
- package/dist/core/dataset/data/preprocessing.js +33 -0
- package/dist/core/dataset/data/tabular_data.d.ts +8 -0
- package/dist/core/dataset/data/tabular_data.js +40 -0
- package/dist/{dataset → core/dataset}/data_loader/data_loader.d.ts +4 -11
- package/dist/{dataset → core/dataset}/data_loader/data_loader.js +0 -0
- package/dist/core/dataset/data_loader/image_loader.d.ts +17 -0
- package/dist/core/dataset/data_loader/image_loader.js +141 -0
- package/dist/core/dataset/data_loader/index.d.ts +3 -0
- package/dist/core/dataset/data_loader/index.js +9 -0
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +29 -0
- package/dist/core/dataset/data_loader/tabular_loader.js +101 -0
- package/dist/core/dataset/dataset.d.ts +2 -0
- package/dist/{task/training_information.js → core/dataset/dataset.js} +0 -0
- package/dist/{dataset → core/dataset}/dataset_builder.d.ts +5 -5
- package/dist/{dataset → core/dataset}/dataset_builder.js +14 -10
- package/dist/core/dataset/index.d.ts +4 -0
- package/dist/core/dataset/index.js +14 -0
- package/dist/core/default_tasks/cifar10.d.ts +2 -0
- package/dist/core/default_tasks/cifar10.js +68 -0
- package/dist/core/default_tasks/geotags.d.ts +2 -0
- package/dist/core/default_tasks/geotags.js +69 -0
- package/dist/core/default_tasks/index.d.ts +6 -0
- package/dist/core/default_tasks/index.js +15 -0
- package/dist/core/default_tasks/lus_covid.d.ts +2 -0
- package/dist/core/default_tasks/lus_covid.js +96 -0
- package/dist/core/default_tasks/mnist.d.ts +2 -0
- package/dist/core/default_tasks/mnist.js +69 -0
- package/dist/core/default_tasks/simple_face.d.ts +2 -0
- package/dist/core/default_tasks/simple_face.js +53 -0
- package/dist/core/default_tasks/titanic.d.ts +2 -0
- package/dist/core/default_tasks/titanic.js +97 -0
- package/dist/core/index.d.ts +18 -0
- package/dist/core/index.js +39 -0
- package/dist/{informant → core/informant}/graph_informant.d.ts +0 -0
- package/dist/{informant → core/informant}/graph_informant.js +0 -0
- package/dist/{informant → core/informant}/index.d.ts +0 -0
- package/dist/{informant → core/informant}/index.js +0 -0
- package/dist/{informant → core/informant}/training_informant/base.d.ts +3 -3
- package/dist/{informant → core/informant}/training_informant/base.js +3 -2
- package/dist/{informant → core/informant}/training_informant/decentralized.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/decentralized.js +0 -0
- package/dist/{informant → core/informant}/training_informant/federated.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/federated.js +0 -0
- package/dist/{informant → core/informant}/training_informant/index.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/index.js +0 -0
- package/dist/{informant → core/informant}/training_informant/local.d.ts +2 -2
- package/dist/{informant → core/informant}/training_informant/local.js +2 -2
- package/dist/{logging → core/logging}/console_logger.d.ts +0 -0
- package/dist/{logging → core/logging}/console_logger.js +0 -0
- package/dist/{logging → core/logging}/index.d.ts +0 -0
- package/dist/{logging → core/logging}/index.js +0 -0
- package/dist/{logging → core/logging}/logger.d.ts +0 -0
- package/dist/{logging → core/logging}/logger.js +0 -0
- package/dist/{logging → core/logging}/trainer_logger.d.ts +0 -0
- package/dist/{logging → core/logging}/trainer_logger.js +0 -0
- package/dist/{memory → core/memory}/base.d.ts +2 -2
- package/dist/{memory → core/memory}/base.js +0 -0
- package/dist/{memory → core/memory}/empty.d.ts +0 -0
- package/dist/{memory → core/memory}/empty.js +0 -0
- package/dist/core/memory/index.d.ts +3 -0
- package/dist/core/memory/index.js +9 -0
- package/dist/{memory → core/memory}/model_type.d.ts +0 -0
- package/dist/{memory → core/memory}/model_type.js +0 -0
- package/dist/{privacy.d.ts → core/privacy.d.ts} +2 -3
- package/dist/{privacy.js → core/privacy.js} +3 -16
- package/dist/{serialization → core/serialization}/index.d.ts +0 -0
- package/dist/{serialization → core/serialization}/index.js +0 -0
- package/dist/{serialization → core/serialization}/model.d.ts +0 -0
- package/dist/{serialization → core/serialization}/model.js +0 -0
- package/dist/core/serialization/weights.d.ts +5 -0
- package/dist/{serialization → core/serialization}/weights.js +11 -9
- package/dist/{task → core/task}/data_example.d.ts +0 -0
- package/dist/{task → core/task}/data_example.js +0 -0
- package/dist/core/task/digest.d.ts +5 -0
- package/dist/core/task/digest.js +18 -0
- package/dist/{task → core/task}/display_information.d.ts +5 -5
- package/dist/{task → core/task}/display_information.js +5 -10
- package/dist/{task → core/task}/index.d.ts +3 -0
- package/dist/core/task/index.js +15 -0
- package/dist/core/task/model_compile_data.d.ts +6 -0
- package/dist/core/task/model_compile_data.js +22 -0
- package/dist/{task → core/task}/summary.d.ts +0 -0
- package/dist/{task → core/task}/summary.js +0 -4
- package/dist/{task → core/task}/task.d.ts +4 -2
- package/dist/{task → core/task}/task.js +10 -7
- package/dist/core/task/task_handler.d.ts +5 -0
- package/dist/core/task/task_handler.js +53 -0
- package/dist/core/task/task_provider.d.ts +6 -0
- package/dist/core/task/task_provider.js +13 -0
- package/dist/{task → core/task}/training_information.d.ts +10 -14
- package/dist/core/task/training_information.js +66 -0
- package/dist/core/training/disco.d.ts +23 -0
- package/dist/core/training/disco.js +130 -0
- package/dist/{training → core/training}/index.d.ts +0 -0
- package/dist/{training → core/training}/index.js +0 -0
- package/dist/{training → core/training}/trainer/distributed_trainer.d.ts +1 -2
- package/dist/{training → core/training}/trainer/distributed_trainer.js +6 -5
- package/dist/{training → core/training}/trainer/local_trainer.d.ts +2 -2
- package/dist/{training → core/training}/trainer/local_trainer.js +0 -0
- package/dist/{training → core/training}/trainer/round_tracker.d.ts +0 -0
- package/dist/{training → core/training}/trainer/round_tracker.js +0 -0
- package/dist/{training → core/training}/trainer/trainer.d.ts +1 -2
- package/dist/{training → core/training}/trainer/trainer.js +2 -2
- package/dist/{training → core/training}/trainer/trainer_builder.d.ts +0 -0
- package/dist/{training → core/training}/trainer/trainer_builder.js +0 -0
- package/dist/core/training/training_schemes.d.ts +5 -0
- package/dist/{training → core/training}/training_schemes.js +2 -2
- package/dist/{types.d.ts → core/types.d.ts} +0 -0
- package/dist/{types.js → core/types.js} +0 -0
- package/dist/{validation → core/validation}/index.d.ts +0 -0
- package/dist/{validation → core/validation}/index.js +0 -0
- package/dist/{validation → core/validation}/validator.d.ts +5 -8
- package/dist/{validation → core/validation}/validator.js +9 -11
- package/dist/core/weights/aggregation.d.ts +7 -0
- package/dist/core/weights/aggregation.js +72 -0
- package/dist/core/weights/index.d.ts +2 -0
- package/dist/core/weights/index.js +7 -0
- package/dist/core/weights/weights_container.d.ts +19 -0
- package/dist/core/weights/weights_container.js +64 -0
- package/dist/dataset/data_loader/image_loader.d.ts +3 -15
- package/dist/dataset/data_loader/image_loader.js +12 -125
- package/dist/dataset/data_loader/index.d.ts +2 -3
- package/dist/dataset/data_loader/index.js +3 -5
- package/dist/dataset/data_loader/tabular_loader.d.ts +3 -28
- package/dist/dataset/data_loader/tabular_loader.js +11 -92
- package/dist/imports.d.ts +2 -0
- package/dist/imports.js +7 -0
- package/dist/index.d.ts +2 -19
- package/dist/index.js +3 -39
- package/dist/memory/index.d.ts +1 -3
- package/dist/memory/index.js +3 -7
- package/dist/memory/memory.d.ts +26 -0
- package/dist/memory/memory.js +160 -0
- package/package.json +13 -26
- package/dist/aggregation.d.ts +0 -5
- package/dist/aggregation.js +0 -33
- package/dist/client/decentralized/base.d.ts +0 -43
- package/dist/client/decentralized/base.js +0 -243
- package/dist/client/decentralized/clear_text.d.ts +0 -13
- package/dist/client/decentralized/clear_text.js +0 -78
- package/dist/client/decentralized/messages.d.ts +0 -37
- package/dist/client/decentralized/messages.js +0 -15
- package/dist/client/decentralized/sec_agg.d.ts +0 -18
- package/dist/client/decentralized/sec_agg.js +0 -169
- package/dist/client/decentralized/secret_shares.d.ts +0 -5
- package/dist/client/decentralized/secret_shares.js +0 -58
- package/dist/client/decentralized/types.d.ts +0 -1
- package/dist/client/federated.d.ts +0 -30
- package/dist/client/federated.js +0 -218
- package/dist/dataset/index.d.ts +0 -2
- package/dist/dataset/index.js +0 -7
- package/dist/model_actor.d.ts +0 -16
- package/dist/model_actor.js +0 -20
- package/dist/serialization/weights.d.ts +0 -5
- package/dist/task/index.js +0 -8
- package/dist/task/model_compile_data.d.ts +0 -6
- package/dist/task/model_compile_data.js +0 -12
- package/dist/tasks/cifar10.d.ts +0 -4
- package/dist/tasks/cifar10.js +0 -76
- package/dist/tasks/index.d.ts +0 -5
- package/dist/tasks/index.js +0 -9
- package/dist/tasks/lus_covid.d.ts +0 -4
- package/dist/tasks/lus_covid.js +0 -85
- package/dist/tasks/mnist.d.ts +0 -4
- package/dist/tasks/mnist.js +0 -58
- package/dist/tasks/simple_face.d.ts +0 -4
- package/dist/tasks/simple_face.js +0 -84
- package/dist/tasks/titanic.d.ts +0 -4
- package/dist/tasks/titanic.js +0 -88
- package/dist/tfjs.d.ts +0 -2
- package/dist/tfjs.js +0 -6
- package/dist/training/disco.d.ts +0 -14
- package/dist/training/disco.js +0 -70
- package/dist/training/training_schemes.d.ts +0 -5
|
@@ -1,169 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.SecAgg = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var immutable_1 = require("immutable");
|
|
6
|
-
var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
|
|
7
|
-
var __1 = require("../..");
|
|
8
|
-
var base_1 = require("./base");
|
|
9
|
-
var messages = (0, tslib_1.__importStar)(require("./messages"));
|
|
10
|
-
var secret_shares = (0, tslib_1.__importStar)(require("./secret_shares"));
|
|
11
|
-
/**
|
|
12
|
-
* Decentralized client that utilizes secure aggregation so client updates remain private
|
|
13
|
-
*/
|
|
14
|
-
var SecAgg = /** @class */ (function (_super) {
|
|
15
|
-
(0, tslib_1.__extends)(SecAgg, _super);
|
|
16
|
-
function SecAgg() {
|
|
17
|
-
var _this = _super !== null && _super.apply(this, arguments) || this;
|
|
18
|
-
// list of weights received from other clients
|
|
19
|
-
_this.receivedShares = (0, immutable_1.List)();
|
|
20
|
-
// list of partial sums received by client
|
|
21
|
-
_this.receivedPartialSums = (0, immutable_1.List)();
|
|
22
|
-
// the partial sum calculated by the client
|
|
23
|
-
_this.mySum = [];
|
|
24
|
-
return _this;
|
|
25
|
-
}
|
|
26
|
-
/*
|
|
27
|
-
generates shares and sends to all ready peers adds differential privacy
|
|
28
|
-
*/
|
|
29
|
-
SecAgg.prototype.sendShares = function (noisyWeights, round, trainingInformant) {
|
|
30
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
31
|
-
var weightShares, i, weights, msg, encodedMsg;
|
|
32
|
-
var _a;
|
|
33
|
-
return (0, tslib_1.__generator)(this, function (_b) {
|
|
34
|
-
switch (_b.label) {
|
|
35
|
-
case 0:
|
|
36
|
-
weightShares = secret_shares.generateAllShares(noisyWeights, this.peers.length, this.maxShareValue);
|
|
37
|
-
i = 0;
|
|
38
|
-
_b.label = 1;
|
|
39
|
-
case 1:
|
|
40
|
-
if (!(i < this.peers.length)) return [3 /*break*/, 4];
|
|
41
|
-
weights = weightShares.get(i);
|
|
42
|
-
if (weights === undefined) {
|
|
43
|
-
throw new Error('weight shares generated incorrectly');
|
|
44
|
-
}
|
|
45
|
-
_a = {
|
|
46
|
-
type: messages.messageType.clientSharesMessageServer,
|
|
47
|
-
peerID: this.ID
|
|
48
|
-
};
|
|
49
|
-
return [4 /*yield*/, __1.serialization.weights.encode(weights)];
|
|
50
|
-
case 2:
|
|
51
|
-
msg = (_a.weights = _b.sent(),
|
|
52
|
-
_a.destination = this.peers[i],
|
|
53
|
-
_a);
|
|
54
|
-
encodedMsg = msgpack_lite_1.default.encode(msg);
|
|
55
|
-
this.sendMessagetoPeer(encodedMsg);
|
|
56
|
-
_b.label = 3;
|
|
57
|
-
case 3:
|
|
58
|
-
i++;
|
|
59
|
-
return [3 /*break*/, 1];
|
|
60
|
-
case 4: return [2 /*return*/];
|
|
61
|
-
}
|
|
62
|
-
});
|
|
63
|
-
});
|
|
64
|
-
};
|
|
65
|
-
/*
|
|
66
|
-
sends partial sums to connected peers so final update can be calculated
|
|
67
|
-
*/
|
|
68
|
-
SecAgg.prototype.sendPartialSums = function () {
|
|
69
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
70
|
-
var i, msg, encodedMsg;
|
|
71
|
-
var _a;
|
|
72
|
-
return (0, tslib_1.__generator)(this, function (_b) {
|
|
73
|
-
switch (_b.label) {
|
|
74
|
-
case 0:
|
|
75
|
-
// calculating my personal partial sum from received shares that i will share with peers
|
|
76
|
-
this.mySum = __1.aggregation.sumWeights((0, immutable_1.List)(Array.from(this.receivedShares.values())));
|
|
77
|
-
i = 0;
|
|
78
|
-
_b.label = 1;
|
|
79
|
-
case 1:
|
|
80
|
-
if (!(i < this.peers.length)) return [3 /*break*/, 4];
|
|
81
|
-
_a = {
|
|
82
|
-
type: messages.messageType.clientPartialSumsMessageServer,
|
|
83
|
-
peerID: this.ID
|
|
84
|
-
};
|
|
85
|
-
return [4 /*yield*/, __1.serialization.weights.encode(this.mySum)];
|
|
86
|
-
case 2:
|
|
87
|
-
msg = (_a.partials = _b.sent(),
|
|
88
|
-
_a.destination = this.peers[i],
|
|
89
|
-
_a);
|
|
90
|
-
encodedMsg = msgpack_lite_1.default.encode(msg);
|
|
91
|
-
this.sendMessagetoPeer(encodedMsg);
|
|
92
|
-
_b.label = 3;
|
|
93
|
-
case 3:
|
|
94
|
-
i++;
|
|
95
|
-
return [3 /*break*/, 1];
|
|
96
|
-
case 4: return [2 /*return*/];
|
|
97
|
-
}
|
|
98
|
-
});
|
|
99
|
-
});
|
|
100
|
-
};
|
|
101
|
-
SecAgg.prototype.sendAndReceiveWeights = function (noisyWeights, round, trainingInformant) {
|
|
102
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
103
|
-
var _this = this;
|
|
104
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
105
|
-
switch (_a.label) {
|
|
106
|
-
case 0:
|
|
107
|
-
// reset fields at beginning of each round
|
|
108
|
-
this.receivedShares = this.receivedShares.clear();
|
|
109
|
-
this.receivedPartialSums = this.receivedPartialSums.clear();
|
|
110
|
-
// PHASE 1 COMMUNICATION --> send additive shares to ready peers, pause program until shares are received from all peers
|
|
111
|
-
return [4 /*yield*/, this.sendShares(noisyWeights, round, trainingInformant)];
|
|
112
|
-
case 1:
|
|
113
|
-
// PHASE 1 COMMUNICATION --> send additive shares to ready peers, pause program until shares are received from all peers
|
|
114
|
-
_a.sent();
|
|
115
|
-
return [4 /*yield*/, this.pauseUntil(function () { return _this.receivedShares.size >= _this.peers.length; })
|
|
116
|
-
// PHASE 2 COMMUNICATION --> send partial sums to ready peers
|
|
117
|
-
];
|
|
118
|
-
case 2:
|
|
119
|
-
_a.sent();
|
|
120
|
-
// PHASE 2 COMMUNICATION --> send partial sums to ready peers
|
|
121
|
-
return [4 /*yield*/, this.sendPartialSums()
|
|
122
|
-
// after all partial sums are received, return list of partial sums to be aggregated
|
|
123
|
-
];
|
|
124
|
-
case 3:
|
|
125
|
-
// PHASE 2 COMMUNICATION --> send partial sums to ready peers
|
|
126
|
-
_a.sent();
|
|
127
|
-
// after all partial sums are received, return list of partial sums to be aggregated
|
|
128
|
-
return [4 /*yield*/, this.pauseUntil(function () { return _this.receivedPartialSums.size >= _this.peers.length; })];
|
|
129
|
-
case 4:
|
|
130
|
-
// after all partial sums are received, return list of partial sums to be aggregated
|
|
131
|
-
_a.sent();
|
|
132
|
-
return [2 /*return*/, this.receivedPartialSums];
|
|
133
|
-
}
|
|
134
|
-
});
|
|
135
|
-
});
|
|
136
|
-
};
|
|
137
|
-
/*
|
|
138
|
-
checks if message contains shares from a peer
|
|
139
|
-
*/
|
|
140
|
-
SecAgg.prototype.instanceOfClientSharesMessageServer = function (msg) {
|
|
141
|
-
return msg.type === messages.messageType.clientSharesMessageServer;
|
|
142
|
-
};
|
|
143
|
-
/*
|
|
144
|
-
checks if message contains partial sums from a peer
|
|
145
|
-
*/
|
|
146
|
-
SecAgg.prototype.instanceOfClientPartialSumsMessageServer = function (msg) {
|
|
147
|
-
return msg.type === messages.messageType.clientPartialSumsMessageServer;
|
|
148
|
-
};
|
|
149
|
-
/*
|
|
150
|
-
handles received messages from signaling server
|
|
151
|
-
*/
|
|
152
|
-
SecAgg.prototype.clientHandle = function (msg) {
|
|
153
|
-
if (this.instanceOfClientSharesMessageServer(msg)) {
|
|
154
|
-
// update received weights by one weights reception
|
|
155
|
-
var weights = __1.serialization.weights.decode(msg.weights);
|
|
156
|
-
this.receivedShares = this.receivedShares.push(weights);
|
|
157
|
-
}
|
|
158
|
-
else if (this.instanceOfClientPartialSumsMessageServer(msg)) {
|
|
159
|
-
// update received partial sums by one partial sum
|
|
160
|
-
var partials = __1.serialization.weights.decode(msg.partials);
|
|
161
|
-
this.receivedPartialSums = this.receivedPartialSums.push(partials);
|
|
162
|
-
}
|
|
163
|
-
else {
|
|
164
|
-
throw new Error('Unexpected Message Type');
|
|
165
|
-
}
|
|
166
|
-
};
|
|
167
|
-
return SecAgg;
|
|
168
|
-
}(base_1.Base));
|
|
169
|
-
exports.SecAgg = SecAgg;
|
|
@@ -1,5 +0,0 @@
|
|
|
1
|
-
import { List } from 'immutable';
|
|
2
|
-
import { Weights } from '../..';
|
|
3
|
-
export declare function lastShare(currentShares: Weights[], secret: Weights): Weights;
|
|
4
|
-
export declare function generateAllShares(secret: Weights, nParticipants: number, maxShareValue: number): List<Weights>;
|
|
5
|
-
export declare function generateRandomShare(secret: Weights, maxShareValue: number): Weights;
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.generateRandomShare = exports.generateAllShares = exports.lastShare = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var immutable_1 = require("immutable");
|
|
6
|
-
var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
|
|
7
|
-
var crypto = (0, tslib_1.__importStar)(require("crypto"));
|
|
8
|
-
var __1 = require("../..");
|
|
9
|
-
var maxSeed = Math.pow(2, 47);
|
|
10
|
-
/*
|
|
11
|
-
Return Weights in the remaining share once N-1 shares have been constructed (where N is number of ready clients)
|
|
12
|
-
*/
|
|
13
|
-
function lastShare(currentShares, secret) {
|
|
14
|
-
if (currentShares.length === 0) {
|
|
15
|
-
throw new Error('Need at least one current share to be able to subtract secret from');
|
|
16
|
-
}
|
|
17
|
-
var currentShares2 = (0, immutable_1.List)(currentShares);
|
|
18
|
-
var last = __1.aggregation.subtractWeights((0, immutable_1.List)([secret, __1.aggregation.sumWeights(currentShares2)]));
|
|
19
|
-
return last;
|
|
20
|
-
}
|
|
21
|
-
exports.lastShare = lastShare;
|
|
22
|
-
/*
|
|
23
|
-
Generate N additive shares that aggregate to the secret weights array (where N is number of ready clients)
|
|
24
|
-
*/
|
|
25
|
-
function generateAllShares(secret, nParticipants, maxShareValue) {
|
|
26
|
-
var shares = [];
|
|
27
|
-
for (var i = 0; i < nParticipants - 1; i++) {
|
|
28
|
-
shares.push(generateRandomShare(secret, maxShareValue));
|
|
29
|
-
}
|
|
30
|
-
shares.push(lastShare(shares, secret));
|
|
31
|
-
var sharesFinal = (0, immutable_1.List)(shares);
|
|
32
|
-
return sharesFinal;
|
|
33
|
-
}
|
|
34
|
-
exports.generateAllShares = generateAllShares;
|
|
35
|
-
/*
|
|
36
|
-
generates one share in the same shape as the secret that is populated with values randomly chosend from
|
|
37
|
-
a uniform distribution between (-maxShareValue, maxShareValue).
|
|
38
|
-
*/
|
|
39
|
-
function generateRandomShare(secret, maxShareValue) {
|
|
40
|
-
var e_1, _a;
|
|
41
|
-
var share = [];
|
|
42
|
-
var seed = crypto.randomInt(maxSeed);
|
|
43
|
-
try {
|
|
44
|
-
for (var secret_1 = (0, tslib_1.__values)(secret), secret_1_1 = secret_1.next(); !secret_1_1.done; secret_1_1 = secret_1.next()) {
|
|
45
|
-
var t = secret_1_1.value;
|
|
46
|
-
share.push(tf.randomUniform(t.shape, -maxShareValue, maxShareValue, 'float32', seed));
|
|
47
|
-
}
|
|
48
|
-
}
|
|
49
|
-
catch (e_1_1) { e_1 = { error: e_1_1 }; }
|
|
50
|
-
finally {
|
|
51
|
-
try {
|
|
52
|
-
if (secret_1_1 && !secret_1_1.done && (_a = secret_1.return)) _a.call(secret_1);
|
|
53
|
-
}
|
|
54
|
-
finally { if (e_1) throw e_1.error; }
|
|
55
|
-
}
|
|
56
|
-
return share;
|
|
57
|
-
}
|
|
58
|
-
exports.generateRandomShare = generateRandomShare;
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export declare type PeerID = number;
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
import { informant, MetadataID, Weights } from '..';
|
|
2
|
-
import { Base } from './base';
|
|
3
|
-
/**
|
|
4
|
-
* Class that deals with communication with the centralized server when training
|
|
5
|
-
* a specific task in the federated setting.
|
|
6
|
-
*/
|
|
7
|
-
export declare class Federated extends Base {
|
|
8
|
-
private readonly clientID;
|
|
9
|
-
private readonly peer;
|
|
10
|
-
private round;
|
|
11
|
-
private urlTo;
|
|
12
|
-
private urlToMetadata;
|
|
13
|
-
/**
|
|
14
|
-
* Initialize the connection to the server. TODO: In the case of FeAI,
|
|
15
|
-
* should return the current server-side round for the task.
|
|
16
|
-
*/
|
|
17
|
-
connect(): Promise<void>;
|
|
18
|
-
/**
|
|
19
|
-
* Disconnection process when user quits the task.
|
|
20
|
-
*/
|
|
21
|
-
disconnect(): Promise<void>;
|
|
22
|
-
postWeightsToServer(weights: Weights): Promise<void>;
|
|
23
|
-
postMetadata(metadataID: string, metadata: string): Promise<void>;
|
|
24
|
-
getMetadataMap(metadataID: MetadataID): Promise<Map<string, unknown>>;
|
|
25
|
-
getLatestServerRound(): Promise<number>;
|
|
26
|
-
pullRoundAndFetchWeights(): Promise<Weights | undefined>;
|
|
27
|
-
pullServerStatistics(trainingInformant: informant.FederatedInformant): Promise<void>;
|
|
28
|
-
onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, _: number, trainingInformant: informant.FederatedInformant): Promise<Weights>;
|
|
29
|
-
onTrainEndCommunication(): Promise<void>;
|
|
30
|
-
}
|
package/dist/client/federated.js
DELETED
|
@@ -1,218 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Federated = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var msgpack = (0, tslib_1.__importStar)(require("msgpack-lite"));
|
|
6
|
-
var axios_1 = (0, tslib_1.__importDefault)(require("axios"));
|
|
7
|
-
var uuid_1 = require("uuid");
|
|
8
|
-
var __1 = require("..");
|
|
9
|
-
var base_1 = require("./base");
|
|
10
|
-
/**
|
|
11
|
-
* Class that deals with communication with the centralized server when training
|
|
12
|
-
* a specific task in the federated setting.
|
|
13
|
-
*/
|
|
14
|
-
var Federated = /** @class */ (function (_super) {
|
|
15
|
-
(0, tslib_1.__extends)(Federated, _super);
|
|
16
|
-
function Federated() {
|
|
17
|
-
var _this = _super !== null && _super.apply(this, arguments) || this;
|
|
18
|
-
_this.clientID = (0, uuid_1.v4)();
|
|
19
|
-
_this.round = 0;
|
|
20
|
-
return _this;
|
|
21
|
-
}
|
|
22
|
-
Federated.prototype.urlTo = function (category) {
|
|
23
|
-
var url = new URL('', this.url);
|
|
24
|
-
url.pathname += [
|
|
25
|
-
'feai',
|
|
26
|
-
category,
|
|
27
|
-
this.task.taskID,
|
|
28
|
-
this.clientID
|
|
29
|
-
].join('/');
|
|
30
|
-
return url.href;
|
|
31
|
-
};
|
|
32
|
-
Federated.prototype.urlToMetadata = function (metadataID) {
|
|
33
|
-
var url = new URL('', this.url);
|
|
34
|
-
url.pathname += [
|
|
35
|
-
'feai',
|
|
36
|
-
'metadata',
|
|
37
|
-
metadataID,
|
|
38
|
-
this.task.taskID,
|
|
39
|
-
this.round,
|
|
40
|
-
this.clientID
|
|
41
|
-
].join('/');
|
|
42
|
-
return url.href;
|
|
43
|
-
};
|
|
44
|
-
/**
|
|
45
|
-
* Initialize the connection to the server. TODO: In the case of FeAI,
|
|
46
|
-
* should return the current server-side round for the task.
|
|
47
|
-
*/
|
|
48
|
-
Federated.prototype.connect = function () {
|
|
49
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
50
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
51
|
-
switch (_a.label) {
|
|
52
|
-
case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('connect'))];
|
|
53
|
-
case 1:
|
|
54
|
-
_a.sent();
|
|
55
|
-
return [2 /*return*/];
|
|
56
|
-
}
|
|
57
|
-
});
|
|
58
|
-
});
|
|
59
|
-
};
|
|
60
|
-
/**
|
|
61
|
-
* Disconnection process when user quits the task.
|
|
62
|
-
*/
|
|
63
|
-
Federated.prototype.disconnect = function () {
|
|
64
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
65
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
66
|
-
switch (_a.label) {
|
|
67
|
-
case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('disconnect'))];
|
|
68
|
-
case 1:
|
|
69
|
-
_a.sent();
|
|
70
|
-
return [2 /*return*/];
|
|
71
|
-
}
|
|
72
|
-
});
|
|
73
|
-
});
|
|
74
|
-
};
|
|
75
|
-
Federated.prototype.postWeightsToServer = function (weights) {
|
|
76
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
77
|
-
var _a;
|
|
78
|
-
var _b, _c;
|
|
79
|
-
return (0, tslib_1.__generator)(this, function (_d) {
|
|
80
|
-
switch (_d.label) {
|
|
81
|
-
case 0:
|
|
82
|
-
_a = axios_1.default;
|
|
83
|
-
_b = {
|
|
84
|
-
method: 'post',
|
|
85
|
-
url: this.urlTo('weights')
|
|
86
|
-
};
|
|
87
|
-
_c = {};
|
|
88
|
-
return [4 /*yield*/, __1.serialization.weights.encode(weights)];
|
|
89
|
-
case 1: return [4 /*yield*/, _a.apply(void 0, [(_b.data = (_c.weights = _d.sent(),
|
|
90
|
-
_c.round = this.round,
|
|
91
|
-
_c),
|
|
92
|
-
_b)])];
|
|
93
|
-
case 2:
|
|
94
|
-
_d.sent();
|
|
95
|
-
return [2 /*return*/];
|
|
96
|
-
}
|
|
97
|
-
});
|
|
98
|
-
});
|
|
99
|
-
};
|
|
100
|
-
Federated.prototype.postMetadata = function (metadataID, metadata) {
|
|
101
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
102
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
103
|
-
switch (_a.label) {
|
|
104
|
-
case 0: return [4 /*yield*/, (0, axios_1.default)({
|
|
105
|
-
method: 'post',
|
|
106
|
-
url: this.urlToMetadata(metadataID),
|
|
107
|
-
data: {
|
|
108
|
-
metadataID: metadata
|
|
109
|
-
}
|
|
110
|
-
})];
|
|
111
|
-
case 1:
|
|
112
|
-
_a.sent();
|
|
113
|
-
return [2 /*return*/];
|
|
114
|
-
}
|
|
115
|
-
});
|
|
116
|
-
});
|
|
117
|
-
};
|
|
118
|
-
Federated.prototype.getMetadataMap = function (metadataID) {
|
|
119
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
120
|
-
var response, body;
|
|
121
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
122
|
-
switch (_a.label) {
|
|
123
|
-
case 0: return [4 /*yield*/, axios_1.default.get(this.urlToMetadata(metadataID))];
|
|
124
|
-
case 1:
|
|
125
|
-
response = _a.sent();
|
|
126
|
-
return [4 /*yield*/, response.data];
|
|
127
|
-
case 2:
|
|
128
|
-
body = _a.sent();
|
|
129
|
-
return [2 /*return*/, new Map(msgpack.decode(body[metadataID]))];
|
|
130
|
-
}
|
|
131
|
-
});
|
|
132
|
-
});
|
|
133
|
-
};
|
|
134
|
-
Federated.prototype.getLatestServerRound = function () {
|
|
135
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
136
|
-
var response;
|
|
137
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
138
|
-
switch (_a.label) {
|
|
139
|
-
case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('round'))];
|
|
140
|
-
case 1:
|
|
141
|
-
response = _a.sent();
|
|
142
|
-
if (response.status === 200) {
|
|
143
|
-
return [2 /*return*/, response.data.round];
|
|
144
|
-
}
|
|
145
|
-
console.log('Error getting weights: code', response.status);
|
|
146
|
-
return [2 /*return*/, -1];
|
|
147
|
-
}
|
|
148
|
-
});
|
|
149
|
-
});
|
|
150
|
-
};
|
|
151
|
-
Federated.prototype.pullRoundAndFetchWeights = function () {
|
|
152
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
153
|
-
var serverRound, response, serverWeights;
|
|
154
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
155
|
-
switch (_a.label) {
|
|
156
|
-
case 0: return [4 /*yield*/, this.getLatestServerRound()];
|
|
157
|
-
case 1:
|
|
158
|
-
serverRound = _a.sent();
|
|
159
|
-
return [4 /*yield*/, axios_1.default.get(this.urlTo('weights'))];
|
|
160
|
-
case 2:
|
|
161
|
-
response = _a.sent();
|
|
162
|
-
serverWeights = __1.serialization.weights.decode(response.data);
|
|
163
|
-
if (this.round < serverRound) {
|
|
164
|
-
// Update the local round to match the server's
|
|
165
|
-
this.round = serverRound;
|
|
166
|
-
return [2 /*return*/, serverWeights];
|
|
167
|
-
}
|
|
168
|
-
else {
|
|
169
|
-
return [2 /*return*/, undefined];
|
|
170
|
-
}
|
|
171
|
-
return [2 /*return*/];
|
|
172
|
-
}
|
|
173
|
-
});
|
|
174
|
-
});
|
|
175
|
-
};
|
|
176
|
-
Federated.prototype.pullServerStatistics = function (trainingInformant) {
|
|
177
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
178
|
-
var response;
|
|
179
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
180
|
-
switch (_a.label) {
|
|
181
|
-
case 0: return [4 /*yield*/, axios_1.default.get(this.urlTo('statistics'))];
|
|
182
|
-
case 1:
|
|
183
|
-
response = _a.sent();
|
|
184
|
-
trainingInformant.update(response.data.statistics);
|
|
185
|
-
return [2 /*return*/];
|
|
186
|
-
}
|
|
187
|
-
});
|
|
188
|
-
});
|
|
189
|
-
};
|
|
190
|
-
Federated.prototype.onRoundEndCommunication = function (updatedWeights, staleWeights, _, trainingInformant) {
|
|
191
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
192
|
-
var noisyWeights, serverWeights;
|
|
193
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
194
|
-
switch (_a.label) {
|
|
195
|
-
case 0:
|
|
196
|
-
noisyWeights = __1.privacy.addDifferentialPrivacy(updatedWeights, staleWeights, this.task);
|
|
197
|
-
return [4 /*yield*/, this.postWeightsToServer(noisyWeights)];
|
|
198
|
-
case 1:
|
|
199
|
-
_a.sent();
|
|
200
|
-
return [4 /*yield*/, this.pullServerStatistics(trainingInformant)];
|
|
201
|
-
case 2:
|
|
202
|
-
_a.sent();
|
|
203
|
-
return [4 /*yield*/, this.pullRoundAndFetchWeights()];
|
|
204
|
-
case 3:
|
|
205
|
-
serverWeights = _a.sent();
|
|
206
|
-
return [2 /*return*/, serverWeights !== null && serverWeights !== void 0 ? serverWeights : staleWeights];
|
|
207
|
-
}
|
|
208
|
-
});
|
|
209
|
-
});
|
|
210
|
-
};
|
|
211
|
-
Federated.prototype.onTrainEndCommunication = function () {
|
|
212
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
213
|
-
return [2 /*return*/];
|
|
214
|
-
}); });
|
|
215
|
-
};
|
|
216
|
-
return Federated;
|
|
217
|
-
}(base_1.Base));
|
|
218
|
-
exports.Federated = Federated;
|
package/dist/dataset/index.d.ts
DELETED
package/dist/dataset/index.js
DELETED
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.DatasetBuilder = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var dataset_builder_1 = require("./dataset_builder");
|
|
6
|
-
Object.defineProperty(exports, "DatasetBuilder", { enumerable: true, get: function () { return dataset_builder_1.DatasetBuilder; } });
|
|
7
|
-
(0, tslib_1.__exportStar)(require("./data_loader"), exports);
|
package/dist/model_actor.d.ts
DELETED
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
import { Logger } from '@/logging/logger';
|
|
2
|
-
import { Task } from '@/task';
|
|
3
|
-
/**
|
|
4
|
-
* Base class for all actors of the system (e.g. trainer, tester, etc.)
|
|
5
|
-
* containing commonly used parameters
|
|
6
|
-
*/
|
|
7
|
-
export declare class ModelActor {
|
|
8
|
-
task: Task;
|
|
9
|
-
logger: Logger;
|
|
10
|
-
/**
|
|
11
|
-
* Constructor for Actor
|
|
12
|
-
* @param {Task} task - task on which the tasking shall be performed
|
|
13
|
-
* @param {Logger} logger - logging system (e.g. toaster)
|
|
14
|
-
*/
|
|
15
|
-
constructor(task: Task, logger: Logger);
|
|
16
|
-
}
|
package/dist/model_actor.js
DELETED
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.ModelActor = void 0;
|
|
4
|
-
/**
|
|
5
|
-
* Base class for all actors of the system (e.g. trainer, tester, etc.)
|
|
6
|
-
* containing commonly used parameters
|
|
7
|
-
*/
|
|
8
|
-
var ModelActor = /** @class */ (function () {
|
|
9
|
-
/**
|
|
10
|
-
* Constructor for Actor
|
|
11
|
-
* @param {Task} task - task on which the tasking shall be performed
|
|
12
|
-
* @param {Logger} logger - logging system (e.g. toaster)
|
|
13
|
-
*/
|
|
14
|
-
function ModelActor(task, logger) {
|
|
15
|
-
this.task = task;
|
|
16
|
-
this.logger = logger;
|
|
17
|
-
}
|
|
18
|
-
return ModelActor;
|
|
19
|
-
}());
|
|
20
|
-
exports.ModelActor = ModelActor;
|
|
@@ -1,5 +0,0 @@
|
|
|
1
|
-
import { Weights } from '@/types';
|
|
2
|
-
export declare type Encoded = number[];
|
|
3
|
-
export declare function isEncoded(raw: unknown): raw is Encoded;
|
|
4
|
-
export declare function encode(weights: Weights): Promise<Encoded>;
|
|
5
|
-
export declare function decode(encoded: Encoded): Weights;
|
package/dist/task/index.js
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.isDisplayInformation = exports.isTaskID = exports.isTask = void 0;
|
|
4
|
-
var task_1 = require("./task");
|
|
5
|
-
Object.defineProperty(exports, "isTask", { enumerable: true, get: function () { return task_1.isTask; } });
|
|
6
|
-
Object.defineProperty(exports, "isTaskID", { enumerable: true, get: function () { return task_1.isTaskID; } });
|
|
7
|
-
var display_information_1 = require("./display_information");
|
|
8
|
-
Object.defineProperty(exports, "isDisplayInformation", { enumerable: true, get: function () { return display_information_1.isDisplayInformation; } });
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.ModelCompileData = void 0;
|
|
4
|
-
var ModelCompileData = /** @class */ (function () {
|
|
5
|
-
function ModelCompileData(optimizer, loss, metrics) {
|
|
6
|
-
this.optimizer = optimizer;
|
|
7
|
-
this.loss = loss;
|
|
8
|
-
this.metrics = metrics;
|
|
9
|
-
}
|
|
10
|
-
return ModelCompileData;
|
|
11
|
-
}());
|
|
12
|
-
exports.ModelCompileData = ModelCompileData;
|
package/dist/tasks/cifar10.d.ts
DELETED
package/dist/tasks/cifar10.js
DELETED
|
@@ -1,76 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.model = exports.task = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
|
|
6
|
-
exports.task = {
|
|
7
|
-
taskID: 'cifar10',
|
|
8
|
-
displayInformation: {
|
|
9
|
-
taskTitle: 'CIFAR10',
|
|
10
|
-
summary: {
|
|
11
|
-
preview: 'In this challenge, we ask you to classify images into categories based on the objects shown on the image.',
|
|
12
|
-
overview: 'The CIFAR-10 dataset is a collection of images that are commonly used to train machine learning and computer vision algorithms. It is one of the most widely used datasets for machine learning research.'
|
|
13
|
-
},
|
|
14
|
-
limitations: 'The training data is limited to small images of size 32x32.',
|
|
15
|
-
tradeoffs: 'Training success strongly depends on label distribution',
|
|
16
|
-
dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The label file should be .csv, where each row contains a file_name, class. <br> <br> e.g. if you have images: 0.png (of a frog) and 1.png (of a car) <br> labels.csv contains: (Note that no header is needed)<br> 0.png, frog <br> 1.png, car',
|
|
17
|
-
dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.',
|
|
18
|
-
dataExampleImage: './cifar10-example.png'
|
|
19
|
-
},
|
|
20
|
-
trainingInformation: {
|
|
21
|
-
modelID: 'cifar10-model',
|
|
22
|
-
epochs: 10,
|
|
23
|
-
roundDuration: 10,
|
|
24
|
-
validationSplit: 0.2,
|
|
25
|
-
batchSize: 10,
|
|
26
|
-
modelCompileData: {
|
|
27
|
-
optimizer: 'sgd',
|
|
28
|
-
loss: 'categoricalCrossentropy',
|
|
29
|
-
metrics: ['accuracy']
|
|
30
|
-
},
|
|
31
|
-
threshold: 1,
|
|
32
|
-
dataType: 'image',
|
|
33
|
-
csvLabels: true,
|
|
34
|
-
IMAGE_H: 32,
|
|
35
|
-
IMAGE_W: 32,
|
|
36
|
-
preprocessFunctions: ['resize'],
|
|
37
|
-
RESIZED_IMAGE_H: 224,
|
|
38
|
-
RESIZED_IMAGE_W: 224,
|
|
39
|
-
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
40
|
-
LABEL_ASSIGNMENT: [
|
|
41
|
-
{ columnName: 'airplane', columnData: 0 },
|
|
42
|
-
{ columnName: 'automobile', columnData: 1 },
|
|
43
|
-
{ columnName: 'bird', columnData: 2 },
|
|
44
|
-
{ columnName: 'cat', columnData: 3 },
|
|
45
|
-
{ columnName: 'deer', columnData: 4 },
|
|
46
|
-
{ columnName: 'dog', columnData: 5 },
|
|
47
|
-
{ columnName: 'frog', columnData: 6 },
|
|
48
|
-
{ columnName: 'horse', columnData: 7 },
|
|
49
|
-
{ columnName: 'ship', columnData: 8 },
|
|
50
|
-
{ columnName: 'truck', columnData: 9 }
|
|
51
|
-
],
|
|
52
|
-
scheme: 'Decentralized'
|
|
53
|
-
}
|
|
54
|
-
};
|
|
55
|
-
function model() {
|
|
56
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
57
|
-
var mobilenet, x, predictions;
|
|
58
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
59
|
-
switch (_a.label) {
|
|
60
|
-
case 0: return [4 /*yield*/, tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json')];
|
|
61
|
-
case 1:
|
|
62
|
-
mobilenet = _a.sent();
|
|
63
|
-
x = mobilenet.getLayer('global_average_pooling2d_1');
|
|
64
|
-
predictions = tf.layers
|
|
65
|
-
.dense({ units: 10, activation: 'softmax', name: 'denseModified' })
|
|
66
|
-
.apply(x.output);
|
|
67
|
-
return [2 /*return*/, tf.model({
|
|
68
|
-
inputs: mobilenet.input,
|
|
69
|
-
outputs: predictions,
|
|
70
|
-
name: 'modelModified'
|
|
71
|
-
})];
|
|
72
|
-
}
|
|
73
|
-
});
|
|
74
|
-
});
|
|
75
|
-
}
|
|
76
|
-
exports.model = model;
|