@epfml/discojs 0.0.1 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. package/README.md +12 -11
  2. package/dist/aggregation.d.ts +4 -2
  3. package/dist/aggregation.js +20 -6
  4. package/dist/client/base.d.ts +1 -1
  5. package/dist/client/decentralized/base.d.ts +43 -0
  6. package/dist/client/decentralized/base.js +243 -0
  7. package/dist/client/decentralized/clear_text.d.ts +13 -0
  8. package/dist/client/decentralized/clear_text.js +78 -0
  9. package/dist/client/decentralized/index.d.ts +4 -0
  10. package/dist/client/decentralized/index.js +9 -0
  11. package/dist/client/decentralized/messages.d.ts +37 -0
  12. package/dist/client/decentralized/messages.js +15 -0
  13. package/dist/client/decentralized/sec_agg.d.ts +18 -0
  14. package/dist/client/decentralized/sec_agg.js +169 -0
  15. package/dist/client/decentralized/secret_shares.d.ts +5 -0
  16. package/dist/client/decentralized/secret_shares.js +58 -0
  17. package/dist/client/decentralized/types.d.ts +1 -0
  18. package/dist/client/decentralized/types.js +2 -0
  19. package/dist/client/federated.d.ts +4 -4
  20. package/dist/client/federated.js +5 -8
  21. package/dist/client/index.d.ts +1 -1
  22. package/dist/client/index.js +3 -3
  23. package/dist/client/local.js +5 -3
  24. package/dist/dataset/data_loader/data_loader.d.ts +7 -1
  25. package/dist/dataset/data_loader/image_loader.d.ts +5 -3
  26. package/dist/dataset/data_loader/image_loader.js +64 -18
  27. package/dist/dataset/data_loader/index.d.ts +1 -1
  28. package/dist/dataset/data_loader/tabular_loader.d.ts +3 -3
  29. package/dist/dataset/data_loader/tabular_loader.js +27 -18
  30. package/dist/dataset/dataset_builder.d.ts +3 -2
  31. package/dist/dataset/dataset_builder.js +29 -17
  32. package/dist/index.d.ts +6 -4
  33. package/dist/index.js +13 -5
  34. package/dist/informant/graph_informant.d.ts +10 -0
  35. package/dist/informant/graph_informant.js +23 -0
  36. package/dist/informant/index.d.ts +3 -0
  37. package/dist/informant/index.js +9 -0
  38. package/dist/informant/training_informant/base.d.ts +31 -0
  39. package/dist/informant/training_informant/base.js +82 -0
  40. package/dist/informant/training_informant/decentralized.d.ts +5 -0
  41. package/dist/informant/training_informant/decentralized.js +22 -0
  42. package/dist/informant/training_informant/federated.d.ts +14 -0
  43. package/dist/informant/training_informant/federated.js +32 -0
  44. package/dist/informant/training_informant/index.d.ts +4 -0
  45. package/dist/informant/training_informant/index.js +11 -0
  46. package/dist/informant/training_informant/local.d.ts +6 -0
  47. package/dist/informant/training_informant/local.js +20 -0
  48. package/dist/logging/index.d.ts +1 -0
  49. package/dist/logging/index.js +3 -1
  50. package/dist/logging/trainer_logger.d.ts +1 -1
  51. package/dist/logging/trainer_logger.js +5 -5
  52. package/dist/memory/base.d.ts +17 -48
  53. package/dist/memory/empty.d.ts +6 -4
  54. package/dist/memory/empty.js +8 -2
  55. package/dist/memory/index.d.ts +1 -1
  56. package/dist/privacy.js +3 -3
  57. package/dist/serialization/model.d.ts +1 -1
  58. package/dist/serialization/model.js +2 -2
  59. package/dist/serialization/weights.js +2 -2
  60. package/dist/task/display_information.d.ts +2 -2
  61. package/dist/task/display_information.js +6 -5
  62. package/dist/task/summary.d.ts +5 -0
  63. package/dist/task/summary.js +23 -0
  64. package/dist/task/training_information.d.ts +3 -0
  65. package/dist/tasks/cifar10.js +5 -3
  66. package/dist/tasks/lus_covid.d.ts +1 -1
  67. package/dist/tasks/lus_covid.js +50 -13
  68. package/dist/tasks/mnist.js +4 -2
  69. package/dist/tasks/simple_face.d.ts +1 -1
  70. package/dist/tasks/simple_face.js +13 -17
  71. package/dist/tasks/titanic.js +9 -7
  72. package/dist/tfjs.d.ts +2 -0
  73. package/dist/tfjs.js +6 -0
  74. package/dist/training/disco.d.ts +3 -1
  75. package/dist/training/disco.js +14 -6
  76. package/dist/training/trainer/distributed_trainer.d.ts +1 -1
  77. package/dist/training/trainer/distributed_trainer.js +5 -1
  78. package/dist/training/trainer/local_trainer.d.ts +4 -3
  79. package/dist/training/trainer/local_trainer.js +6 -9
  80. package/dist/training/trainer/round_tracker.js +3 -0
  81. package/dist/training/trainer/trainer.d.ts +15 -15
  82. package/dist/training/trainer/trainer.js +57 -43
  83. package/dist/training/trainer/trainer_builder.js +8 -15
  84. package/dist/types.d.ts +1 -1
  85. package/dist/validation/index.d.ts +1 -0
  86. package/dist/validation/index.js +5 -0
  87. package/dist/validation/validator.d.ts +20 -0
  88. package/dist/validation/validator.js +106 -0
  89. package/package.json +2 -3
  90. package/dist/client/decentralized.d.ts +0 -23
  91. package/dist/client/decentralized.js +0 -275
  92. package/dist/testing/tester.d.ts +0 -5
  93. package/dist/testing/tester.js +0 -21
  94. package/dist/training_informant.d.ts +0 -88
  95. package/dist/training_informant.js +0 -135
@@ -11,7 +11,6 @@ var DatasetBuilder = /** @class */ (function () {
11
11
  this.built = false;
12
12
  }
13
13
  DatasetBuilder.prototype.addFiles = function (sources, label) {
14
- var _this = this;
15
14
  if (this.built) {
16
15
  throw new Error('builder already consumed');
17
16
  }
@@ -19,7 +18,13 @@ var DatasetBuilder = /** @class */ (function () {
19
18
  this.sources = this.sources.concat(sources);
20
19
  }
21
20
  else {
22
- sources.forEach(function (source) { return _this.labelledSources.set(label, source); });
21
+ var currentSources = this.labelledSources.get(label);
22
+ if (currentSources === undefined) {
23
+ this.labelledSources.set(label, sources);
24
+ }
25
+ else {
26
+ this.labelledSources.set(label, currentSources.concat(sources));
27
+ }
23
28
  }
24
29
  };
25
30
  DatasetBuilder.prototype.clearFiles = function (label) {
@@ -33,10 +38,21 @@ var DatasetBuilder = /** @class */ (function () {
33
38
  this.labelledSources.delete(label);
34
39
  }
35
40
  };
36
- DatasetBuilder.prototype.build = function () {
41
+ DatasetBuilder.prototype.getLabels = function () {
42
+ // We need to duplicate the labels as we need one for each soure.
43
+ // Say for label A we have sources [img1, img2, img3], then we
44
+ // need labels [A, A, A].
45
+ var labels = [];
46
+ Array.from(this.labelledSources.values()).forEach(function (sources, index) {
47
+ var sourcesLabels = Array.from({ length: sources.length }, function (_) { return index.toString(); });
48
+ labels = labels.concat(sourcesLabels);
49
+ });
50
+ return labels.flat();
51
+ };
52
+ DatasetBuilder.prototype.build = function (config) {
37
53
  var _a, _b;
38
54
  return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
39
- var data, config, config;
55
+ var dataTuple, defaultConfig, defaultConfig, sources;
40
56
  return (0, tslib_1.__generator)(this, function (_c) {
41
57
  switch (_c.label) {
42
58
  case 0:
@@ -45,26 +61,22 @@ var DatasetBuilder = /** @class */ (function () {
45
61
  throw new Error('invalid sources');
46
62
  }
47
63
  if (!(this.sources.length > 0)) return [3 /*break*/, 2];
48
- config = {
49
- features: (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.inputColumns,
50
- labels: (_b = this.task.trainingInformation) === null || _b === void 0 ? void 0 : _b.outputColumns
51
- };
52
- return [4 /*yield*/, this.dataLoader.loadAll(this.sources, config)];
64
+ defaultConfig = (0, tslib_1.__assign)({ features: (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.inputColumns, labels: (_b = this.task.trainingInformation) === null || _b === void 0 ? void 0 : _b.outputColumns }, config);
65
+ return [4 /*yield*/, this.dataLoader.loadAll(this.sources, defaultConfig)];
53
66
  case 1:
54
- data = _c.sent();
67
+ dataTuple = _c.sent();
55
68
  return [3 /*break*/, 4];
56
69
  case 2:
57
- config = {
58
- labels: Array.from(this.labelledSources.keys())
59
- };
60
- return [4 /*yield*/, this.dataLoader.loadAll(Array.from(this.labelledSources.values()), config)];
70
+ defaultConfig = (0, tslib_1.__assign)({ labels: this.getLabels() }, config);
71
+ sources = Array.from(this.labelledSources.values()).flat();
72
+ return [4 /*yield*/, this.dataLoader.loadAll(sources, defaultConfig)];
61
73
  case 3:
62
- data = _c.sent();
74
+ dataTuple = _c.sent();
63
75
  _c.label = 4;
64
76
  case 4:
65
- // TODO @s314cy: Support .csv labels for image datasets
77
+ // TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
66
78
  this.built = true;
67
- return [2 /*return*/, data];
79
+ return [2 /*return*/, dataTuple];
68
80
  }
69
81
  });
70
82
  });
package/dist/index.d.ts CHANGED
@@ -4,14 +4,16 @@ export * as serialization from './serialization';
4
4
  export * as tasks from './tasks';
5
5
  export * as training from './training';
6
6
  export * as privacy from './privacy';
7
+ export { GraphInformant, TrainingInformant, informant } from './informant';
7
8
  export { Base as Client } from './client';
8
9
  export * as client from './client';
9
10
  export { AsyncBuffer } from './async_buffer';
10
11
  export { AsyncInformant } from './async_informant';
11
- export { Logger, ConsoleLogger } from './logging';
12
- export { Memory, ModelType, Empty as EmptyMemory } from './memory';
12
+ export { Logger, ConsoleLogger, TrainerLog } from './logging';
13
+ export { Memory, ModelType, ModelInfo, Path, ModelSource, Empty as EmptyMemory } from './memory';
13
14
  export { ModelActor } from './model_actor';
15
+ export { Disco, TrainingSchemes } from './training';
16
+ export { Validator } from './validation';
14
17
  export { TrainingInformation, DisplayInformation, isTask, Task, isTaskID, TaskID } from './task';
15
- export { TrainingInformant } from './training_informant';
16
- export { TrainingSchemes } from './training/training_schemes';
17
18
  export * from './types';
19
+ export { tf } from './tfjs';
package/dist/index.js CHANGED
@@ -1,6 +1,6 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.TrainingSchemes = exports.TrainingInformant = exports.isTaskID = exports.isTask = exports.ModelActor = exports.EmptyMemory = exports.ModelType = exports.Memory = exports.ConsoleLogger = exports.Logger = exports.AsyncInformant = exports.AsyncBuffer = exports.client = exports.Client = exports.privacy = exports.training = exports.tasks = exports.serialization = exports.dataset = exports.aggregation = void 0;
3
+ exports.tf = exports.isTaskID = exports.isTask = exports.Validator = exports.TrainingSchemes = exports.Disco = exports.ModelActor = exports.EmptyMemory = exports.ModelType = exports.Memory = exports.TrainerLog = exports.ConsoleLogger = exports.Logger = exports.AsyncInformant = exports.AsyncBuffer = exports.client = exports.Client = exports.informant = exports.TrainingInformant = exports.GraphInformant = exports.privacy = exports.training = exports.tasks = exports.serialization = exports.dataset = exports.aggregation = void 0;
4
4
  var tslib_1 = require("tslib");
5
5
  exports.aggregation = (0, tslib_1.__importStar)(require("./aggregation"));
6
6
  exports.dataset = (0, tslib_1.__importStar)(require("./dataset"));
@@ -8,6 +8,10 @@ exports.serialization = (0, tslib_1.__importStar)(require("./serialization"));
8
8
  exports.tasks = (0, tslib_1.__importStar)(require("./tasks"));
9
9
  exports.training = (0, tslib_1.__importStar)(require("./training"));
10
10
  exports.privacy = (0, tslib_1.__importStar)(require("./privacy"));
11
+ var informant_1 = require("./informant");
12
+ Object.defineProperty(exports, "GraphInformant", { enumerable: true, get: function () { return informant_1.GraphInformant; } });
13
+ Object.defineProperty(exports, "TrainingInformant", { enumerable: true, get: function () { return informant_1.TrainingInformant; } });
14
+ Object.defineProperty(exports, "informant", { enumerable: true, get: function () { return informant_1.informant; } });
11
15
  var client_1 = require("./client");
12
16
  Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Base; } });
13
17
  exports.client = (0, tslib_1.__importStar)(require("./client"));
@@ -18,17 +22,21 @@ Object.defineProperty(exports, "AsyncInformant", { enumerable: true, get: functi
18
22
  var logging_1 = require("./logging");
19
23
  Object.defineProperty(exports, "Logger", { enumerable: true, get: function () { return logging_1.Logger; } });
20
24
  Object.defineProperty(exports, "ConsoleLogger", { enumerable: true, get: function () { return logging_1.ConsoleLogger; } });
25
+ Object.defineProperty(exports, "TrainerLog", { enumerable: true, get: function () { return logging_1.TrainerLog; } });
21
26
  var memory_1 = require("./memory");
22
27
  Object.defineProperty(exports, "Memory", { enumerable: true, get: function () { return memory_1.Memory; } });
23
28
  Object.defineProperty(exports, "ModelType", { enumerable: true, get: function () { return memory_1.ModelType; } });
24
29
  Object.defineProperty(exports, "EmptyMemory", { enumerable: true, get: function () { return memory_1.Empty; } });
25
30
  var model_actor_1 = require("./model_actor");
26
31
  Object.defineProperty(exports, "ModelActor", { enumerable: true, get: function () { return model_actor_1.ModelActor; } });
32
+ var training_1 = require("./training");
33
+ Object.defineProperty(exports, "Disco", { enumerable: true, get: function () { return training_1.Disco; } });
34
+ Object.defineProperty(exports, "TrainingSchemes", { enumerable: true, get: function () { return training_1.TrainingSchemes; } });
35
+ var validation_1 = require("./validation");
36
+ Object.defineProperty(exports, "Validator", { enumerable: true, get: function () { return validation_1.Validator; } });
27
37
  var task_1 = require("./task");
28
38
  Object.defineProperty(exports, "isTask", { enumerable: true, get: function () { return task_1.isTask; } });
29
39
  Object.defineProperty(exports, "isTaskID", { enumerable: true, get: function () { return task_1.isTaskID; } });
30
- var training_informant_1 = require("./training_informant");
31
- Object.defineProperty(exports, "TrainingInformant", { enumerable: true, get: function () { return training_informant_1.TrainingInformant; } });
32
- var training_schemes_1 = require("./training/training_schemes");
33
- Object.defineProperty(exports, "TrainingSchemes", { enumerable: true, get: function () { return training_schemes_1.TrainingSchemes; } });
34
40
  (0, tslib_1.__exportStar)(require("./types"), exports);
41
+ var tfjs_1 = require("./tfjs");
42
+ Object.defineProperty(exports, "tf", { enumerable: true, get: function () { return tfjs_1.tf; } });
@@ -0,0 +1,10 @@
1
+ import { List } from 'immutable';
2
+ export declare class GraphInformant {
3
+ static readonly NB_EPOCHS_ON_GRAPH = 10;
4
+ private currentAccuracy;
5
+ private accuracyDataSeries;
6
+ constructor();
7
+ updateAccuracy(accuracy: number): void;
8
+ data(): List<number>;
9
+ accuracy(): number;
10
+ }
@@ -0,0 +1,23 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.GraphInformant = void 0;
4
+ var immutable_1 = require("immutable");
5
+ var GraphInformant = /** @class */ (function () {
6
+ function GraphInformant() {
7
+ this.currentAccuracy = 0;
8
+ this.accuracyDataSeries = (0, immutable_1.Repeat)(0, GraphInformant.NB_EPOCHS_ON_GRAPH).toList();
9
+ }
10
+ GraphInformant.prototype.updateAccuracy = function (accuracy) {
11
+ this.accuracyDataSeries = this.accuracyDataSeries.shift().push(accuracy);
12
+ this.currentAccuracy = accuracy;
13
+ };
14
+ GraphInformant.prototype.data = function () {
15
+ return this.accuracyDataSeries;
16
+ };
17
+ GraphInformant.prototype.accuracy = function () {
18
+ return this.currentAccuracy;
19
+ };
20
+ GraphInformant.NB_EPOCHS_ON_GRAPH = 10;
21
+ return GraphInformant;
22
+ }());
23
+ exports.GraphInformant = GraphInformant;
@@ -0,0 +1,3 @@
1
+ export { GraphInformant } from './graph_informant';
2
+ export { Base as TrainingInformant } from './training_informant';
3
+ export * as informant from './training_informant';
@@ -0,0 +1,9 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.informant = exports.TrainingInformant = exports.GraphInformant = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var graph_informant_1 = require("./graph_informant");
6
+ Object.defineProperty(exports, "GraphInformant", { enumerable: true, get: function () { return graph_informant_1.GraphInformant; } });
7
+ var training_informant_1 = require("./training_informant");
8
+ Object.defineProperty(exports, "TrainingInformant", { enumerable: true, get: function () { return training_informant_1.Base; } });
9
+ exports.informant = (0, tslib_1.__importStar)(require("./training_informant"));
@@ -0,0 +1,31 @@
1
+ import { List } from 'immutable';
2
+ import { TaskID } from '@/task';
3
+ import { GraphInformant } from '../graph_informant';
4
+ export declare abstract class Base {
5
+ readonly taskID: TaskID;
6
+ private readonly nbrMessagesToShow;
7
+ private messages;
8
+ protected readonly trainingGraphInformant: GraphInformant;
9
+ protected readonly validationGraphInformant: GraphInformant;
10
+ protected currentRound: number;
11
+ protected currentNumberOfParticipants: number;
12
+ protected totalNumberOfParticipants: number;
13
+ protected averageNumberOfParticipants: number;
14
+ constructor(taskID: TaskID, nbrMessagesToShow: number);
15
+ abstract update(statistics: Record<string, number>): void;
16
+ addMessage(msg: string): void;
17
+ getMessages(): string[];
18
+ round(): number;
19
+ participants(): number;
20
+ totalParticipants(): number;
21
+ averageParticipants(): number;
22
+ updateTrainingGraph(accuracy: number): void;
23
+ updateValidationGraph(accuracy: number): void;
24
+ trainingAccuracy(): number;
25
+ validationAccuracy(): number;
26
+ trainingAccuracyData(): List<number>;
27
+ validationAccuracyData(): List<number>;
28
+ isDecentralized(): boolean;
29
+ isFederated(): boolean;
30
+ static isTrainingInformant(raw: unknown): raw is Base;
31
+ }
@@ -0,0 +1,82 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Base = void 0;
4
+ var immutable_1 = require("immutable");
5
+ var graph_informant_1 = require("../graph_informant");
6
+ var Base = /** @class */ (function () {
7
+ function Base(taskID, nbrMessagesToShow) {
8
+ this.taskID = taskID;
9
+ this.nbrMessagesToShow = nbrMessagesToShow;
10
+ // written feedback
11
+ this.messages = (0, immutable_1.List)();
12
+ // graph feedback
13
+ this.trainingGraphInformant = new graph_informant_1.GraphInformant();
14
+ this.validationGraphInformant = new graph_informant_1.GraphInformant();
15
+ // statistics
16
+ this.currentRound = 0;
17
+ this.currentNumberOfParticipants = 0;
18
+ this.totalNumberOfParticipants = 0;
19
+ this.averageNumberOfParticipants = 0;
20
+ }
21
+ Base.prototype.addMessage = function (msg) {
22
+ if (this.messages.size >= this.nbrMessagesToShow) {
23
+ this.messages = this.messages.shift();
24
+ }
25
+ this.messages = this.messages.push(msg);
26
+ };
27
+ Base.prototype.getMessages = function () {
28
+ return this.messages.toArray();
29
+ };
30
+ Base.prototype.round = function () {
31
+ return this.currentRound;
32
+ };
33
+ Base.prototype.participants = function () {
34
+ return this.currentNumberOfParticipants;
35
+ };
36
+ Base.prototype.totalParticipants = function () {
37
+ return this.totalNumberOfParticipants;
38
+ };
39
+ Base.prototype.averageParticipants = function () {
40
+ return this.averageNumberOfParticipants;
41
+ };
42
+ Base.prototype.updateTrainingGraph = function (accuracy) {
43
+ this.trainingGraphInformant.updateAccuracy(accuracy);
44
+ };
45
+ Base.prototype.updateValidationGraph = function (accuracy) {
46
+ this.validationGraphInformant.updateAccuracy(accuracy);
47
+ };
48
+ Base.prototype.trainingAccuracy = function () {
49
+ return this.trainingGraphInformant.accuracy();
50
+ };
51
+ Base.prototype.validationAccuracy = function () {
52
+ return this.validationGraphInformant.accuracy();
53
+ };
54
+ Base.prototype.trainingAccuracyData = function () {
55
+ return this.trainingGraphInformant.data();
56
+ };
57
+ Base.prototype.validationAccuracyData = function () {
58
+ return this.validationGraphInformant.data();
59
+ };
60
+ Base.prototype.isDecentralized = function () {
61
+ return false;
62
+ };
63
+ Base.prototype.isFederated = function () {
64
+ return false;
65
+ };
66
+ Base.isTrainingInformant = function (raw) {
67
+ if (typeof raw !== 'object') {
68
+ return false;
69
+ }
70
+ if (raw === null) {
71
+ return false;
72
+ }
73
+ // TODO
74
+ var requiredFields = (0, immutable_1.Set)();
75
+ if (!(requiredFields.every(function (field) { return field in raw; }))) {
76
+ return false;
77
+ }
78
+ return true;
79
+ };
80
+ return Base;
81
+ }());
82
+ exports.Base = Base;
@@ -0,0 +1,5 @@
1
+ import { Base } from '.';
2
+ export declare class DecentralizedInformant extends Base {
3
+ update(statistics: Record<string, number>): void;
4
+ isDecentralized(): boolean;
5
+ }
@@ -0,0 +1,22 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.DecentralizedInformant = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var _1 = require(".");
6
+ var DecentralizedInformant = /** @class */ (function (_super) {
7
+ (0, tslib_1.__extends)(DecentralizedInformant, _super);
8
+ function DecentralizedInformant() {
9
+ return _super !== null && _super.apply(this, arguments) || this;
10
+ }
11
+ DecentralizedInformant.prototype.update = function (statistics) {
12
+ this.currentRound += 1;
13
+ this.currentNumberOfParticipants = statistics.currentNumberOfParticipants;
14
+ this.totalNumberOfParticipants += this.currentNumberOfParticipants;
15
+ this.averageNumberOfParticipants = this.totalNumberOfParticipants / this.currentRound;
16
+ };
17
+ DecentralizedInformant.prototype.isDecentralized = function () {
18
+ return true;
19
+ };
20
+ return DecentralizedInformant;
21
+ }(_1.Base));
22
+ exports.DecentralizedInformant = DecentralizedInformant;
@@ -0,0 +1,14 @@
1
+ import { Base } from '.';
2
+ /**
3
+ * Class that collects information about the status of the training-loop of the model.
4
+ */
5
+ export declare class FederatedInformant extends Base {
6
+ displayHeatmap: boolean;
7
+ /**
8
+ * Update the server statistics with the JSON received from the server
9
+ * For now it's just the JSON, but we might want to keep it as a dictionary
10
+ * @param receivedStatistics statistics received from the server.
11
+ */
12
+ update(receivedStatistics: Record<string, number>): void;
13
+ isFederated(): boolean;
14
+ }
@@ -0,0 +1,32 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.FederatedInformant = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var _1 = require(".");
6
+ /**
7
+ * Class that collects information about the status of the training-loop of the model.
8
+ */
9
+ var FederatedInformant = /** @class */ (function (_super) {
10
+ (0, tslib_1.__extends)(FederatedInformant, _super);
11
+ function FederatedInformant() {
12
+ var _this = _super !== null && _super.apply(this, arguments) || this;
13
+ _this.displayHeatmap = false;
14
+ return _this;
15
+ }
16
+ /**
17
+ * Update the server statistics with the JSON received from the server
18
+ * For now it's just the JSON, but we might want to keep it as a dictionary
19
+ * @param receivedStatistics statistics received from the server.
20
+ */
21
+ FederatedInformant.prototype.update = function (receivedStatistics) {
22
+ this.currentRound = receivedStatistics.round;
23
+ this.currentNumberOfParticipants = receivedStatistics.currentNumberOfParticipants;
24
+ this.totalNumberOfParticipants = receivedStatistics.totalNumberOfParticipants;
25
+ this.averageNumberOfParticipants = receivedStatistics.averageNumberOfParticipants;
26
+ };
27
+ FederatedInformant.prototype.isFederated = function () {
28
+ return true;
29
+ };
30
+ return FederatedInformant;
31
+ }(_1.Base));
32
+ exports.FederatedInformant = FederatedInformant;
@@ -0,0 +1,4 @@
1
+ export { Base } from './base';
2
+ export { FederatedInformant } from './federated';
3
+ export { DecentralizedInformant } from './decentralized';
4
+ export { LocalInformant } from './local';
@@ -0,0 +1,11 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.LocalInformant = exports.DecentralizedInformant = exports.FederatedInformant = exports.Base = void 0;
4
+ var base_1 = require("./base");
5
+ Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
6
+ var federated_1 = require("./federated");
7
+ Object.defineProperty(exports, "FederatedInformant", { enumerable: true, get: function () { return federated_1.FederatedInformant; } });
8
+ var decentralized_1 = require("./decentralized");
9
+ Object.defineProperty(exports, "DecentralizedInformant", { enumerable: true, get: function () { return decentralized_1.DecentralizedInformant; } });
10
+ var local_1 = require("./local");
11
+ Object.defineProperty(exports, "LocalInformant", { enumerable: true, get: function () { return local_1.LocalInformant; } });
@@ -0,0 +1,6 @@
1
+ import { TaskID } from '@/task';
2
+ import { Base } from '.';
3
+ export declare class LocalInformant extends Base {
4
+ constructor(taskID: TaskID, nbrMessagesToShow: number);
5
+ update(statistics: Record<string, number>): void;
6
+ }
@@ -0,0 +1,20 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.LocalInformant = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var _1 = require(".");
6
+ var LocalInformant = /** @class */ (function (_super) {
7
+ (0, tslib_1.__extends)(LocalInformant, _super);
8
+ function LocalInformant(taskID, nbrMessagesToShow) {
9
+ var _this = _super.call(this, taskID, nbrMessagesToShow) || this;
10
+ _this.currentNumberOfParticipants = 1;
11
+ _this.averageNumberOfParticipants = 1;
12
+ _this.totalNumberOfParticipants = 1;
13
+ return _this;
14
+ }
15
+ LocalInformant.prototype.update = function (statistics) {
16
+ this.currentRound = statistics.currentRound;
17
+ };
18
+ return LocalInformant;
19
+ }(_1.Base));
20
+ exports.LocalInformant = LocalInformant;
@@ -1,2 +1,3 @@
1
1
  export { Logger } from './logger';
2
2
  export { ConsoleLogger } from './console_logger';
3
+ export { TrainerLog } from './trainer_logger';
@@ -1,7 +1,9 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ConsoleLogger = exports.Logger = void 0;
3
+ exports.TrainerLog = exports.ConsoleLogger = exports.Logger = void 0;
4
4
  var logger_1 = require("./logger");
5
5
  Object.defineProperty(exports, "Logger", { enumerable: true, get: function () { return logger_1.Logger; } });
6
6
  var console_logger_1 = require("./console_logger");
7
7
  Object.defineProperty(exports, "ConsoleLogger", { enumerable: true, get: function () { return console_logger_1.ConsoleLogger; } });
8
+ var trainer_logger_1 = require("./trainer_logger");
9
+ Object.defineProperty(exports, "TrainerLog", { enumerable: true, get: function () { return trainer_logger_1.TrainerLog; } });
@@ -1,5 +1,5 @@
1
- import * as tf from '@tensorflow/tfjs';
2
1
  import { List } from 'immutable';
2
+ import { tf } from '..';
3
3
  import { ConsoleLogger } from '.';
4
4
  export declare class TrainerLog {
5
5
  epochs: List<number>;
@@ -2,8 +2,8 @@
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.TrainerLogger = exports.TrainerLog = void 0;
4
4
  var tslib_1 = require("tslib");
5
- var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
5
  var immutable_1 = require("immutable");
6
+ var __1 = require("..");
7
7
  var _1 = require(".");
8
8
  var TrainerLog = /** @class */ (function () {
9
9
  function TrainerLog() {
@@ -44,15 +44,15 @@ var TrainerLogger = /** @class */ (function (_super) {
44
44
  this.log.add(epoch, logs);
45
45
  }
46
46
  // console output
47
- var msg = "Train: " + ((_a = logs === null || logs === void 0 ? void 0 : logs.acc) !== null && _a !== void 0 ? _a : 'undefined') + "\nValidation:" + ((_b = logs === null || logs === void 0 ? void 0 : logs.val_acc) !== null && _b !== void 0 ? _b : 'undefined') + "\nLoss:" + ((_c = logs === null || logs === void 0 ? void 0 : logs.loss) !== null && _c !== void 0 ? _c : 'undefined');
48
- this.success("On epoch end:\n" + msg);
47
+ var msg = "Epoch: " + epoch + "\nTrain: " + ((_a = logs === null || logs === void 0 ? void 0 : logs.acc) !== null && _a !== void 0 ? _a : 'undefined') + "\nValidation:" + ((_b = logs === null || logs === void 0 ? void 0 : logs.val_acc) !== null && _b !== void 0 ? _b : 'undefined') + "\nLoss:" + ((_c = logs === null || logs === void 0 ? void 0 : logs.loss) !== null && _c !== void 0 ? _c : 'undefined');
48
+ this.success("On epoch end:\n" + msg + "\n");
49
49
  };
50
50
  /**
51
51
  * Display ram usage
52
52
  */
53
53
  TrainerLogger.prototype.ramUsage = function () {
54
- this.success("Training RAM usage is = " + tf.memory().numBytes * 0.000001 + " MB");
55
- this.success("Number of allocated tensors = " + tf.memory().numTensors);
54
+ this.success("Training RAM usage is = " + __1.tf.memory().numBytes * 0.000001 + " MB");
55
+ this.success("Number of allocated tensors = " + __1.tf.memory().numTensors);
56
56
  };
57
57
  return TrainerLogger;
58
58
  }(_1.ConsoleLogger));
@@ -1,53 +1,22 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
2
  import { TaskID } from '..';
3
3
  import { ModelType } from './model_type';
4
+ export declare type Path = string;
5
+ export interface ModelInfo {
6
+ type?: ModelType;
7
+ taskID: TaskID;
8
+ name?: string;
9
+ }
10
+ export declare type ModelSource = ModelInfo | Path;
4
11
  export declare abstract class Memory {
5
- /**
6
- * Fetches metadata of the model.
7
- * @param taskID the working model's corresponding task
8
- * @param modelName the working model's file name
9
- */
10
- abstract getModelMetadata(type: ModelType, taskID: TaskID, modelName: string): Promise<tf.io.ModelArtifactsInfo | undefined>;
11
- /**
12
- * Loads the current working model and returns it as a fresh TFJS object.
13
- * @param taskID the working model's corresponding task
14
- * @param modelName the working model's file name
15
- */
16
- abstract getModel(type: ModelType, taskID: TaskID, modelName: string): Promise<tf.LayersModel>;
17
- /**
18
- * Loads a model from the model library into the current working model.
19
- * @param taskID the saved model's corresponding task
20
- * @param modelName the saved model's file name
21
- */
22
- abstract loadSavedModel(taskID: TaskID, modelName: string): Promise<void>;
23
- /**
24
- * Loads a fresh TFJS model object into the current working model.
25
- * @param taskID the working model's corresponding task
26
- * @param modelName the working model's file name
27
- * @param model the fresh model
28
- */
29
- abstract updateWorkingModel(taskID: TaskID, modelName: string, model: tf.LayersModel): Promise<void>;
30
- /**
31
- * Adds the current working model to the model library.
32
- * @param taskID the working model's corresponding task
33
- * @param modelName the working model's file name
34
- */
35
- abstract saveWorkingModel(taskID: TaskID, modelName: string): Promise<void>;
36
- /**
37
- * Removes the model from the library.
38
- * @param taskID the model's corresponding task
39
- * @param modelName the model's file name
40
- */
41
- abstract deleteModel(type: ModelType, taskID: TaskID, modelName: string): Promise<void>;
42
- /**
43
- * Downloads a previously saved model.
44
- * @param {taskID} taskID the saved model's corresponding task
45
- * @param {string} modelName the saved model's file name
46
- */
47
- abstract downloadSavedModel(taskID: TaskID, modelName: string): Promise<void>;
48
- /**
49
- * @param {taskID} taskID
50
- * @param {string} modelName
51
- */
52
- abstract contains(modelType: ModelType, taskID: TaskID, modelName: string): Promise<boolean>;
12
+ abstract getModel(source: ModelSource): Promise<tf.LayersModel>;
13
+ abstract deleteModel(source: ModelSource): Promise<void>;
14
+ abstract loadModel(source: ModelSource): Promise<void>;
15
+ abstract getModelMetadata(source: ModelSource): Promise<tf.io.ModelArtifactsInfo | undefined>;
16
+ abstract updateWorkingModel(source: ModelSource, model: tf.LayersModel): Promise<void>;
17
+ abstract saveWorkingModel(source: ModelSource): Promise<void>;
18
+ abstract downloadModel(source: ModelSource): Promise<void>;
19
+ abstract contains(source: ModelSource): Promise<boolean>;
20
+ abstract pathFor(source: ModelSource): Path | undefined;
21
+ abstract infoFor(source: ModelSource): ModelInfo | undefined;
53
22
  }
@@ -1,12 +1,14 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { Memory } from './base';
1
+ import { tf } from '..';
2
+ import { Memory, ModelInfo, Path } from './base';
3
3
  export declare class Empty extends Memory {
4
4
  getModelMetadata(): Promise<undefined>;
5
5
  contains(): Promise<boolean>;
6
6
  getModel(): Promise<tf.LayersModel>;
7
- loadSavedModel(): Promise<void>;
7
+ loadModel(): Promise<void>;
8
8
  updateWorkingModel(): Promise<void>;
9
9
  saveWorkingModel(): Promise<void>;
10
10
  deleteModel(): Promise<void>;
11
- downloadSavedModel(): Promise<void>;
11
+ downloadModel(): Promise<void>;
12
+ pathFor(): Path;
13
+ infoFor(): ModelInfo;
12
14
  }
@@ -29,7 +29,7 @@ var Empty = /** @class */ (function (_super) {
29
29
  });
30
30
  });
31
31
  };
32
- Empty.prototype.loadSavedModel = function () {
32
+ Empty.prototype.loadModel = function () {
33
33
  return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
34
34
  return (0, tslib_1.__generator)(this, function (_a) {
35
35
  throw new Error('empty');
@@ -57,13 +57,19 @@ var Empty = /** @class */ (function (_super) {
57
57
  });
58
58
  });
59
59
  };
60
- Empty.prototype.downloadSavedModel = function () {
60
+ Empty.prototype.downloadModel = function () {
61
61
  return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
62
62
  return (0, tslib_1.__generator)(this, function (_a) {
63
63
  throw new Error('empty');
64
64
  });
65
65
  });
66
66
  };
67
+ Empty.prototype.pathFor = function () {
68
+ throw new Error('empty');
69
+ };
70
+ Empty.prototype.infoFor = function () {
71
+ throw new Error('empty');
72
+ };
67
73
  return Empty;
68
74
  }(base_1.Memory));
69
75
  exports.Empty = Empty;
@@ -1,3 +1,3 @@
1
1
  export { Empty } from './empty';
2
- export { Memory } from './base';
2
+ export { Memory, ModelInfo, Path, ModelSource } from './base';
3
3
  export { ModelType } from './model_type';