@epfml/discojs 0.0.1 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +12 -11
- package/dist/aggregation.d.ts +4 -2
- package/dist/aggregation.js +20 -6
- package/dist/client/base.d.ts +1 -1
- package/dist/client/decentralized/base.d.ts +43 -0
- package/dist/client/decentralized/base.js +243 -0
- package/dist/client/decentralized/clear_text.d.ts +13 -0
- package/dist/client/decentralized/clear_text.js +78 -0
- package/dist/client/decentralized/index.d.ts +4 -0
- package/dist/client/decentralized/index.js +9 -0
- package/dist/client/decentralized/messages.d.ts +37 -0
- package/dist/client/decentralized/messages.js +15 -0
- package/dist/client/decentralized/sec_agg.d.ts +18 -0
- package/dist/client/decentralized/sec_agg.js +169 -0
- package/dist/client/decentralized/secret_shares.d.ts +5 -0
- package/dist/client/decentralized/secret_shares.js +58 -0
- package/dist/client/decentralized/types.d.ts +1 -0
- package/dist/client/decentralized/types.js +2 -0
- package/dist/client/federated.d.ts +4 -4
- package/dist/client/federated.js +5 -8
- package/dist/client/index.d.ts +1 -1
- package/dist/client/index.js +3 -3
- package/dist/client/local.js +5 -3
- package/dist/dataset/data_loader/data_loader.d.ts +7 -1
- package/dist/dataset/data_loader/image_loader.d.ts +5 -3
- package/dist/dataset/data_loader/image_loader.js +64 -18
- package/dist/dataset/data_loader/index.d.ts +1 -1
- package/dist/dataset/data_loader/tabular_loader.d.ts +3 -3
- package/dist/dataset/data_loader/tabular_loader.js +27 -18
- package/dist/dataset/dataset_builder.d.ts +3 -2
- package/dist/dataset/dataset_builder.js +29 -17
- package/dist/index.d.ts +6 -4
- package/dist/index.js +13 -5
- package/dist/informant/graph_informant.d.ts +10 -0
- package/dist/informant/graph_informant.js +23 -0
- package/dist/informant/index.d.ts +3 -0
- package/dist/informant/index.js +9 -0
- package/dist/informant/training_informant/base.d.ts +31 -0
- package/dist/informant/training_informant/base.js +82 -0
- package/dist/informant/training_informant/decentralized.d.ts +5 -0
- package/dist/informant/training_informant/decentralized.js +22 -0
- package/dist/informant/training_informant/federated.d.ts +14 -0
- package/dist/informant/training_informant/federated.js +32 -0
- package/dist/informant/training_informant/index.d.ts +4 -0
- package/dist/informant/training_informant/index.js +11 -0
- package/dist/informant/training_informant/local.d.ts +6 -0
- package/dist/informant/training_informant/local.js +20 -0
- package/dist/logging/index.d.ts +1 -0
- package/dist/logging/index.js +3 -1
- package/dist/logging/trainer_logger.d.ts +1 -1
- package/dist/logging/trainer_logger.js +5 -5
- package/dist/memory/base.d.ts +17 -48
- package/dist/memory/empty.d.ts +6 -4
- package/dist/memory/empty.js +8 -2
- package/dist/memory/index.d.ts +1 -1
- package/dist/privacy.js +3 -3
- package/dist/serialization/model.d.ts +1 -1
- package/dist/serialization/model.js +2 -2
- package/dist/serialization/weights.js +2 -2
- package/dist/task/display_information.d.ts +2 -2
- package/dist/task/display_information.js +6 -5
- package/dist/task/summary.d.ts +5 -0
- package/dist/task/summary.js +23 -0
- package/dist/task/training_information.d.ts +3 -0
- package/dist/tasks/cifar10.js +5 -3
- package/dist/tasks/lus_covid.d.ts +1 -1
- package/dist/tasks/lus_covid.js +50 -13
- package/dist/tasks/mnist.js +4 -2
- package/dist/tasks/simple_face.d.ts +1 -1
- package/dist/tasks/simple_face.js +13 -17
- package/dist/tasks/titanic.js +9 -7
- package/dist/tfjs.d.ts +2 -0
- package/dist/tfjs.js +6 -0
- package/dist/training/disco.d.ts +3 -1
- package/dist/training/disco.js +14 -6
- package/dist/training/trainer/distributed_trainer.d.ts +1 -1
- package/dist/training/trainer/distributed_trainer.js +5 -1
- package/dist/training/trainer/local_trainer.d.ts +4 -3
- package/dist/training/trainer/local_trainer.js +6 -9
- package/dist/training/trainer/round_tracker.js +3 -0
- package/dist/training/trainer/trainer.d.ts +15 -15
- package/dist/training/trainer/trainer.js +57 -43
- package/dist/training/trainer/trainer_builder.js +8 -15
- package/dist/types.d.ts +1 -1
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +5 -0
- package/dist/validation/validator.d.ts +20 -0
- package/dist/validation/validator.js +106 -0
- package/package.json +2 -3
- package/dist/client/decentralized.d.ts +0 -23
- package/dist/client/decentralized.js +0 -275
- package/dist/testing/tester.d.ts +0 -5
- package/dist/testing/tester.js +0 -21
- package/dist/training_informant.d.ts +0 -88
- package/dist/training_informant.js +0 -135
|
@@ -3,12 +3,12 @@ import { Memory, Task, TrainingInformant, TrainingInformation } from '@/.';
|
|
|
3
3
|
import { RoundTracker } from './round_tracker';
|
|
4
4
|
import { TrainerLog } from '../../logging/trainer_logger';
|
|
5
5
|
/** Abstract class whose role is to train a model with a given dataset. This can be either done
|
|
6
|
-
* locally or in a distributed way. The Trainer works as follows:
|
|
6
|
+
* locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
|
|
7
7
|
*
|
|
8
8
|
* 1. Call trainModel(dataset) to start training
|
|
9
|
-
* 2. Once a batch ends, onBatchEnd is triggered,
|
|
9
|
+
* 2. Once a batch ends, onBatchEnd is triggered, which will then call onRoundEnd once the round has ended.
|
|
10
10
|
*
|
|
11
|
-
* The onRoundEnd needs to be implemented to specify what actions to do when the round has ended. To know when
|
|
11
|
+
* The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators. To know when
|
|
12
12
|
* a round has ended we use the roundTracker object.
|
|
13
13
|
*/
|
|
14
14
|
export declare abstract class Trainer {
|
|
@@ -30,10 +30,17 @@ export declare abstract class Trainer {
|
|
|
30
30
|
* Every time a round ends this function will be called
|
|
31
31
|
*/
|
|
32
32
|
protected abstract onRoundEnd(accuracy: number): Promise<void>;
|
|
33
|
+
/** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
|
|
34
|
+
*/
|
|
35
|
+
protected onBatchEnd(_: number, logs?: tf.Logs): Promise<void>;
|
|
36
|
+
/**
|
|
37
|
+
* We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
|
|
38
|
+
*/
|
|
39
|
+
protected onEpochEnd(epoch: number, logs?: tf.Logs): void;
|
|
33
40
|
/**
|
|
34
41
|
* When the training ends this function will be call
|
|
35
42
|
*/
|
|
36
|
-
protected
|
|
43
|
+
protected onTrainEnd(logs?: tf.Logs): Promise<void>;
|
|
37
44
|
/**
|
|
38
45
|
* Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
|
|
39
46
|
*/
|
|
@@ -42,25 +49,18 @@ export declare abstract class Trainer {
|
|
|
42
49
|
* Start training the model with the given dataset
|
|
43
50
|
* @param dataset
|
|
44
51
|
*/
|
|
45
|
-
trainModel(dataset: tf.data.Dataset<tf.TensorContainer>): Promise<void>;
|
|
52
|
+
trainModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): Promise<void>;
|
|
46
53
|
/**
|
|
47
54
|
* Format accuracy
|
|
48
55
|
*/
|
|
49
|
-
|
|
50
|
-
/**
|
|
51
|
-
* We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
|
|
52
|
-
*/
|
|
53
|
-
private onEpochEnd;
|
|
54
|
-
/** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
|
|
55
|
-
*/
|
|
56
|
-
private onBatchEnd;
|
|
56
|
+
protected roundDecimals(accuracy: number, decimalsToRound?: number): number;
|
|
57
57
|
/**
|
|
58
58
|
* reset stop training state
|
|
59
59
|
*/
|
|
60
|
-
|
|
60
|
+
protected resetStopTrainerState(): void;
|
|
61
61
|
/**
|
|
62
62
|
* If stop training is requested, do so
|
|
63
63
|
*/
|
|
64
|
-
|
|
64
|
+
protected stopTrainModelIfRequested(): void;
|
|
65
65
|
getTrainerLog(): TrainerLog;
|
|
66
66
|
}
|
|
@@ -5,12 +5,12 @@ var tslib_1 = require("tslib");
|
|
|
5
5
|
var round_tracker_1 = require("./round_tracker");
|
|
6
6
|
var trainer_logger_1 = require("../../logging/trainer_logger");
|
|
7
7
|
/** Abstract class whose role is to train a model with a given dataset. This can be either done
|
|
8
|
-
* locally or in a distributed way. The Trainer works as follows:
|
|
8
|
+
* locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
|
|
9
9
|
*
|
|
10
10
|
* 1. Call trainModel(dataset) to start training
|
|
11
|
-
* 2. Once a batch ends, onBatchEnd is triggered,
|
|
11
|
+
* 2. Once a batch ends, onBatchEnd is triggered, which will then call onRoundEnd once the round has ended.
|
|
12
12
|
*
|
|
13
|
-
* The onRoundEnd needs to be implemented to specify what actions to do when the round has ended. To know when
|
|
13
|
+
* The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators. To know when
|
|
14
14
|
* a round has ended we use the roundTracker object.
|
|
15
15
|
*/
|
|
16
16
|
var Trainer = /** @class */ (function () {
|
|
@@ -33,9 +33,52 @@ var Trainer = /** @class */ (function () {
|
|
|
33
33
|
this.trainingInformation = trainingInformation;
|
|
34
34
|
this.roundTracker = new round_tracker_1.RoundTracker(trainingInformation.roundDuration);
|
|
35
35
|
}
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
36
|
+
/** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
|
|
37
|
+
*/
|
|
38
|
+
Trainer.prototype.onBatchEnd = function (_, logs) {
|
|
39
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
40
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
41
|
+
switch (_a.label) {
|
|
42
|
+
case 0:
|
|
43
|
+
if (logs === undefined) {
|
|
44
|
+
return [2 /*return*/];
|
|
45
|
+
}
|
|
46
|
+
this.roundTracker.updateBatch();
|
|
47
|
+
this.stopTrainModelIfRequested();
|
|
48
|
+
if (!this.roundTracker.roundHasEnded()) return [3 /*break*/, 2];
|
|
49
|
+
return [4 /*yield*/, this.onRoundEnd(logs.acc)];
|
|
50
|
+
case 1:
|
|
51
|
+
_a.sent();
|
|
52
|
+
_a.label = 2;
|
|
53
|
+
case 2: return [2 /*return*/];
|
|
54
|
+
}
|
|
55
|
+
});
|
|
56
|
+
});
|
|
57
|
+
};
|
|
58
|
+
/**
|
|
59
|
+
* We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
|
|
60
|
+
*/
|
|
61
|
+
Trainer.prototype.onEpochEnd = function (epoch, logs) {
|
|
62
|
+
this.trainerLogger.onEpochEnd(epoch, logs);
|
|
63
|
+
if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) {
|
|
64
|
+
this.trainingInformant.updateTrainingGraph(this.roundDecimals(logs.acc));
|
|
65
|
+
this.trainingInformant.updateValidationGraph(this.roundDecimals(logs.val_acc));
|
|
66
|
+
}
|
|
67
|
+
else {
|
|
68
|
+
this.trainerLogger.error('onEpochEnd: NaN value');
|
|
69
|
+
}
|
|
70
|
+
};
|
|
71
|
+
/**
|
|
72
|
+
* When the training ends this function will be call
|
|
73
|
+
*/
|
|
74
|
+
Trainer.prototype.onTrainEnd = function (logs) {
|
|
75
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
76
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
77
|
+
this.trainingInformant.addMessage('Training finished.');
|
|
78
|
+
return [2 /*return*/];
|
|
79
|
+
});
|
|
80
|
+
});
|
|
81
|
+
};
|
|
39
82
|
/**
|
|
40
83
|
* Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
|
|
41
84
|
*/
|
|
@@ -51,7 +94,7 @@ var Trainer = /** @class */ (function () {
|
|
|
51
94
|
* Start training the model with the given dataset
|
|
52
95
|
* @param dataset
|
|
53
96
|
*/
|
|
54
|
-
Trainer.prototype.trainModel = function (dataset) {
|
|
97
|
+
Trainer.prototype.trainModel = function (dataset, valDataset) {
|
|
55
98
|
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
56
99
|
var _this = this;
|
|
57
100
|
return (0, tslib_1.__generator)(this, function (_a) {
|
|
@@ -61,7 +104,7 @@ var Trainer = /** @class */ (function () {
|
|
|
61
104
|
// Assign callbacks and start training
|
|
62
105
|
return [4 /*yield*/, this.model.fitDataset(dataset.batch(this.trainingInformation.batchSize), {
|
|
63
106
|
epochs: this.trainingInformation.epochs,
|
|
64
|
-
validationData:
|
|
107
|
+
validationData: valDataset.batch(this.trainingInformation.batchSize),
|
|
65
108
|
callbacks: {
|
|
66
109
|
onEpochEnd: function (epoch, logs) { return _this.onEpochEnd(epoch, logs); },
|
|
67
110
|
onBatchEnd: function (epoch, logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
@@ -69,6 +112,12 @@ var Trainer = /** @class */ (function () {
|
|
|
69
112
|
case 0: return [4 /*yield*/, this.onBatchEnd(epoch, logs)];
|
|
70
113
|
case 1: return [2 /*return*/, _a.sent()];
|
|
71
114
|
}
|
|
115
|
+
}); }); },
|
|
116
|
+
onTrainEnd: function (logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
117
|
+
switch (_a.label) {
|
|
118
|
+
case 0: return [4 /*yield*/, this.onTrainEnd(logs)];
|
|
119
|
+
case 1: return [2 /*return*/, _a.sent()];
|
|
120
|
+
}
|
|
72
121
|
}); }); }
|
|
73
122
|
}
|
|
74
123
|
})];
|
|
@@ -87,41 +136,6 @@ var Trainer = /** @class */ (function () {
|
|
|
87
136
|
if (decimalsToRound === void 0) { decimalsToRound = 2; }
|
|
88
137
|
return +(accuracy * 100).toFixed(decimalsToRound);
|
|
89
138
|
};
|
|
90
|
-
/**
|
|
91
|
-
* We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
|
|
92
|
-
*/
|
|
93
|
-
Trainer.prototype.onEpochEnd = function (epoch, logs) {
|
|
94
|
-
this.trainerLogger.onEpochEnd(epoch, logs);
|
|
95
|
-
if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) {
|
|
96
|
-
this.trainingInformant.updateTrainingAccuracyGraph(this.roundDecimals(logs.acc));
|
|
97
|
-
this.trainingInformant.updateValidationAccuracyGraph(this.roundDecimals(logs.val_acc));
|
|
98
|
-
}
|
|
99
|
-
else {
|
|
100
|
-
this.trainerLogger.error('onEpochEnd: NaN value');
|
|
101
|
-
}
|
|
102
|
-
};
|
|
103
|
-
/** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
|
|
104
|
-
*/
|
|
105
|
-
Trainer.prototype.onBatchEnd = function (_, logs) {
|
|
106
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
107
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
108
|
-
switch (_a.label) {
|
|
109
|
-
case 0:
|
|
110
|
-
if (logs === undefined) {
|
|
111
|
-
return [2 /*return*/];
|
|
112
|
-
}
|
|
113
|
-
this.roundTracker.updateBatch();
|
|
114
|
-
this.stopTrainModelIfRequested();
|
|
115
|
-
if (!this.roundTracker.roundHasEnded()) return [3 /*break*/, 2];
|
|
116
|
-
return [4 /*yield*/, this.onRoundEnd(logs.acc)];
|
|
117
|
-
case 1:
|
|
118
|
-
_a.sent();
|
|
119
|
-
_a.label = 2;
|
|
120
|
-
case 2: return [2 /*return*/];
|
|
121
|
-
}
|
|
122
|
-
});
|
|
123
|
-
});
|
|
124
|
-
};
|
|
125
139
|
/**
|
|
126
140
|
* reset stop training state
|
|
127
141
|
*/
|
|
@@ -48,28 +48,21 @@ var TrainerBuilder = /** @class */ (function () {
|
|
|
48
48
|
TrainerBuilder.prototype.getModel = function (client) {
|
|
49
49
|
var _a;
|
|
50
50
|
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
51
|
-
var modelID,
|
|
51
|
+
var modelID, info, model;
|
|
52
52
|
return (0, tslib_1.__generator)(this, function (_b) {
|
|
53
53
|
switch (_b.label) {
|
|
54
54
|
case 0:
|
|
55
55
|
modelID = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.modelID;
|
|
56
56
|
if (modelID === undefined) {
|
|
57
|
-
throw new
|
|
57
|
+
throw new TypeError('model ID is undefined');
|
|
58
58
|
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if (!(modelExistsInMemory !== undefined)) return [3 /*break*/, 3];
|
|
63
|
-
return [4 /*yield*/, this.memory.getModel(__1.ModelType.WORKING, this.task.taskID, modelID)];
|
|
59
|
+
info = { type: __1.ModelType.WORKING, taskID: this.task.taskID, name: modelID };
|
|
60
|
+
return [4 /*yield*/, this.memory.contains(info)];
|
|
61
|
+
case 1: return [4 /*yield*/, ((_b.sent()) ? this.memory.getModel(info) : client.getLatestModel())];
|
|
64
62
|
case 2:
|
|
65
63
|
model = _b.sent();
|
|
66
|
-
return [
|
|
67
|
-
case 3: return [
|
|
68
|
-
case 4:
|
|
69
|
-
model = _b.sent();
|
|
70
|
-
_b.label = 5;
|
|
71
|
-
case 5: return [4 /*yield*/, this.updateModelInformation(model)];
|
|
72
|
-
case 6: return [2 /*return*/, _b.sent()];
|
|
64
|
+
return [4 /*yield*/, this.updateModelInformation(model)];
|
|
65
|
+
case 3: return [2 /*return*/, _b.sent()];
|
|
73
66
|
}
|
|
74
67
|
});
|
|
75
68
|
});
|
|
@@ -84,7 +77,7 @@ var TrainerBuilder = /** @class */ (function () {
|
|
|
84
77
|
}
|
|
85
78
|
info = this.task.trainingInformation;
|
|
86
79
|
if (info === undefined) {
|
|
87
|
-
throw new
|
|
80
|
+
throw new TypeError('training information is undefined');
|
|
88
81
|
}
|
|
89
82
|
model.compile(info.modelCompileData);
|
|
90
83
|
if (info.learningRate !== undefined) {
|
package/dist/types.d.ts
CHANGED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export { Validator } from './validator';
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Validator = void 0;
|
|
4
|
+
var validator_1 = require("./validator");
|
|
5
|
+
Object.defineProperty(exports, "Validator", { enumerable: true, get: function () { return validator_1.Validator; } });
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { ModelActor } from '../model_actor';
|
|
3
|
+
import { Task } from '@/task';
|
|
4
|
+
import { Data } from '@/dataset';
|
|
5
|
+
import { Logger } from '@/logging';
|
|
6
|
+
import { List } from 'immutable';
|
|
7
|
+
import { Client, Memory, ModelSource } from '..';
|
|
8
|
+
export declare class Validator extends ModelActor {
|
|
9
|
+
private readonly memory;
|
|
10
|
+
private readonly source?;
|
|
11
|
+
private readonly client?;
|
|
12
|
+
private readonly graphInformant;
|
|
13
|
+
private size;
|
|
14
|
+
constructor(task: Task, logger: Logger, memory: Memory, source?: ModelSource | undefined, client?: Client | undefined);
|
|
15
|
+
assess(data: Data): Promise<void>;
|
|
16
|
+
getModel(): Promise<tf.LayersModel>;
|
|
17
|
+
accuracyData(): List<number>;
|
|
18
|
+
accuracy(): number;
|
|
19
|
+
visitedSamples(): number;
|
|
20
|
+
}
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Validator = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var model_actor_1 = require("../model_actor");
|
|
6
|
+
var immutable_1 = require("immutable");
|
|
7
|
+
var __1 = require("..");
|
|
8
|
+
var Validator = /** @class */ (function (_super) {
|
|
9
|
+
(0, tslib_1.__extends)(Validator, _super);
|
|
10
|
+
function Validator(task, logger, memory, source, client) {
|
|
11
|
+
var _this = _super.call(this, task, logger) || this;
|
|
12
|
+
_this.memory = memory;
|
|
13
|
+
_this.source = source;
|
|
14
|
+
_this.client = client;
|
|
15
|
+
_this.graphInformant = new __1.GraphInformant();
|
|
16
|
+
_this.size = 0;
|
|
17
|
+
if (source === undefined && client === undefined) {
|
|
18
|
+
throw new Error('cannot identify model');
|
|
19
|
+
}
|
|
20
|
+
return _this;
|
|
21
|
+
}
|
|
22
|
+
Validator.prototype.assess = function (data) {
|
|
23
|
+
var _a, _b, _c;
|
|
24
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
25
|
+
var batchSize, labels, classes, model, hits;
|
|
26
|
+
var _this = this;
|
|
27
|
+
return (0, tslib_1.__generator)(this, function (_d) {
|
|
28
|
+
switch (_d.label) {
|
|
29
|
+
case 0:
|
|
30
|
+
batchSize = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.batchSize;
|
|
31
|
+
if (batchSize === undefined) {
|
|
32
|
+
throw new TypeError('batch size is undefined');
|
|
33
|
+
}
|
|
34
|
+
labels = (_b = this.task.trainingInformation) === null || _b === void 0 ? void 0 : _b.LABEL_LIST;
|
|
35
|
+
classes = (_c = labels === null || labels === void 0 ? void 0 : labels.length) !== null && _c !== void 0 ? _c : 1;
|
|
36
|
+
return [4 /*yield*/, this.getModel()];
|
|
37
|
+
case 1:
|
|
38
|
+
model = _d.sent();
|
|
39
|
+
hits = 0;
|
|
40
|
+
return [4 /*yield*/, data.dataset.batch(batchSize).forEachAsync(function (e) {
|
|
41
|
+
if (typeof e === 'object' && 'xs' in e && 'ys' in e) {
|
|
42
|
+
var xs = e.xs;
|
|
43
|
+
var ys = e.ys.dataSync();
|
|
44
|
+
var pred = model.predict(xs, { batchSize: batchSize })
|
|
45
|
+
.dataSync()
|
|
46
|
+
.map(Math.round);
|
|
47
|
+
_this.size += xs.shape[0];
|
|
48
|
+
hits += (0, immutable_1.List)(pred).zip((0, immutable_1.List)(ys))
|
|
49
|
+
.map(function (_a) {
|
|
50
|
+
var _b = (0, tslib_1.__read)(_a, 2), p = _b[0], y = _b[1];
|
|
51
|
+
return 1 - Math.abs(p - y);
|
|
52
|
+
})
|
|
53
|
+
.reduce(function (acc, e) { return acc + e; }) / classes;
|
|
54
|
+
var currentAccuracy = hits / _this.size;
|
|
55
|
+
_this.graphInformant.updateAccuracy(currentAccuracy);
|
|
56
|
+
}
|
|
57
|
+
else {
|
|
58
|
+
throw new TypeError('missing feature/label in dataset');
|
|
59
|
+
}
|
|
60
|
+
})];
|
|
61
|
+
case 2:
|
|
62
|
+
_d.sent();
|
|
63
|
+
this.logger.success("Obtained validation accuracy of " + this.accuracy());
|
|
64
|
+
this.logger.success("Visited " + this.visitedSamples() + " samples");
|
|
65
|
+
return [2 /*return*/];
|
|
66
|
+
}
|
|
67
|
+
});
|
|
68
|
+
});
|
|
69
|
+
};
|
|
70
|
+
Validator.prototype.getModel = function () {
|
|
71
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
72
|
+
var _a;
|
|
73
|
+
return (0, tslib_1.__generator)(this, function (_b) {
|
|
74
|
+
switch (_b.label) {
|
|
75
|
+
case 0:
|
|
76
|
+
_a = this.source !== undefined;
|
|
77
|
+
if (!_a) return [3 /*break*/, 2];
|
|
78
|
+
return [4 /*yield*/, this.memory.contains(this.source)];
|
|
79
|
+
case 1:
|
|
80
|
+
_a = (_b.sent());
|
|
81
|
+
_b.label = 2;
|
|
82
|
+
case 2:
|
|
83
|
+
if (!_a) return [3 /*break*/, 4];
|
|
84
|
+
return [4 /*yield*/, this.memory.getModel(this.source)];
|
|
85
|
+
case 3: return [2 /*return*/, _b.sent()];
|
|
86
|
+
case 4:
|
|
87
|
+
if (!(this.client !== undefined)) return [3 /*break*/, 6];
|
|
88
|
+
return [4 /*yield*/, this.client.getLatestModel()];
|
|
89
|
+
case 5: return [2 /*return*/, _b.sent()];
|
|
90
|
+
case 6: throw new Error('cannot identify model');
|
|
91
|
+
}
|
|
92
|
+
});
|
|
93
|
+
});
|
|
94
|
+
};
|
|
95
|
+
Validator.prototype.accuracyData = function () {
|
|
96
|
+
return this.graphInformant.data();
|
|
97
|
+
};
|
|
98
|
+
Validator.prototype.accuracy = function () {
|
|
99
|
+
return this.graphInformant.accuracy();
|
|
100
|
+
};
|
|
101
|
+
Validator.prototype.visitedSamples = function () {
|
|
102
|
+
return this.size;
|
|
103
|
+
};
|
|
104
|
+
return Validator;
|
|
105
|
+
}(model_actor_1.ModelActor));
|
|
106
|
+
exports.Validator = Validator;
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@epfml/discojs",
|
|
3
|
-
"version": "0.0
|
|
3
|
+
"version": "0.1.0",
|
|
4
4
|
"main": "dist/index.js",
|
|
5
5
|
"types": "dist/index.d.ts",
|
|
6
6
|
"scripts": {
|
|
@@ -16,7 +16,6 @@
|
|
|
16
16
|
},
|
|
17
17
|
"homepage": "https://github.com/epfml/disco#readme",
|
|
18
18
|
"dependencies": {
|
|
19
|
-
"@tensorflow/tfjs": "3",
|
|
20
19
|
"axios": "0.27",
|
|
21
20
|
"immutable": "4",
|
|
22
21
|
"isomorphic-ws": "4",
|
|
@@ -45,4 +44,4 @@
|
|
|
45
44
|
"ts-node": "10",
|
|
46
45
|
"typescript": "<4.5.0"
|
|
47
46
|
}
|
|
48
|
-
}
|
|
47
|
+
}
|
|
@@ -1,23 +0,0 @@
|
|
|
1
|
-
import { TrainingInformant, Weights } from '..';
|
|
2
|
-
import { Base } from './base';
|
|
3
|
-
/**
|
|
4
|
-
* Class that deals with communication with the PeerJS server.
|
|
5
|
-
* Collects the list of receivers currently connected to the PeerJS server.
|
|
6
|
-
*/
|
|
7
|
-
export declare class Decentralized extends Base {
|
|
8
|
-
private server?;
|
|
9
|
-
private peers;
|
|
10
|
-
private readonly weights;
|
|
11
|
-
private connectServer;
|
|
12
|
-
private connectNewPeer;
|
|
13
|
-
/**
|
|
14
|
-
* Initialize the connection to the peers and to the other nodes.
|
|
15
|
-
*/
|
|
16
|
-
connect(): Promise<void>;
|
|
17
|
-
/**
|
|
18
|
-
* Disconnection process when user quits the task.
|
|
19
|
-
*/
|
|
20
|
-
disconnect(): Promise<void>;
|
|
21
|
-
onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, epoch: number, trainingInformant: TrainingInformant): Promise<Weights>;
|
|
22
|
-
onTrainEndCommunication(_: Weights, trainingInformant: TrainingInformant): Promise<void>;
|
|
23
|
-
}
|