@epfml/discojs 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/README.md +32 -0
  2. package/dist/aggregation.d.ts +3 -0
  3. package/dist/aggregation.js +19 -0
  4. package/dist/async_buffer.d.ts +41 -0
  5. package/dist/async_buffer.js +98 -0
  6. package/dist/async_informant.d.ts +20 -0
  7. package/dist/async_informant.js +69 -0
  8. package/dist/client/base.d.ts +36 -0
  9. package/dist/client/base.js +34 -0
  10. package/dist/client/decentralized.d.ts +23 -0
  11. package/dist/client/decentralized.js +275 -0
  12. package/dist/client/federated.d.ts +30 -0
  13. package/dist/client/federated.js +221 -0
  14. package/dist/client/index.d.ts +4 -0
  15. package/dist/client/index.js +11 -0
  16. package/dist/client/local.d.ts +8 -0
  17. package/dist/client/local.js +34 -0
  18. package/dist/dataset/data_loader/data_loader.d.ts +16 -0
  19. package/dist/dataset/data_loader/data_loader.js +10 -0
  20. package/dist/dataset/data_loader/image_loader.d.ts +14 -0
  21. package/dist/dataset/data_loader/image_loader.js +93 -0
  22. package/dist/dataset/data_loader/index.d.ts +3 -0
  23. package/dist/dataset/data_loader/index.js +9 -0
  24. package/dist/dataset/data_loader/tabular_loader.d.ts +29 -0
  25. package/dist/dataset/data_loader/tabular_loader.js +88 -0
  26. package/dist/dataset/dataset_builder.d.ts +17 -0
  27. package/dist/dataset/dataset_builder.js +80 -0
  28. package/dist/dataset/index.d.ts +2 -0
  29. package/dist/dataset/index.js +7 -0
  30. package/dist/index.d.ts +17 -0
  31. package/dist/index.js +34 -0
  32. package/dist/logging/console_logger.d.ts +18 -0
  33. package/dist/logging/console_logger.js +33 -0
  34. package/dist/logging/index.d.ts +2 -0
  35. package/dist/logging/index.js +7 -0
  36. package/dist/logging/logger.d.ts +12 -0
  37. package/dist/logging/logger.js +9 -0
  38. package/dist/logging/trainer_logger.d.ts +24 -0
  39. package/dist/logging/trainer_logger.js +59 -0
  40. package/dist/memory/base.d.ts +53 -0
  41. package/dist/memory/base.js +9 -0
  42. package/dist/memory/empty.d.ts +12 -0
  43. package/dist/memory/empty.js +69 -0
  44. package/dist/memory/index.d.ts +3 -0
  45. package/dist/memory/index.js +9 -0
  46. package/dist/memory/model_type.d.ts +4 -0
  47. package/dist/memory/model_type.js +9 -0
  48. package/dist/model_actor.d.ts +16 -0
  49. package/dist/model_actor.js +20 -0
  50. package/dist/privacy.d.ts +12 -0
  51. package/dist/privacy.js +60 -0
  52. package/dist/serialization/index.d.ts +2 -0
  53. package/dist/serialization/index.js +6 -0
  54. package/dist/serialization/model.d.ts +5 -0
  55. package/dist/serialization/model.js +55 -0
  56. package/dist/serialization/weights.d.ts +5 -0
  57. package/dist/serialization/weights.js +62 -0
  58. package/dist/task/data_example.d.ts +5 -0
  59. package/dist/task/data_example.js +24 -0
  60. package/dist/task/display_information.d.ts +15 -0
  61. package/dist/task/display_information.js +53 -0
  62. package/dist/task/index.d.ts +3 -0
  63. package/dist/task/index.js +8 -0
  64. package/dist/task/model_compile_data.d.ts +6 -0
  65. package/dist/task/model_compile_data.js +12 -0
  66. package/dist/task/task.d.ts +10 -0
  67. package/dist/task/task.js +32 -0
  68. package/dist/task/training_information.d.ts +29 -0
  69. package/dist/task/training_information.js +2 -0
  70. package/dist/tasks/cifar10.d.ts +4 -0
  71. package/dist/tasks/cifar10.js +74 -0
  72. package/dist/tasks/index.d.ts +5 -0
  73. package/dist/tasks/index.js +9 -0
  74. package/dist/tasks/lus_covid.d.ts +4 -0
  75. package/dist/tasks/lus_covid.js +48 -0
  76. package/dist/tasks/mnist.d.ts +4 -0
  77. package/dist/tasks/mnist.js +56 -0
  78. package/dist/tasks/simple_face.d.ts +4 -0
  79. package/dist/tasks/simple_face.js +88 -0
  80. package/dist/tasks/titanic.d.ts +4 -0
  81. package/dist/tasks/titanic.js +86 -0
  82. package/dist/testing/tester.d.ts +5 -0
  83. package/dist/testing/tester.js +21 -0
  84. package/dist/training/disco.d.ts +12 -0
  85. package/dist/training/disco.js +62 -0
  86. package/dist/training/index.d.ts +2 -0
  87. package/dist/training/index.js +7 -0
  88. package/dist/training/trainer/distributed_trainer.d.ts +21 -0
  89. package/dist/training/trainer/distributed_trainer.js +60 -0
  90. package/dist/training/trainer/local_trainer.d.ts +10 -0
  91. package/dist/training/trainer/local_trainer.js +37 -0
  92. package/dist/training/trainer/round_tracker.d.ts +30 -0
  93. package/dist/training/trainer/round_tracker.js +44 -0
  94. package/dist/training/trainer/trainer.d.ts +66 -0
  95. package/dist/training/trainer/trainer.js +146 -0
  96. package/dist/training/trainer/trainer_builder.d.ts +25 -0
  97. package/dist/training/trainer/trainer_builder.js +102 -0
  98. package/dist/training/training_schemes.d.ts +5 -0
  99. package/dist/training/training_schemes.js +10 -0
  100. package/dist/training_informant.d.ts +88 -0
  101. package/dist/training_informant.js +135 -0
  102. package/dist/types.d.ts +4 -0
  103. package/dist/types.js +2 -0
  104. package/package.json +48 -0
@@ -0,0 +1,102 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.TrainerBuilder = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var __1 = require("../..");
6
+ var distributed_trainer_1 = require("./distributed_trainer");
7
+ var local_trainer_1 = require("./local_trainer");
8
+ /**
9
+ * A class that helps build the Trainer and auxiliary classes.
10
+ */
11
+ var TrainerBuilder = /** @class */ (function () {
12
+ function TrainerBuilder(memory, task, trainingInformant) {
13
+ this.memory = memory;
14
+ this.task = task;
15
+ this.trainingInformant = trainingInformant;
16
+ }
17
+ /**
18
+ * Builds a trainer object.
19
+ *
20
+ * @param client client to share weights with (either distributed or federated)
21
+ * @param distributed whether to build a distributed or local trainer
22
+ * @returns
23
+ */
24
+ TrainerBuilder.prototype.build = function (client, distributed) {
25
+ if (distributed === void 0) { distributed = false; }
26
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
27
+ var model;
28
+ return (0, tslib_1.__generator)(this, function (_a) {
29
+ switch (_a.label) {
30
+ case 0: return [4 /*yield*/, this.getModel(client)];
31
+ case 1:
32
+ model = _a.sent();
33
+ if (distributed) {
34
+ return [2 /*return*/, new distributed_trainer_1.DistributedTrainer(this.task, this.trainingInformant, this.memory, model, model, client)];
35
+ }
36
+ else {
37
+ return [2 /*return*/, new local_trainer_1.LocalTrainer(this.task, this.trainingInformant, this.memory, model)];
38
+ }
39
+ return [2 /*return*/];
40
+ }
41
+ });
42
+ });
43
+ };
44
+ /**
45
+ * If a model exists in memory, laod it, otherwise load model from server
46
+ * @returns
47
+ */
48
+ TrainerBuilder.prototype.getModel = function (client) {
49
+ var _a;
50
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
51
+ var modelID, modelExistsInMemory, model;
52
+ return (0, tslib_1.__generator)(this, function (_b) {
53
+ switch (_b.label) {
54
+ case 0:
55
+ modelID = (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.modelID;
56
+ if (modelID === undefined) {
57
+ throw new Error('undefined model ID');
58
+ }
59
+ return [4 /*yield*/, this.memory.getModelMetadata(__1.ModelType.WORKING, this.task.taskID, modelID)];
60
+ case 1:
61
+ modelExistsInMemory = _b.sent();
62
+ if (!(modelExistsInMemory !== undefined)) return [3 /*break*/, 3];
63
+ return [4 /*yield*/, this.memory.getModel(__1.ModelType.WORKING, this.task.taskID, modelID)];
64
+ case 2:
65
+ model = _b.sent();
66
+ return [3 /*break*/, 5];
67
+ case 3: return [4 /*yield*/, client.getLatestModel()];
68
+ case 4:
69
+ model = _b.sent();
70
+ _b.label = 5;
71
+ case 5: return [4 /*yield*/, this.updateModelInformation(model)];
72
+ case 6: return [2 /*return*/, _b.sent()];
73
+ }
74
+ });
75
+ });
76
+ };
77
+ TrainerBuilder.prototype.updateModelInformation = function (model) {
78
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
79
+ var info;
80
+ return (0, tslib_1.__generator)(this, function (_a) {
81
+ // Continue local training from previous epoch checkpoint
82
+ if (model.getUserDefinedMetadata() === undefined) {
83
+ model.setUserDefinedMetadata({ epoch: 0 });
84
+ }
85
+ info = this.task.trainingInformation;
86
+ if (info === undefined) {
87
+ throw new Error('undefined training information');
88
+ }
89
+ model.compile(info.modelCompileData);
90
+ if (info.learningRate !== undefined) {
91
+ // TODO: Not the right way to change learningRate and hence we cast to any
92
+ // the right way is to construct the optimiser and pass learningRate via
93
+ // argument.
94
+ model.optimizer.learningRate = info.learningRate;
95
+ }
96
+ return [2 /*return*/, model];
97
+ });
98
+ });
99
+ };
100
+ return TrainerBuilder;
101
+ }());
102
+ exports.TrainerBuilder = TrainerBuilder;
@@ -0,0 +1,5 @@
1
+ export declare enum TrainingSchemes {
2
+ LOCAL = "local",
3
+ DECENTRALIZED = "deai",
4
+ FEDERATED = "feai"
5
+ }
@@ -0,0 +1,10 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.TrainingSchemes = void 0;
4
+ /* eslint-disable no-unused-vars */
5
+ var TrainingSchemes;
6
+ (function (TrainingSchemes) {
7
+ TrainingSchemes["LOCAL"] = "local";
8
+ TrainingSchemes["DECENTRALIZED"] = "deai";
9
+ TrainingSchemes["FEDERATED"] = "feai";
10
+ })(TrainingSchemes = exports.TrainingSchemes || (exports.TrainingSchemes = {}));
@@ -0,0 +1,88 @@
1
+ import { List, Set } from 'immutable';
2
+ import { TaskID } from '.';
3
+ import { TrainingSchemes } from './training/training_schemes';
4
+ /**
5
+ * Class that collects information about the status of the training-loop of the model.
6
+ */
7
+ export declare class TrainingInformant {
8
+ private readonly nbrMessagesToShow;
9
+ readonly taskID: TaskID;
10
+ readonly taskTrainingScheme: TrainingSchemes;
11
+ whoReceivedMyModel: Set<unknown>;
12
+ nbrUpdatesWithOthers: number;
13
+ waitingTime: number;
14
+ nbrWeightRequests: number;
15
+ messages: List<unknown>;
16
+ currentRound: number;
17
+ currentNumberOfParticipants: number;
18
+ totalNumberOfParticipants: number;
19
+ averageNumberOfParticipants: number;
20
+ validationAccuracyChart: unknown;
21
+ validationAccuracy: number;
22
+ trainingAccuracyChart: unknown;
23
+ trainingAccuracy: number;
24
+ displayHeatmap: boolean;
25
+ currentValidationAccuracy: number;
26
+ validationAccuracyDataSerie: List<number>;
27
+ currentTrainingAccuracy: number;
28
+ trainingAccuracyDataSerie: List<number>;
29
+ weightsIn: number;
30
+ weightsOut: number;
31
+ /**
32
+ *
33
+ * @param nbrMessagesToShow the number of messages to be kept to inform the users about status of communication with other peers.
34
+ * @param taskID the task's name.
35
+ */
36
+ constructor(nbrMessagesToShow: number, taskID: TaskID, taskTrainingScheme: TrainingSchemes);
37
+ /**
38
+ * Updates the set of peers who received my model.
39
+ * @param {String} peerName the peer's name to whom I recently shared my model to.
40
+ */
41
+ updateWhoReceivedMyModel(peerName: string): void;
42
+ /**
43
+ * Updates the number of updates I did with other peers.
44
+ * @param {Number} nbrUpdates the number of updates I did thanks to other peers contribution since the last update of the parameter.
45
+ */
46
+ updateNbrUpdatesWithOthers(nbrUpdates: number): void;
47
+ /**
48
+ * Updates the time I waited to receive weights.
49
+ * @param {Number} time
50
+ */
51
+ updateWaitingTime(time: number): void;
52
+ /**
53
+ * Updates the number of weights request I received.
54
+ * @param {Number} nbrRequests the number of weight requests I received since the last update of the parameter.
55
+ */
56
+ updateNbrWeightsRequests(nbrRequests: number): void;
57
+ /**
58
+ * Add a new message to the message list.
59
+ * @param {String} msg a message.
60
+ */
61
+ addMessage(msg: string): void;
62
+ /**
63
+ * Update the server statistics with the JSON received from the server
64
+ * For now it's just the JSON, but we might want to keep it as a dictionnary
65
+ * @param {any} receivedStatistics statistics received from the server.
66
+ */
67
+ updateWithServerStatistics(receivedStatistics: Record<string, number>): void;
68
+ /**
69
+ * Updates the data to be displayed on the validation accuracy graph.
70
+ * @param {Number} validationAccuracy the current validation accuracy of the model
71
+ */
72
+ updateValidationAccuracyGraph(validationAccuracy: number): void;
73
+ /**
74
+ * Returns wether or not the Task's training scheme is Decentralized
75
+ * @returns Boolean value
76
+ */
77
+ isTaskTrainingSchemeDecentralized(): boolean;
78
+ /**
79
+ * Returns wether or not the Task's training scheme is Federated
80
+ * @returns Boolean value
81
+ */
82
+ isTaskTrainingSchemeFederated(): boolean;
83
+ /**
84
+ * Updates the data to be displayed on the training accuracy graph.
85
+ * @param {Number} trainingAccuracy the current training accuracy of the model
86
+ */
87
+ updateTrainingAccuracyGraph(trainingAccuracy: number): void;
88
+ }
@@ -0,0 +1,135 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.TrainingInformant = void 0;
4
+ var immutable_1 = require("immutable");
5
+ var training_schemes_1 = require("./training/training_schemes");
6
+ var nbEpochsOnGraphs = 10;
7
+ /**
8
+ * Class that collects information about the status of the training-loop of the model.
9
+ */
10
+ var TrainingInformant = /** @class */ (function () {
11
+ /**
12
+ *
13
+ * @param nbrMessagesToShow the number of messages to be kept to inform the users about status of communication with other peers.
14
+ * @param taskID the task's name.
15
+ */
16
+ function TrainingInformant(nbrMessagesToShow, taskID, taskTrainingScheme) {
17
+ this.nbrMessagesToShow = nbrMessagesToShow;
18
+ this.taskID = taskID;
19
+ this.taskTrainingScheme = taskTrainingScheme;
20
+ // Decentralized Informations
21
+ // number of people with whom I've shared my model
22
+ this.whoReceivedMyModel = (0, immutable_1.Set)();
23
+ // message feedback from peer-to-peer training
24
+ this.messages = (0, immutable_1.List)();
25
+ this.validationAccuracyDataSerie = (0, immutable_1.Repeat)(0, nbEpochsOnGraphs).toList();
26
+ this.trainingAccuracyDataSerie = (0, immutable_1.Repeat)(0, nbEpochsOnGraphs).toList();
27
+ this.weightsIn = 0;
28
+ this.weightsOut = 0;
29
+ // how many times the model has been averaged with someone's else model
30
+ this.nbrUpdatesWithOthers = 0;
31
+ // how much time I've been waiting for a model
32
+ this.waitingTime = 0;
33
+ // number of weight requests I've responded to
34
+ this.nbrWeightRequests = 0;
35
+ // statistics received from the server
36
+ this.currentRound = 0;
37
+ this.currentNumberOfParticipants = 0;
38
+ this.totalNumberOfParticipants = 0;
39
+ this.averageNumberOfParticipants = 0;
40
+ // validation accuracy chart
41
+ this.validationAccuracyChart = null; // new TrainingChart("validationAccuracy_".concat(taskID), "Validation Accuracy")
42
+ this.validationAccuracy = 0;
43
+ // training accuracy chart
44
+ this.trainingAccuracyChart = null; // new TrainingChart("trainingAccuracy_".concat(taskID), "Training Accuracy")
45
+ this.trainingAccuracy = 0;
46
+ // is the model using Interoperability (default to false)
47
+ this.displayHeatmap = false;
48
+ // default values for the validation and training charts
49
+ this.currentValidationAccuracy = 0;
50
+ this.currentTrainingAccuracy = 0;
51
+ }
52
+ /**
53
+ * Updates the set of peers who received my model.
54
+ * @param {String} peerName the peer's name to whom I recently shared my model to.
55
+ */
56
+ TrainingInformant.prototype.updateWhoReceivedMyModel = function (peerName) {
57
+ this.whoReceivedMyModel = this.whoReceivedMyModel.add(peerName);
58
+ };
59
+ /**
60
+ * Updates the number of updates I did with other peers.
61
+ * @param {Number} nbrUpdates the number of updates I did thanks to other peers contribution since the last update of the parameter.
62
+ */
63
+ TrainingInformant.prototype.updateNbrUpdatesWithOthers = function (nbrUpdates) {
64
+ this.nbrUpdatesWithOthers += nbrUpdates;
65
+ };
66
+ /**
67
+ * Updates the time I waited to receive weights.
68
+ * @param {Number} time
69
+ */
70
+ TrainingInformant.prototype.updateWaitingTime = function (time) {
71
+ this.waitingTime += time;
72
+ };
73
+ /**
74
+ * Updates the number of weights request I received.
75
+ * @param {Number} nbrRequests the number of weight requests I received since the last update of the parameter.
76
+ */
77
+ TrainingInformant.prototype.updateNbrWeightsRequests = function (nbrRequests) {
78
+ this.nbrWeightRequests += nbrRequests;
79
+ };
80
+ /**
81
+ * Add a new message to the message list.
82
+ * @param {String} msg a message.
83
+ */
84
+ TrainingInformant.prototype.addMessage = function (msg) {
85
+ if (this.messages.size >= this.nbrMessagesToShow) {
86
+ this.messages = this.messages.shift();
87
+ }
88
+ this.messages = this.messages.push(msg);
89
+ };
90
+ /**
91
+ * Update the server statistics with the JSON received from the server
92
+ * For now it's just the JSON, but we might want to keep it as a dictionnary
93
+ * @param {any} receivedStatistics statistics received from the server.
94
+ */
95
+ TrainingInformant.prototype.updateWithServerStatistics = function (receivedStatistics) {
96
+ this.currentRound = receivedStatistics.round;
97
+ this.currentNumberOfParticipants = receivedStatistics.currentNumberOfParticipants;
98
+ this.totalNumberOfParticipants = receivedStatistics.totalNumberOfParticipants;
99
+ this.averageNumberOfParticipants = receivedStatistics.averageNumberOfParticipants;
100
+ };
101
+ /**
102
+ * Updates the data to be displayed on the validation accuracy graph.
103
+ * @param {Number} validationAccuracy the current validation accuracy of the model
104
+ */
105
+ TrainingInformant.prototype.updateValidationAccuracyGraph = function (validationAccuracy) {
106
+ this.validationAccuracyDataSerie =
107
+ this.validationAccuracyDataSerie.shift().push(validationAccuracy);
108
+ this.currentValidationAccuracy = validationAccuracy;
109
+ };
110
+ /**
111
+ * Returns wether or not the Task's training scheme is Decentralized
112
+ * @returns Boolean value
113
+ */
114
+ TrainingInformant.prototype.isTaskTrainingSchemeDecentralized = function () {
115
+ return this.taskTrainingScheme === training_schemes_1.TrainingSchemes.DECENTRALIZED;
116
+ };
117
+ /**
118
+ * Returns wether or not the Task's training scheme is Federated
119
+ * @returns Boolean value
120
+ */
121
+ TrainingInformant.prototype.isTaskTrainingSchemeFederated = function () {
122
+ return this.taskTrainingScheme === training_schemes_1.TrainingSchemes.FEDERATED;
123
+ };
124
+ /**
125
+ * Updates the data to be displayed on the training accuracy graph.
126
+ * @param {Number} trainingAccuracy the current training accuracy of the model
127
+ */
128
+ TrainingInformant.prototype.updateTrainingAccuracyGraph = function (trainingAccuracy) {
129
+ this.trainingAccuracyDataSerie =
130
+ this.trainingAccuracyDataSerie.shift().push(trainingAccuracy);
131
+ this.currentTrainingAccuracy = trainingAccuracy;
132
+ };
133
+ return TrainingInformant;
134
+ }());
135
+ exports.TrainingInformant = TrainingInformant;
@@ -0,0 +1,4 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ export declare type Path = string;
3
+ export declare type Weights = tf.Tensor[];
4
+ export declare type MetadataID = string;
package/dist/types.js ADDED
@@ -0,0 +1,2 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
package/package.json ADDED
@@ -0,0 +1,48 @@
1
+ {
2
+ "name": "@epfml/discojs",
3
+ "version": "0.0.1",
4
+ "main": "dist/index.js",
5
+ "types": "dist/index.d.ts",
6
+ "scripts": {
7
+ "build": "tsc",
8
+ "test": "mocha"
9
+ },
10
+ "repository": {
11
+ "type": "git",
12
+ "url": "git+https://github.com/epfml/disco.git"
13
+ },
14
+ "bugs": {
15
+ "url": "https://github.com/epfml/disco/issues"
16
+ },
17
+ "homepage": "https://github.com/epfml/disco#readme",
18
+ "dependencies": {
19
+ "@tensorflow/tfjs": "3",
20
+ "axios": "0.27",
21
+ "immutable": "4",
22
+ "isomorphic-ws": "4",
23
+ "lodash": "4",
24
+ "msgpack-lite": "0.1",
25
+ "simple-peer": "9",
26
+ "tslib": "2",
27
+ "url": "0.11",
28
+ "uuid": "8",
29
+ "ws": "8"
30
+ },
31
+ "devDependencies": {
32
+ "@tensorflow/tfjs-node": "3",
33
+ "@types/chai": "4",
34
+ "@types/lodash": "4",
35
+ "@types/mocha": "9",
36
+ "@types/msgpack-lite": "0.1",
37
+ "@types/simple-peer": "9",
38
+ "@types/uuid": "8",
39
+ "@typescript-eslint/eslint-plugin": "4",
40
+ "@typescript-eslint/parser": "4",
41
+ "chai": "4",
42
+ "eslint": "7",
43
+ "eslint-config-standard-with-typescript": "21",
44
+ "mocha": "9",
45
+ "ts-node": "10",
46
+ "typescript": "<4.5.0"
47
+ }
48
+ }