@epfml/discojs-node 2.1.1 → 2.1.2-p20240506085559.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. package/dist/data/image_loader.d.ts +5 -0
  2. package/dist/data/image_loader.js +11 -0
  3. package/dist/data/index.d.ts +3 -0
  4. package/dist/data/index.js +3 -0
  5. package/dist/data/tabular_loader.d.ts +4 -0
  6. package/dist/data/tabular_loader.js +11 -0
  7. package/dist/data/text_loader.d.ts +4 -0
  8. package/dist/data/text_loader.js +14 -0
  9. package/dist/index.d.ts +2 -2
  10. package/dist/index.js +2 -6
  11. package/package.json +13 -16
  12. package/README.md +0 -53
  13. package/dist/core/async_buffer.d.ts +0 -41
  14. package/dist/core/async_buffer.js +0 -97
  15. package/dist/core/async_informant.d.ts +0 -20
  16. package/dist/core/async_informant.js +0 -69
  17. package/dist/core/client/base.d.ts +0 -33
  18. package/dist/core/client/base.js +0 -35
  19. package/dist/core/client/decentralized/base.d.ts +0 -32
  20. package/dist/core/client/decentralized/base.js +0 -212
  21. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  22. package/dist/core/client/decentralized/clear_text.js +0 -96
  23. package/dist/core/client/decentralized/index.d.ts +0 -4
  24. package/dist/core/client/decentralized/index.js +0 -9
  25. package/dist/core/client/decentralized/messages.d.ts +0 -41
  26. package/dist/core/client/decentralized/messages.js +0 -54
  27. package/dist/core/client/decentralized/peer.d.ts +0 -26
  28. package/dist/core/client/decentralized/peer.js +0 -210
  29. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  30. package/dist/core/client/decentralized/peer_pool.js +0 -92
  31. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  32. package/dist/core/client/decentralized/sec_agg.js +0 -190
  33. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  34. package/dist/core/client/decentralized/secret_shares.js +0 -39
  35. package/dist/core/client/decentralized/types.d.ts +0 -2
  36. package/dist/core/client/decentralized/types.js +0 -7
  37. package/dist/core/client/event_connection.d.ts +0 -37
  38. package/dist/core/client/event_connection.js +0 -158
  39. package/dist/core/client/federated/client.d.ts +0 -37
  40. package/dist/core/client/federated/client.js +0 -273
  41. package/dist/core/client/federated/index.d.ts +0 -2
  42. package/dist/core/client/federated/index.js +0 -7
  43. package/dist/core/client/federated/messages.d.ts +0 -38
  44. package/dist/core/client/federated/messages.js +0 -25
  45. package/dist/core/client/index.d.ts +0 -5
  46. package/dist/core/client/index.js +0 -11
  47. package/dist/core/client/local.d.ts +0 -8
  48. package/dist/core/client/local.js +0 -36
  49. package/dist/core/client/messages.d.ts +0 -28
  50. package/dist/core/client/messages.js +0 -33
  51. package/dist/core/client/utils.d.ts +0 -2
  52. package/dist/core/client/utils.js +0 -19
  53. package/dist/core/dataset/data/data.d.ts +0 -11
  54. package/dist/core/dataset/data/data.js +0 -20
  55. package/dist/core/dataset/data/data_split.d.ts +0 -5
  56. package/dist/core/dataset/data/data_split.js +0 -2
  57. package/dist/core/dataset/data/image_data.d.ts +0 -8
  58. package/dist/core/dataset/data/image_data.js +0 -64
  59. package/dist/core/dataset/data/index.d.ts +0 -5
  60. package/dist/core/dataset/data/index.js +0 -11
  61. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  62. package/dist/core/dataset/data/preprocessing.js +0 -33
  63. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  64. package/dist/core/dataset/data/tabular_data.js +0 -40
  65. package/dist/core/dataset/data_loader/data_loader.d.ts +0 -15
  66. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  67. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  68. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  69. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  70. package/dist/core/dataset/data_loader/index.js +0 -9
  71. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  72. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  73. package/dist/core/dataset/dataset.d.ts +0 -2
  74. package/dist/core/dataset/dataset.js +0 -2
  75. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  76. package/dist/core/dataset/dataset_builder.js +0 -96
  77. package/dist/core/dataset/index.d.ts +0 -4
  78. package/dist/core/dataset/index.js +0 -14
  79. package/dist/core/default_tasks/cifar10.d.ts +0 -2
  80. package/dist/core/default_tasks/cifar10.js +0 -68
  81. package/dist/core/default_tasks/geotags.d.ts +0 -2
  82. package/dist/core/default_tasks/geotags.js +0 -69
  83. package/dist/core/default_tasks/index.d.ts +0 -6
  84. package/dist/core/default_tasks/index.js +0 -15
  85. package/dist/core/default_tasks/lus_covid.d.ts +0 -2
  86. package/dist/core/default_tasks/lus_covid.js +0 -96
  87. package/dist/core/default_tasks/mnist.d.ts +0 -2
  88. package/dist/core/default_tasks/mnist.js +0 -69
  89. package/dist/core/default_tasks/simple_face.d.ts +0 -2
  90. package/dist/core/default_tasks/simple_face.js +0 -53
  91. package/dist/core/default_tasks/titanic.d.ts +0 -2
  92. package/dist/core/default_tasks/titanic.js +0 -97
  93. package/dist/core/index.d.ts +0 -18
  94. package/dist/core/index.js +0 -39
  95. package/dist/core/informant/graph_informant.d.ts +0 -10
  96. package/dist/core/informant/graph_informant.js +0 -23
  97. package/dist/core/informant/index.d.ts +0 -3
  98. package/dist/core/informant/index.js +0 -9
  99. package/dist/core/informant/training_informant/base.d.ts +0 -31
  100. package/dist/core/informant/training_informant/base.js +0 -83
  101. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  102. package/dist/core/informant/training_informant/decentralized.js +0 -22
  103. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  104. package/dist/core/informant/training_informant/federated.js +0 -32
  105. package/dist/core/informant/training_informant/index.d.ts +0 -4
  106. package/dist/core/informant/training_informant/index.js +0 -11
  107. package/dist/core/informant/training_informant/local.d.ts +0 -6
  108. package/dist/core/informant/training_informant/local.js +0 -20
  109. package/dist/core/logging/console_logger.d.ts +0 -18
  110. package/dist/core/logging/console_logger.js +0 -33
  111. package/dist/core/logging/index.d.ts +0 -3
  112. package/dist/core/logging/index.js +0 -9
  113. package/dist/core/logging/logger.d.ts +0 -12
  114. package/dist/core/logging/logger.js +0 -9
  115. package/dist/core/logging/trainer_logger.d.ts +0 -24
  116. package/dist/core/logging/trainer_logger.js +0 -59
  117. package/dist/core/memory/base.d.ts +0 -22
  118. package/dist/core/memory/base.js +0 -9
  119. package/dist/core/memory/empty.d.ts +0 -14
  120. package/dist/core/memory/empty.js +0 -75
  121. package/dist/core/memory/index.d.ts +0 -3
  122. package/dist/core/memory/index.js +0 -9
  123. package/dist/core/memory/model_type.d.ts +0 -4
  124. package/dist/core/memory/model_type.js +0 -9
  125. package/dist/core/privacy.d.ts +0 -11
  126. package/dist/core/privacy.js +0 -47
  127. package/dist/core/serialization/index.d.ts +0 -2
  128. package/dist/core/serialization/index.js +0 -6
  129. package/dist/core/serialization/model.d.ts +0 -5
  130. package/dist/core/serialization/model.js +0 -55
  131. package/dist/core/serialization/weights.d.ts +0 -5
  132. package/dist/core/serialization/weights.js +0 -64
  133. package/dist/core/task/data_example.d.ts +0 -5
  134. package/dist/core/task/data_example.js +0 -24
  135. package/dist/core/task/digest.d.ts +0 -5
  136. package/dist/core/task/digest.js +0 -18
  137. package/dist/core/task/display_information.d.ts +0 -15
  138. package/dist/core/task/display_information.js +0 -49
  139. package/dist/core/task/index.d.ts +0 -6
  140. package/dist/core/task/index.js +0 -15
  141. package/dist/core/task/model_compile_data.d.ts +0 -6
  142. package/dist/core/task/model_compile_data.js +0 -22
  143. package/dist/core/task/summary.d.ts +0 -5
  144. package/dist/core/task/summary.js +0 -19
  145. package/dist/core/task/task.d.ts +0 -12
  146. package/dist/core/task/task.js +0 -35
  147. package/dist/core/task/task_handler.d.ts +0 -5
  148. package/dist/core/task/task_handler.js +0 -53
  149. package/dist/core/task/task_provider.d.ts +0 -6
  150. package/dist/core/task/task_provider.js +0 -13
  151. package/dist/core/task/training_information.d.ts +0 -28
  152. package/dist/core/task/training_information.js +0 -66
  153. package/dist/core/training/disco.d.ts +0 -23
  154. package/dist/core/training/disco.js +0 -130
  155. package/dist/core/training/index.d.ts +0 -2
  156. package/dist/core/training/index.js +0 -7
  157. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  158. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  159. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  160. package/dist/core/training/trainer/local_trainer.js +0 -34
  161. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  162. package/dist/core/training/trainer/round_tracker.js +0 -47
  163. package/dist/core/training/trainer/trainer.d.ts +0 -65
  164. package/dist/core/training/trainer/trainer.js +0 -160
  165. package/dist/core/training/trainer/trainer_builder.d.ts +0 -25
  166. package/dist/core/training/trainer/trainer_builder.js +0 -95
  167. package/dist/core/training/training_schemes.d.ts +0 -5
  168. package/dist/core/training/training_schemes.js +0 -10
  169. package/dist/core/types.d.ts +0 -4
  170. package/dist/core/types.js +0 -2
  171. package/dist/core/validation/index.d.ts +0 -1
  172. package/dist/core/validation/index.js +0 -5
  173. package/dist/core/validation/validator.d.ts +0 -17
  174. package/dist/core/validation/validator.js +0 -104
  175. package/dist/core/weights/aggregation.d.ts +0 -7
  176. package/dist/core/weights/aggregation.js +0 -72
  177. package/dist/core/weights/index.d.ts +0 -2
  178. package/dist/core/weights/index.js +0 -7
  179. package/dist/core/weights/weights_container.d.ts +0 -19
  180. package/dist/core/weights/weights_container.js +0 -64
  181. package/dist/dataset/data_loader/image_loader.d.ts +0 -4
  182. package/dist/dataset/data_loader/image_loader.js +0 -21
  183. package/dist/dataset/data_loader/index.d.ts +0 -2
  184. package/dist/dataset/data_loader/index.js +0 -7
  185. package/dist/dataset/data_loader/tabular_loader.d.ts +0 -4
  186. package/dist/dataset/data_loader/tabular_loader.js +0 -20
  187. package/dist/imports.d.ts +0 -1
  188. package/dist/imports.js +0 -5
@@ -1,83 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Base = void 0;
4
- var immutable_1 = require("immutable");
5
- var graph_informant_1 = require("../graph_informant");
6
- var Base = /** @class */ (function () {
7
- function Base(task, nbrMessagesToShow) {
8
- if (nbrMessagesToShow === void 0) { nbrMessagesToShow = 10; }
9
- this.task = task;
10
- this.nbrMessagesToShow = nbrMessagesToShow;
11
- // written feedback
12
- this.messages = (0, immutable_1.List)();
13
- // graph feedback
14
- this.trainingGraphInformant = new graph_informant_1.GraphInformant();
15
- this.validationGraphInformant = new graph_informant_1.GraphInformant();
16
- // statistics
17
- this.currentRound = 0;
18
- this.currentNumberOfParticipants = 0;
19
- this.totalNumberOfParticipants = 0;
20
- this.averageNumberOfParticipants = 0;
21
- }
22
- Base.prototype.addMessage = function (msg) {
23
- if (this.messages.size >= this.nbrMessagesToShow) {
24
- this.messages = this.messages.shift();
25
- }
26
- this.messages = this.messages.push(msg);
27
- };
28
- Base.prototype.getMessages = function () {
29
- return this.messages.toArray();
30
- };
31
- Base.prototype.round = function () {
32
- return this.currentRound;
33
- };
34
- Base.prototype.participants = function () {
35
- return this.currentNumberOfParticipants;
36
- };
37
- Base.prototype.totalParticipants = function () {
38
- return this.totalNumberOfParticipants;
39
- };
40
- Base.prototype.averageParticipants = function () {
41
- return this.averageNumberOfParticipants;
42
- };
43
- Base.prototype.updateTrainingGraph = function (accuracy) {
44
- this.trainingGraphInformant.updateAccuracy(accuracy);
45
- };
46
- Base.prototype.updateValidationGraph = function (accuracy) {
47
- this.validationGraphInformant.updateAccuracy(accuracy);
48
- };
49
- Base.prototype.trainingAccuracy = function () {
50
- return this.trainingGraphInformant.accuracy();
51
- };
52
- Base.prototype.validationAccuracy = function () {
53
- return this.validationGraphInformant.accuracy();
54
- };
55
- Base.prototype.trainingAccuracyData = function () {
56
- return this.trainingGraphInformant.data();
57
- };
58
- Base.prototype.validationAccuracyData = function () {
59
- return this.validationGraphInformant.data();
60
- };
61
- Base.prototype.isDecentralized = function () {
62
- return false;
63
- };
64
- Base.prototype.isFederated = function () {
65
- return false;
66
- };
67
- Base.isTrainingInformant = function (raw) {
68
- if (typeof raw !== 'object') {
69
- return false;
70
- }
71
- if (raw === null) {
72
- return false;
73
- }
74
- // TODO
75
- var requiredFields = (0, immutable_1.Set)();
76
- if (!(requiredFields.every(function (field) { return field in raw; }))) {
77
- return false;
78
- }
79
- return true;
80
- };
81
- return Base;
82
- }());
83
- exports.Base = Base;
@@ -1,5 +0,0 @@
1
- import { Base } from '.';
2
- export declare class DecentralizedInformant extends Base {
3
- update(statistics: Record<string, number>): void;
4
- isDecentralized(): boolean;
5
- }
@@ -1,22 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.DecentralizedInformant = void 0;
4
- var tslib_1 = require("tslib");
5
- var _1 = require(".");
6
- var DecentralizedInformant = /** @class */ (function (_super) {
7
- (0, tslib_1.__extends)(DecentralizedInformant, _super);
8
- function DecentralizedInformant() {
9
- return _super !== null && _super.apply(this, arguments) || this;
10
- }
11
- DecentralizedInformant.prototype.update = function (statistics) {
12
- this.currentRound += 1;
13
- this.currentNumberOfParticipants = statistics.currentNumberOfParticipants;
14
- this.totalNumberOfParticipants += this.currentNumberOfParticipants;
15
- this.averageNumberOfParticipants = this.totalNumberOfParticipants / this.currentRound;
16
- };
17
- DecentralizedInformant.prototype.isDecentralized = function () {
18
- return true;
19
- };
20
- return DecentralizedInformant;
21
- }(_1.Base));
22
- exports.DecentralizedInformant = DecentralizedInformant;
@@ -1,14 +0,0 @@
1
- import { Base } from '.';
2
- /**
3
- * Class that collects information about the status of the training-loop of the model.
4
- */
5
- export declare class FederatedInformant extends Base {
6
- displayHeatmap: boolean;
7
- /**
8
- * Update the server statistics with the JSON received from the server
9
- * For now it's just the JSON, but we might want to keep it as a dictionary
10
- * @param receivedStatistics statistics received from the server.
11
- */
12
- update(receivedStatistics: Record<string, number>): void;
13
- isFederated(): boolean;
14
- }
@@ -1,32 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.FederatedInformant = void 0;
4
- var tslib_1 = require("tslib");
5
- var _1 = require(".");
6
- /**
7
- * Class that collects information about the status of the training-loop of the model.
8
- */
9
- var FederatedInformant = /** @class */ (function (_super) {
10
- (0, tslib_1.__extends)(FederatedInformant, _super);
11
- function FederatedInformant() {
12
- var _this = _super !== null && _super.apply(this, arguments) || this;
13
- _this.displayHeatmap = false;
14
- return _this;
15
- }
16
- /**
17
- * Update the server statistics with the JSON received from the server
18
- * For now it's just the JSON, but we might want to keep it as a dictionary
19
- * @param receivedStatistics statistics received from the server.
20
- */
21
- FederatedInformant.prototype.update = function (receivedStatistics) {
22
- this.currentRound = receivedStatistics.round;
23
- this.currentNumberOfParticipants = receivedStatistics.currentNumberOfParticipants;
24
- this.totalNumberOfParticipants = receivedStatistics.totalNumberOfParticipants;
25
- this.averageNumberOfParticipants = receivedStatistics.averageNumberOfParticipants;
26
- };
27
- FederatedInformant.prototype.isFederated = function () {
28
- return true;
29
- };
30
- return FederatedInformant;
31
- }(_1.Base));
32
- exports.FederatedInformant = FederatedInformant;
@@ -1,4 +0,0 @@
1
- export { Base } from './base';
2
- export { FederatedInformant } from './federated';
3
- export { DecentralizedInformant } from './decentralized';
4
- export { LocalInformant } from './local';
@@ -1,11 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.LocalInformant = exports.DecentralizedInformant = exports.FederatedInformant = exports.Base = void 0;
4
- var base_1 = require("./base");
5
- Object.defineProperty(exports, "Base", { enumerable: true, get: function () { return base_1.Base; } });
6
- var federated_1 = require("./federated");
7
- Object.defineProperty(exports, "FederatedInformant", { enumerable: true, get: function () { return federated_1.FederatedInformant; } });
8
- var decentralized_1 = require("./decentralized");
9
- Object.defineProperty(exports, "DecentralizedInformant", { enumerable: true, get: function () { return decentralized_1.DecentralizedInformant; } });
10
- var local_1 = require("./local");
11
- Object.defineProperty(exports, "LocalInformant", { enumerable: true, get: function () { return local_1.LocalInformant; } });
@@ -1,6 +0,0 @@
1
- import { Task } from '../../task';
2
- import { Base } from '.';
3
- export declare class LocalInformant extends Base {
4
- constructor(task: Task, nbrMessagesToShow?: number);
5
- update(statistics: Record<string, number>): void;
6
- }
@@ -1,20 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.LocalInformant = void 0;
4
- var tslib_1 = require("tslib");
5
- var _1 = require(".");
6
- var LocalInformant = /** @class */ (function (_super) {
7
- (0, tslib_1.__extends)(LocalInformant, _super);
8
- function LocalInformant(task, nbrMessagesToShow) {
9
- var _this = _super.call(this, task, nbrMessagesToShow) || this;
10
- _this.currentNumberOfParticipants = 1;
11
- _this.averageNumberOfParticipants = 1;
12
- _this.totalNumberOfParticipants = 1;
13
- return _this;
14
- }
15
- LocalInformant.prototype.update = function (statistics) {
16
- this.currentRound = statistics.currentRound;
17
- };
18
- return LocalInformant;
19
- }(_1.Base));
20
- exports.LocalInformant = LocalInformant;
@@ -1,18 +0,0 @@
1
- import { Logger } from './logger';
2
- /**
3
- * Same properties as Toaster but on the console
4
- *
5
- * @class Logger
6
- */
7
- export declare class ConsoleLogger extends Logger {
8
- /**
9
- * Logs success message on the console (in green)
10
- * @param {String} message - message to be displayed
11
- */
12
- success(message: string): void;
13
- /**
14
- * Logs error message on the console (in red)
15
- * @param message - message to be displayed
16
- */
17
- error(message: string): void;
18
- }
@@ -1,33 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ConsoleLogger = void 0;
4
- var tslib_1 = require("tslib");
5
- var chalk_1 = (0, tslib_1.__importDefault)(require("chalk"));
6
- var logger_1 = require("./logger");
7
- /**
8
- * Same properties as Toaster but on the console
9
- *
10
- * @class Logger
11
- */
12
- var ConsoleLogger = /** @class */ (function (_super) {
13
- (0, tslib_1.__extends)(ConsoleLogger, _super);
14
- function ConsoleLogger() {
15
- return _super !== null && _super.apply(this, arguments) || this;
16
- }
17
- /**
18
- * Logs success message on the console (in green)
19
- * @param {String} message - message to be displayed
20
- */
21
- ConsoleLogger.prototype.success = function (message) {
22
- console.log(chalk_1.default.green(message));
23
- };
24
- /**
25
- * Logs error message on the console (in red)
26
- * @param message - message to be displayed
27
- */
28
- ConsoleLogger.prototype.error = function (message) {
29
- console.log(chalk_1.default.red(message));
30
- };
31
- return ConsoleLogger;
32
- }(logger_1.Logger));
33
- exports.ConsoleLogger = ConsoleLogger;
@@ -1,3 +0,0 @@
1
- export { Logger } from './logger';
2
- export { ConsoleLogger } from './console_logger';
3
- export { TrainerLog } from './trainer_logger';
@@ -1,9 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.TrainerLog = exports.ConsoleLogger = exports.Logger = void 0;
4
- var logger_1 = require("./logger");
5
- Object.defineProperty(exports, "Logger", { enumerable: true, get: function () { return logger_1.Logger; } });
6
- var console_logger_1 = require("./console_logger");
7
- Object.defineProperty(exports, "ConsoleLogger", { enumerable: true, get: function () { return console_logger_1.ConsoleLogger; } });
8
- var trainer_logger_1 = require("./trainer_logger");
9
- Object.defineProperty(exports, "TrainerLog", { enumerable: true, get: function () { return trainer_logger_1.TrainerLog; } });
@@ -1,12 +0,0 @@
1
- export declare abstract class Logger {
2
- /**
3
- * Logs sucess message (in green)
4
- * @param message - message to be displayed
5
- */
6
- abstract success(message: string): void;
7
- /**
8
- * Logs error message (in red)
9
- * @param message - message to be displayed
10
- */
11
- abstract error(message: string): void;
12
- }
@@ -1,9 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Logger = void 0;
4
- var Logger = /** @class */ (function () {
5
- function Logger() {
6
- }
7
- return Logger;
8
- }());
9
- exports.Logger = Logger;
@@ -1,24 +0,0 @@
1
- import { List } from 'immutable';
2
- import { tf } from '..';
3
- import { ConsoleLogger } from '.';
4
- export declare class TrainerLog {
5
- epochs: List<number>;
6
- trainAccuracy: List<number>;
7
- validationAccuracy: List<number>;
8
- loss: List<number>;
9
- add(epoch: number, logs?: tf.Logs): void;
10
- }
11
- /**
12
- *
13
- * @class TrainerLogger
14
- */
15
- export declare class TrainerLogger extends ConsoleLogger {
16
- readonly log: TrainerLog;
17
- readonly saveTrainerLog: boolean;
18
- constructor(saveTrainerLog?: boolean);
19
- onEpochEnd(epoch: number, logs?: tf.Logs): void;
20
- /**
21
- * Display ram usage
22
- */
23
- ramUsage(): void;
24
- }
@@ -1,59 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.TrainerLogger = exports.TrainerLog = void 0;
4
- var tslib_1 = require("tslib");
5
- var immutable_1 = require("immutable");
6
- var __1 = require("..");
7
- var _1 = require(".");
8
- var TrainerLog = /** @class */ (function () {
9
- function TrainerLog() {
10
- this.epochs = (0, immutable_1.List)();
11
- this.trainAccuracy = (0, immutable_1.List)();
12
- this.validationAccuracy = (0, immutable_1.List)();
13
- this.loss = (0, immutable_1.List)();
14
- }
15
- TrainerLog.prototype.add = function (epoch, logs) {
16
- this.epochs = this.epochs.push(epoch);
17
- if (logs !== undefined) {
18
- this.trainAccuracy = this.trainAccuracy.push(logs.acc);
19
- this.validationAccuracy = this.validationAccuracy.push(logs.val_acc);
20
- this.loss = this.loss.push(logs.loss);
21
- }
22
- };
23
- return TrainerLog;
24
- }());
25
- exports.TrainerLog = TrainerLog;
26
- /**
27
- *
28
- * @class TrainerLogger
29
- */
30
- var TrainerLogger = /** @class */ (function (_super) {
31
- (0, tslib_1.__extends)(TrainerLogger, _super);
32
- // TODO: pass savaTrainerLog as false in browser, used for benchmarking
33
- function TrainerLogger(saveTrainerLog) {
34
- if (saveTrainerLog === void 0) { saveTrainerLog = true; }
35
- var _this = _super.call(this) || this;
36
- _this.saveTrainerLog = saveTrainerLog;
37
- _this.log = new TrainerLog();
38
- return _this;
39
- }
40
- TrainerLogger.prototype.onEpochEnd = function (epoch, logs) {
41
- var _a, _b, _c;
42
- // save logs
43
- if (this.saveTrainerLog) {
44
- this.log.add(epoch, logs);
45
- }
46
- // console output
47
- var msg = "Epoch: " + epoch + "\nTrain: " + ((_a = logs === null || logs === void 0 ? void 0 : logs.acc) !== null && _a !== void 0 ? _a : 'undefined') + "\nValidation:" + ((_b = logs === null || logs === void 0 ? void 0 : logs.val_acc) !== null && _b !== void 0 ? _b : 'undefined') + "\nLoss:" + ((_c = logs === null || logs === void 0 ? void 0 : logs.loss) !== null && _c !== void 0 ? _c : 'undefined');
48
- this.success("On epoch end:\n" + msg + "\n");
49
- };
50
- /**
51
- * Display ram usage
52
- */
53
- TrainerLogger.prototype.ramUsage = function () {
54
- this.success("Training RAM usage is = " + __1.tf.memory().numBytes * 0.000001 + " MB");
55
- this.success("Number of allocated tensors = " + __1.tf.memory().numTensors);
56
- };
57
- return TrainerLogger;
58
- }(_1.ConsoleLogger));
59
- exports.TrainerLogger = TrainerLogger;
@@ -1,22 +0,0 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { TaskID } from '..';
3
- import { ModelType } from './model_type';
4
- export declare type Path = string;
5
- export interface ModelInfo {
6
- type?: ModelType;
7
- taskID: TaskID;
8
- name: string;
9
- }
10
- export declare type ModelSource = ModelInfo | Path;
11
- export declare abstract class Memory {
12
- abstract getModel(source: ModelSource): Promise<tf.LayersModel>;
13
- abstract deleteModel(source: ModelSource): Promise<void>;
14
- abstract loadModel(source: ModelSource): Promise<void>;
15
- abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
16
- abstract updateWorkingModel(source: ModelSource, model: tf.LayersModel): Promise<void>;
17
- abstract saveWorkingModel(source: ModelSource): Promise<void>;
18
- abstract downloadModel(source: ModelSource): Promise<void>;
19
- abstract contains(source: ModelSource): Promise<boolean>;
20
- abstract pathFor(source: ModelSource): Path | undefined;
21
- abstract infoFor(source: ModelSource): ModelInfo | undefined;
22
- }
@@ -1,9 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Memory = void 0;
4
- var Memory = /** @class */ (function () {
5
- function Memory() {
6
- }
7
- return Memory;
8
- }());
9
- exports.Memory = Memory;
@@ -1,14 +0,0 @@
1
- import { tf } from '..';
2
- import { Memory, ModelInfo, Path } from './base';
3
- export declare class Empty extends Memory {
4
- getModelMetadata(): Promise<undefined>;
5
- contains(): Promise<boolean>;
6
- getModel(): Promise<tf.LayersModel>;
7
- loadModel(): Promise<void>;
8
- updateWorkingModel(): Promise<void>;
9
- saveWorkingModel(): Promise<void>;
10
- deleteModel(): Promise<void>;
11
- downloadModel(): Promise<void>;
12
- pathFor(): Path;
13
- infoFor(): ModelInfo;
14
- }
@@ -1,75 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.Empty = void 0;
4
- var tslib_1 = require("tslib");
5
- var base_1 = require("./base");
6
- var Empty = /** @class */ (function (_super) {
7
- (0, tslib_1.__extends)(Empty, _super);
8
- function Empty() {
9
- return _super !== null && _super.apply(this, arguments) || this;
10
- }
11
- Empty.prototype.getModelMetadata = function () {
12
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
13
- return (0, tslib_1.__generator)(this, function (_a) {
14
- return [2 /*return*/, undefined];
15
- });
16
- });
17
- };
18
- Empty.prototype.contains = function () {
19
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
20
- return (0, tslib_1.__generator)(this, function (_a) {
21
- return [2 /*return*/, false];
22
- });
23
- });
24
- };
25
- Empty.prototype.getModel = function () {
26
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
27
- return (0, tslib_1.__generator)(this, function (_a) {
28
- throw new Error('empty');
29
- });
30
- });
31
- };
32
- Empty.prototype.loadModel = function () {
33
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
34
- return (0, tslib_1.__generator)(this, function (_a) {
35
- throw new Error('empty');
36
- });
37
- });
38
- };
39
- Empty.prototype.updateWorkingModel = function () {
40
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
41
- return (0, tslib_1.__generator)(this, function (_a) {
42
- return [2 /*return*/];
43
- });
44
- });
45
- };
46
- Empty.prototype.saveWorkingModel = function () {
47
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
48
- return (0, tslib_1.__generator)(this, function (_a) {
49
- return [2 /*return*/];
50
- });
51
- });
52
- };
53
- Empty.prototype.deleteModel = function () {
54
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
55
- return (0, tslib_1.__generator)(this, function (_a) {
56
- return [2 /*return*/];
57
- });
58
- });
59
- };
60
- Empty.prototype.downloadModel = function () {
61
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
62
- return (0, tslib_1.__generator)(this, function (_a) {
63
- throw new Error('empty');
64
- });
65
- });
66
- };
67
- Empty.prototype.pathFor = function () {
68
- throw new Error('empty');
69
- };
70
- Empty.prototype.infoFor = function () {
71
- throw new Error('empty');
72
- };
73
- return Empty;
74
- }(base_1.Memory));
75
- exports.Empty = Empty;
@@ -1,3 +0,0 @@
1
- export { Empty } from './empty';
2
- export { Memory, ModelInfo, Path, ModelSource } from './base';
3
- export { ModelType } from './model_type';
@@ -1,9 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ModelType = exports.Memory = exports.Empty = void 0;
4
- var empty_1 = require("./empty");
5
- Object.defineProperty(exports, "Empty", { enumerable: true, get: function () { return empty_1.Empty; } });
6
- var base_1 = require("./base");
7
- Object.defineProperty(exports, "Memory", { enumerable: true, get: function () { return base_1.Memory; } });
8
- var model_type_1 = require("./model_type");
9
- Object.defineProperty(exports, "ModelType", { enumerable: true, get: function () { return model_type_1.ModelType; } });
@@ -1,4 +0,0 @@
1
- export declare enum ModelType {
2
- WORKING = "working",
3
- SAVED = "saved"
4
- }
@@ -1,9 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.ModelType = void 0;
4
- // Type of model to store & retrieve
5
- var ModelType;
6
- (function (ModelType) {
7
- ModelType["WORKING"] = "working";
8
- ModelType["SAVED"] = "saved";
9
- })(ModelType = exports.ModelType || (exports.ModelType = {}));
@@ -1,11 +0,0 @@
1
- import { Task, WeightsContainer } from '.';
2
- /**
3
- * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
4
- * The previous round's weights are the last weights pulled from server/peers.
5
- * The current round's weights are obtained after a single round of training, from the previous round's weights.
6
- * @param updatedWeights weights from the current round
7
- * @param staleWeights weights from the previous round
8
- * @param task the task
9
- * @returns the noised weights for the current round
10
- */
11
- export declare function addDifferentialPrivacy(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, task: Task): WeightsContainer;
@@ -1,47 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.addDifferentialPrivacy = void 0;
4
- var _1 = require(".");
5
- /**
6
- * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
7
- * The previous round's weights are the last weights pulled from server/peers.
8
- * The current round's weights are obtained after a single round of training, from the previous round's weights.
9
- * @param updatedWeights weights from the current round
10
- * @param staleWeights weights from the previous round
11
- * @param task the task
12
- * @returns the noised weights for the current round
13
- */
14
- function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
15
- var _a, _b;
16
- var noiseScale = (_a = task.trainingInformation) === null || _a === void 0 ? void 0 : _a.noiseScale;
17
- var clippingRadius = (_b = task.trainingInformation) === null || _b === void 0 ? void 0 : _b.clippingRadius;
18
- var weightsDiff = updatedWeights.sub(staleWeights);
19
- var newWeightsDiff;
20
- if (clippingRadius !== undefined) {
21
- // Frobenius norm
22
- var norm_1 = weightsDiff.frobeniusNorm();
23
- newWeightsDiff = weightsDiff.map(function (w) {
24
- var clipped = w.div(Math.max(1, norm_1 / clippingRadius));
25
- if (noiseScale !== undefined) {
26
- // Add clipping and noise
27
- var noise = _1.tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
28
- return clipped.add(noise);
29
- }
30
- else {
31
- // Add clipping without any noise
32
- return clipped;
33
- }
34
- });
35
- }
36
- else {
37
- if (noiseScale !== undefined) {
38
- // Add noise without any clipping
39
- newWeightsDiff = weightsDiff.map(function (w) { return _1.tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)); });
40
- }
41
- else {
42
- return updatedWeights;
43
- }
44
- }
45
- return staleWeights.add(newWeightsDiff);
46
- }
47
- exports.addDifferentialPrivacy = addDifferentialPrivacy;
@@ -1,2 +0,0 @@
1
- export * as model from './model';
2
- export * as weights from './weights';
@@ -1,6 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.weights = exports.model = void 0;
4
- var tslib_1 = require("tslib");
5
- exports.model = (0, tslib_1.__importStar)(require("./model"));
6
- exports.weights = (0, tslib_1.__importStar)(require("./weights"));
@@ -1,5 +0,0 @@
1
- import { tf } from '..';
2
- export declare type Encoded = number[];
3
- export declare function isEncoded(raw: unknown): raw is Encoded;
4
- export declare function encode(model: tf.LayersModel): Promise<Encoded>;
5
- export declare function decode(encoded: Encoded): Promise<tf.LayersModel>;