@epfml/discojs 0.0.1 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. package/README.md +12 -11
  2. package/dist/aggregation.d.ts +4 -2
  3. package/dist/aggregation.js +20 -6
  4. package/dist/client/base.d.ts +1 -1
  5. package/dist/client/decentralized/base.d.ts +43 -0
  6. package/dist/client/decentralized/base.js +243 -0
  7. package/dist/client/decentralized/clear_text.d.ts +13 -0
  8. package/dist/client/decentralized/clear_text.js +78 -0
  9. package/dist/client/decentralized/index.d.ts +4 -0
  10. package/dist/client/decentralized/index.js +9 -0
  11. package/dist/client/decentralized/messages.d.ts +37 -0
  12. package/dist/client/decentralized/messages.js +15 -0
  13. package/dist/client/decentralized/sec_agg.d.ts +18 -0
  14. package/dist/client/decentralized/sec_agg.js +169 -0
  15. package/dist/client/decentralized/secret_shares.d.ts +5 -0
  16. package/dist/client/decentralized/secret_shares.js +58 -0
  17. package/dist/client/decentralized/types.d.ts +1 -0
  18. package/dist/client/decentralized/types.js +2 -0
  19. package/dist/client/federated.d.ts +4 -4
  20. package/dist/client/federated.js +5 -8
  21. package/dist/client/index.d.ts +1 -1
  22. package/dist/client/index.js +3 -3
  23. package/dist/client/local.js +5 -3
  24. package/dist/dataset/data_loader/data_loader.d.ts +7 -1
  25. package/dist/dataset/data_loader/image_loader.d.ts +5 -3
  26. package/dist/dataset/data_loader/image_loader.js +64 -18
  27. package/dist/dataset/data_loader/index.d.ts +1 -1
  28. package/dist/dataset/data_loader/tabular_loader.d.ts +3 -3
  29. package/dist/dataset/data_loader/tabular_loader.js +27 -18
  30. package/dist/dataset/dataset_builder.d.ts +3 -2
  31. package/dist/dataset/dataset_builder.js +29 -17
  32. package/dist/index.d.ts +6 -4
  33. package/dist/index.js +13 -5
  34. package/dist/informant/graph_informant.d.ts +10 -0
  35. package/dist/informant/graph_informant.js +23 -0
  36. package/dist/informant/index.d.ts +3 -0
  37. package/dist/informant/index.js +9 -0
  38. package/dist/informant/training_informant/base.d.ts +31 -0
  39. package/dist/informant/training_informant/base.js +82 -0
  40. package/dist/informant/training_informant/decentralized.d.ts +5 -0
  41. package/dist/informant/training_informant/decentralized.js +22 -0
  42. package/dist/informant/training_informant/federated.d.ts +14 -0
  43. package/dist/informant/training_informant/federated.js +32 -0
  44. package/dist/informant/training_informant/index.d.ts +4 -0
  45. package/dist/informant/training_informant/index.js +11 -0
  46. package/dist/informant/training_informant/local.d.ts +6 -0
  47. package/dist/informant/training_informant/local.js +20 -0
  48. package/dist/logging/index.d.ts +1 -0
  49. package/dist/logging/index.js +3 -1
  50. package/dist/logging/trainer_logger.d.ts +1 -1
  51. package/dist/logging/trainer_logger.js +5 -5
  52. package/dist/memory/base.d.ts +17 -48
  53. package/dist/memory/empty.d.ts +6 -4
  54. package/dist/memory/empty.js +8 -2
  55. package/dist/memory/index.d.ts +1 -1
  56. package/dist/privacy.js +3 -3
  57. package/dist/serialization/model.d.ts +1 -1
  58. package/dist/serialization/model.js +2 -2
  59. package/dist/serialization/weights.js +2 -2
  60. package/dist/task/display_information.d.ts +2 -2
  61. package/dist/task/display_information.js +6 -5
  62. package/dist/task/summary.d.ts +5 -0
  63. package/dist/task/summary.js +23 -0
  64. package/dist/task/training_information.d.ts +3 -0
  65. package/dist/tasks/cifar10.js +5 -3
  66. package/dist/tasks/lus_covid.d.ts +1 -1
  67. package/dist/tasks/lus_covid.js +50 -13
  68. package/dist/tasks/mnist.js +4 -2
  69. package/dist/tasks/simple_face.d.ts +1 -1
  70. package/dist/tasks/simple_face.js +13 -17
  71. package/dist/tasks/titanic.js +9 -7
  72. package/dist/tfjs.d.ts +2 -0
  73. package/dist/tfjs.js +6 -0
  74. package/dist/training/disco.d.ts +3 -1
  75. package/dist/training/disco.js +14 -6
  76. package/dist/training/trainer/distributed_trainer.d.ts +1 -1
  77. package/dist/training/trainer/distributed_trainer.js +5 -1
  78. package/dist/training/trainer/local_trainer.d.ts +4 -3
  79. package/dist/training/trainer/local_trainer.js +6 -9
  80. package/dist/training/trainer/round_tracker.js +3 -0
  81. package/dist/training/trainer/trainer.d.ts +15 -15
  82. package/dist/training/trainer/trainer.js +57 -43
  83. package/dist/training/trainer/trainer_builder.js +8 -15
  84. package/dist/types.d.ts +1 -1
  85. package/dist/validation/index.d.ts +1 -0
  86. package/dist/validation/index.js +5 -0
  87. package/dist/validation/validator.d.ts +20 -0
  88. package/dist/validation/validator.js +106 -0
  89. package/package.json +2 -3
  90. package/dist/client/decentralized.d.ts +0 -23
  91. package/dist/client/decentralized.js +0 -275
  92. package/dist/testing/tester.d.ts +0 -5
  93. package/dist/testing/tester.js +0 -21
  94. package/dist/training_informant.d.ts +0 -88
  95. package/dist/training_informant.js +0 -135
@@ -3,12 +3,12 @@ import { Memory, Task, TrainingInformant, TrainingInformation } from '@/.';
3
3
  import { RoundTracker } from './round_tracker';
4
4
  import { TrainerLog } from '../../logging/trainer_logger';
5
5
  /** Abstract class whose role is to train a model with a given dataset. This can be either done
6
- * locally or in a distributed way. The Trainer works as follows:
6
+ * locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
7
7
  *
8
8
  * 1. Call trainModel(dataset) to start training
9
- * 2. Once a batch ends, onBatchEnd is triggered, witch will then call onRoundEnd once the round has ended.
9
+ * 2. Once a batch ends, onBatchEnd is triggered, which will then call onRoundEnd once the round has ended.
10
10
  *
11
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended. To know when
11
+ * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators. To know when
12
12
  * a round has ended we use the roundTracker object.
13
13
  */
14
14
  export declare abstract class Trainer {
@@ -30,10 +30,17 @@ export declare abstract class Trainer {
30
30
  * Every time a round ends this function will be called
31
31
  */
32
32
  protected abstract onRoundEnd(accuracy: number): Promise<void>;
33
+ /** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
34
+ */
35
+ protected onBatchEnd(_: number, logs?: tf.Logs): Promise<void>;
36
+ /**
37
+ * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
38
+ */
39
+ protected onEpochEnd(epoch: number, logs?: tf.Logs): void;
33
40
  /**
34
41
  * When the training ends this function will be call
35
42
  */
36
- protected abstract onTrainEnd(): Promise<void>;
43
+ protected onTrainEnd(logs?: tf.Logs): Promise<void>;
37
44
  /**
38
45
  * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
39
46
  */
@@ -42,25 +49,18 @@ export declare abstract class Trainer {
42
49
  * Start training the model with the given dataset
43
50
  * @param dataset
44
51
  */
45
- trainModel(dataset: tf.data.Dataset<tf.TensorContainer>): Promise<void>;
52
+ trainModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): Promise<void>;
46
53
  /**
47
54
  * Format accuracy
48
55
  */
49
- private roundDecimals;
50
- /**
51
- * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
52
- */
53
- private onEpochEnd;
54
- /** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
55
- */
56
- private onBatchEnd;
56
+ protected roundDecimals(accuracy: number, decimalsToRound?: number): number;
57
57
  /**
58
58
  * reset stop training state
59
59
  */
60
- private resetStopTrainerState;
60
+ protected resetStopTrainerState(): void;
61
61
  /**
62
62
  * If stop training is requested, do so
63
63
  */
64
- private stopTrainModelIfRequested;
64
+ protected stopTrainModelIfRequested(): void;
65
65
  getTrainerLog(): TrainerLog;
66
66
  }
@@ -5,12 +5,12 @@ var tslib_1 = require("tslib");
5
5
  var round_tracker_1 = require("./round_tracker");
6
6
  var trainer_logger_1 = require("../../logging/trainer_logger");
7
7
  /** Abstract class whose role is to train a model with a given dataset. This can be either done
8
- * locally or in a distributed way. The Trainer works as follows:
8
+ * locally (alone) or in a distributed way with collaborators. The Trainer works as follows:
9
9
  *
10
10
  * 1. Call trainModel(dataset) to start training
11
- * 2. Once a batch ends, onBatchEnd is triggered, witch will then call onRoundEnd once the round has ended.
11
+ * 2. Once a batch ends, onBatchEnd is triggered, which will then call onRoundEnd once the round has ended.
12
12
  *
13
- * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended. To know when
13
+ * The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators. To know when
14
14
  * a round has ended we use the roundTracker object.
15
15
  */
16
16
  var Trainer = /** @class */ (function () {
@@ -33,9 +33,52 @@ var Trainer = /** @class */ (function () {
33
33
  this.trainingInformation = trainingInformation;
34
34
  this.roundTracker = new round_tracker_1.RoundTracker(trainingInformation.roundDuration);
35
35
  }
36
- //* ***************************************************************
37
- //* Functions to be used by the training manager
38
- //* ***************************************************************
36
+ /** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
37
+ */
38
+ Trainer.prototype.onBatchEnd = function (_, logs) {
39
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
40
+ return (0, tslib_1.__generator)(this, function (_a) {
41
+ switch (_a.label) {
42
+ case 0:
43
+ if (logs === undefined) {
44
+ return [2 /*return*/];
45
+ }
46
+ this.roundTracker.updateBatch();
47
+ this.stopTrainModelIfRequested();
48
+ if (!this.roundTracker.roundHasEnded()) return [3 /*break*/, 2];
49
+ return [4 /*yield*/, this.onRoundEnd(logs.acc)];
50
+ case 1:
51
+ _a.sent();
52
+ _a.label = 2;
53
+ case 2: return [2 /*return*/];
54
+ }
55
+ });
56
+ });
57
+ };
58
+ /**
59
+ * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
60
+ */
61
+ Trainer.prototype.onEpochEnd = function (epoch, logs) {
62
+ this.trainerLogger.onEpochEnd(epoch, logs);
63
+ if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) {
64
+ this.trainingInformant.updateTrainingGraph(this.roundDecimals(logs.acc));
65
+ this.trainingInformant.updateValidationGraph(this.roundDecimals(logs.val_acc));
66
+ }
67
+ else {
68
+ this.trainerLogger.error('onEpochEnd: NaN value');
69
+ }
70
+ };
71
+ /**
72
+ * When the training ends this function will be call
73
+ */
74
+ Trainer.prototype.onTrainEnd = function (logs) {
75
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
76
+ return (0, tslib_1.__generator)(this, function (_a) {
77
+ this.trainingInformant.addMessage('Training finished.');
78
+ return [2 /*return*/];
79
+ });
80
+ });
81
+ };
39
82
  /**
40
83
  * Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
41
84
  */
@@ -51,7 +94,7 @@ var Trainer = /** @class */ (function () {
51
94
  * Start training the model with the given dataset
52
95
  * @param dataset
53
96
  */
54
- Trainer.prototype.trainModel = function (dataset) {
97
+ Trainer.prototype.trainModel = function (dataset, valDataset) {
55
98
  return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
56
99
  var _this = this;
57
100
  return (0, tslib_1.__generator)(this, function (_a) {
@@ -61,7 +104,7 @@ var Trainer = /** @class */ (function () {
61
104
  // Assign callbacks and start training
62
105
  return [4 /*yield*/, this.model.fitDataset(dataset.batch(this.trainingInformation.batchSize), {
63
106
  epochs: this.trainingInformation.epochs,
64
- validationData: dataset.batch(this.trainingInformation.batchSize),
107
+ validationData: valDataset.batch(this.trainingInformation.batchSize),
65
108
  callbacks: {
66
109
  onEpochEnd: function (epoch, logs) { return _this.onEpochEnd(epoch, logs); },
67
110
  onBatchEnd: function (epoch, logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
@@ -69,6 +112,12 @@ var Trainer = /** @class */ (function () {
69
112
  case 0: return [4 /*yield*/, this.onBatchEnd(epoch, logs)];
70
113
  case 1: return [2 /*return*/, _a.sent()];
71
114
  }
115
+ }); }); },
116
+ onTrainEnd: function (logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
117
+ switch (_a.label) {
118
+ case 0: return [4 /*yield*/, this.onTrainEnd(logs)];
119
+ case 1: return [2 /*return*/, _a.sent()];
120
+ }
72
121
  }); }); }
73
122
  }
74
123
  })];
@@ -87,41 +136,6 @@ var Trainer = /** @class */ (function () {
87
136
  if (decimalsToRound === void 0) { decimalsToRound = 2; }
88
137
  return +(accuracy * 100).toFixed(decimalsToRound);
89
138
  };
90
- /**
91
- * We update the training graph, this needs to be done on epoch end as there is no validation accuracy onBatchEnd.
92
- */
93
- Trainer.prototype.onEpochEnd = function (epoch, logs) {
94
- this.trainerLogger.onEpochEnd(epoch, logs);
95
- if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) {
96
- this.trainingInformant.updateTrainingAccuracyGraph(this.roundDecimals(logs.acc));
97
- this.trainingInformant.updateValidationAccuracyGraph(this.roundDecimals(logs.val_acc));
98
- }
99
- else {
100
- this.trainerLogger.error('onEpochEnd: NaN value');
101
- }
102
- };
103
- /** onBatchEnd callback, when a round ends, we call onRoundEnd (to be implemented for local and distributed instances)
104
- */
105
- Trainer.prototype.onBatchEnd = function (_, logs) {
106
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
107
- return (0, tslib_1.__generator)(this, function (_a) {
108
- switch (_a.label) {
109
- case 0:
110
- if (logs === undefined) {
111
- return [2 /*return*/];
112
- }
113
- this.roundTracker.updateBatch();
114
- this.stopTrainModelIfRequested();
115
- if (!this.roundTracker.roundHasEnded()) return [3 /*break*/, 2];
116
- return [4 /*yield*/, this.onRoundEnd(logs.acc)];
117
- case 1:
118
- _a.sent();
119
- _a.label = 2;
120
- case 2: return [2 /*return*/];
121
- }
122
- });
123
- });
124
- };
125
139
  /**
126
140
  * reset stop training state
127
141
  */
@@ -48,28 +48,21 @@ var TrainerBuilder = /** @class */ (function () {
48
48
  TrainerBuilder.prototype.getModel = function (client) {
49
49
  var _a;
50
50
  return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
51
- var modelID, modelExistsInMemory, model;
51
+ var modelID, info, model;
52
52
  return (0, tslib_1.__generator)(this, function (_b) {
53
53
  switch (_b.label) {
54
54
  case 0:
55
55
  modelID = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.modelID;
56
56
  if (modelID === undefined) {
57
- throw new Error('undefined model ID');
57
+ throw new TypeError('model ID is undefined');
58
58
  }
59
- return [4 /*yield*/, this.memory.getModelMetadata(__1.ModelType.WORKING, this.task.taskID, modelID)];
60
- case 1:
61
- modelExistsInMemory = _b.sent();
62
- if (!(modelExistsInMemory !== undefined)) return [3 /*break*/, 3];
63
- return [4 /*yield*/, this.memory.getModel(__1.ModelType.WORKING, this.task.taskID, modelID)];
59
+ info = { type: __1.ModelType.WORKING, taskID: this.task.taskID, name: modelID };
60
+ return [4 /*yield*/, this.memory.contains(info)];
61
+ case 1: return [4 /*yield*/, ((_b.sent()) ? this.memory.getModel(info) : client.getLatestModel())];
64
62
  case 2:
65
63
  model = _b.sent();
66
- return [3 /*break*/, 5];
67
- case 3: return [4 /*yield*/, client.getLatestModel()];
68
- case 4:
69
- model = _b.sent();
70
- _b.label = 5;
71
- case 5: return [4 /*yield*/, this.updateModelInformation(model)];
72
- case 6: return [2 /*return*/, _b.sent()];
64
+ return [4 /*yield*/, this.updateModelInformation(model)];
65
+ case 3: return [2 /*return*/, _b.sent()];
73
66
  }
74
67
  });
75
68
  });
@@ -84,7 +77,7 @@ var TrainerBuilder = /** @class */ (function () {
84
77
  }
85
78
  info = this.task.trainingInformation;
86
79
  if (info === undefined) {
87
- throw new Error('undefined training information');
80
+ throw new TypeError('training information is undefined');
88
81
  }
89
82
  model.compile(info.modelCompileData);
90
83
  if (info.learningRate !== undefined) {
package/dist/types.d.ts CHANGED
@@ -1,4 +1,4 @@
1
- import * as tf from '@tensorflow/tfjs';
1
+ import { tf } from '.';
2
2
  export declare type Path = string;
3
3
  export declare type Weights = tf.Tensor[];
4
4
  export declare type MetadataID = string;
@@ -0,0 +1 @@
1
+ export { Validator } from './validator';
@@ -0,0 +1,5 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Validator = void 0;
4
+ var validator_1 = require("./validator");
5
+ Object.defineProperty(exports, "Validator", { enumerable: true, get: function () { return validator_1.Validator; } });
@@ -0,0 +1,20 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { ModelActor } from '../model_actor';
3
+ import { Task } from '@/task';
4
+ import { Data } from '@/dataset';
5
+ import { Logger } from '@/logging';
6
+ import { List } from 'immutable';
7
+ import { Client, Memory, ModelSource } from '..';
8
+ export declare class Validator extends ModelActor {
9
+ private readonly memory;
10
+ private readonly source?;
11
+ private readonly client?;
12
+ private readonly graphInformant;
13
+ private size;
14
+ constructor(task: Task, logger: Logger, memory: Memory, source?: ModelSource | undefined, client?: Client | undefined);
15
+ assess(data: Data): Promise<void>;
16
+ getModel(): Promise<tf.LayersModel>;
17
+ accuracyData(): List<number>;
18
+ accuracy(): number;
19
+ visitedSamples(): number;
20
+ }
@@ -0,0 +1,106 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Validator = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var model_actor_1 = require("../model_actor");
6
+ var immutable_1 = require("immutable");
7
+ var __1 = require("..");
8
+ var Validator = /** @class */ (function (_super) {
9
+ (0, tslib_1.__extends)(Validator, _super);
10
+ function Validator(task, logger, memory, source, client) {
11
+ var _this = _super.call(this, task, logger) || this;
12
+ _this.memory = memory;
13
+ _this.source = source;
14
+ _this.client = client;
15
+ _this.graphInformant = new __1.GraphInformant();
16
+ _this.size = 0;
17
+ if (source === undefined && client === undefined) {
18
+ throw new Error('cannot identify model');
19
+ }
20
+ return _this;
21
+ }
22
+ Validator.prototype.assess = function (data) {
23
+ var _a, _b, _c;
24
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
25
+ var batchSize, labels, classes, model, hits;
26
+ var _this = this;
27
+ return (0, tslib_1.__generator)(this, function (_d) {
28
+ switch (_d.label) {
29
+ case 0:
30
+ batchSize = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.batchSize;
31
+ if (batchSize === undefined) {
32
+ throw new TypeError('batch size is undefined');
33
+ }
34
+ labels = (_b = this.task.trainingInformation) === null || _b === void 0 ? void 0 : _b.LABEL_LIST;
35
+ classes = (_c = labels === null || labels === void 0 ? void 0 : labels.length) !== null && _c !== void 0 ? _c : 1;
36
+ return [4 /*yield*/, this.getModel()];
37
+ case 1:
38
+ model = _d.sent();
39
+ hits = 0;
40
+ return [4 /*yield*/, data.dataset.batch(batchSize).forEachAsync(function (e) {
41
+ if (typeof e === 'object' && 'xs' in e && 'ys' in e) {
42
+ var xs = e.xs;
43
+ var ys = e.ys.dataSync();
44
+ var pred = model.predict(xs, { batchSize: batchSize })
45
+ .dataSync()
46
+ .map(Math.round);
47
+ _this.size += xs.shape[0];
48
+ hits += (0, immutable_1.List)(pred).zip((0, immutable_1.List)(ys))
49
+ .map(function (_a) {
50
+ var _b = (0, tslib_1.__read)(_a, 2), p = _b[0], y = _b[1];
51
+ return 1 - Math.abs(p - y);
52
+ })
53
+ .reduce(function (acc, e) { return acc + e; }) / classes;
54
+ var currentAccuracy = hits / _this.size;
55
+ _this.graphInformant.updateAccuracy(currentAccuracy);
56
+ }
57
+ else {
58
+ throw new TypeError('missing feature/label in dataset');
59
+ }
60
+ })];
61
+ case 2:
62
+ _d.sent();
63
+ this.logger.success("Obtained validation accuracy of " + this.accuracy());
64
+ this.logger.success("Visited " + this.visitedSamples() + " samples");
65
+ return [2 /*return*/];
66
+ }
67
+ });
68
+ });
69
+ };
70
+ Validator.prototype.getModel = function () {
71
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
72
+ var _a;
73
+ return (0, tslib_1.__generator)(this, function (_b) {
74
+ switch (_b.label) {
75
+ case 0:
76
+ _a = this.source !== undefined;
77
+ if (!_a) return [3 /*break*/, 2];
78
+ return [4 /*yield*/, this.memory.contains(this.source)];
79
+ case 1:
80
+ _a = (_b.sent());
81
+ _b.label = 2;
82
+ case 2:
83
+ if (!_a) return [3 /*break*/, 4];
84
+ return [4 /*yield*/, this.memory.getModel(this.source)];
85
+ case 3: return [2 /*return*/, _b.sent()];
86
+ case 4:
87
+ if (!(this.client !== undefined)) return [3 /*break*/, 6];
88
+ return [4 /*yield*/, this.client.getLatestModel()];
89
+ case 5: return [2 /*return*/, _b.sent()];
90
+ case 6: throw new Error('cannot identify model');
91
+ }
92
+ });
93
+ });
94
+ };
95
+ Validator.prototype.accuracyData = function () {
96
+ return this.graphInformant.data();
97
+ };
98
+ Validator.prototype.accuracy = function () {
99
+ return this.graphInformant.accuracy();
100
+ };
101
+ Validator.prototype.visitedSamples = function () {
102
+ return this.size;
103
+ };
104
+ return Validator;
105
+ }(model_actor_1.ModelActor));
106
+ exports.Validator = Validator;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "0.0.1",
3
+ "version": "0.1.0",
4
4
  "main": "dist/index.js",
5
5
  "types": "dist/index.d.ts",
6
6
  "scripts": {
@@ -16,7 +16,6 @@
16
16
  },
17
17
  "homepage": "https://github.com/epfml/disco#readme",
18
18
  "dependencies": {
19
- "@tensorflow/tfjs": "3",
20
19
  "axios": "0.27",
21
20
  "immutable": "4",
22
21
  "isomorphic-ws": "4",
@@ -45,4 +44,4 @@
45
44
  "ts-node": "10",
46
45
  "typescript": "<4.5.0"
47
46
  }
48
- }
47
+ }
@@ -1,23 +0,0 @@
1
- import { TrainingInformant, Weights } from '..';
2
- import { Base } from './base';
3
- /**
4
- * Class that deals with communication with the PeerJS server.
5
- * Collects the list of receivers currently connected to the PeerJS server.
6
- */
7
- export declare class Decentralized extends Base {
8
- private server?;
9
- private peers;
10
- private readonly weights;
11
- private connectServer;
12
- private connectNewPeer;
13
- /**
14
- * Initialize the connection to the peers and to the other nodes.
15
- */
16
- connect(): Promise<void>;
17
- /**
18
- * Disconnection process when user quits the task.
19
- */
20
- disconnect(): Promise<void>;
21
- onRoundEndCommunication(updatedWeights: Weights, staleWeights: Weights, epoch: number, trainingInformant: TrainingInformant): Promise<Weights>;
22
- onTrainEndCommunication(_: Weights, trainingInformant: TrainingInformant): Promise<void>;
23
- }