@epfml/discojs 0.1.0 → 2.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. package/README.md +28 -8
  2. package/dist/{async_buffer.d.ts → core/async_buffer.d.ts} +3 -3
  3. package/dist/{async_buffer.js → core/async_buffer.js} +5 -6
  4. package/dist/{async_informant.d.ts → core/async_informant.d.ts} +0 -0
  5. package/dist/{async_informant.js → core/async_informant.js} +0 -0
  6. package/dist/{client → core/client}/base.d.ts +4 -7
  7. package/dist/{client → core/client}/base.js +3 -2
  8. package/dist/core/client/decentralized/base.d.ts +32 -0
  9. package/dist/core/client/decentralized/base.js +212 -0
  10. package/dist/core/client/decentralized/clear_text.d.ts +14 -0
  11. package/dist/core/client/decentralized/clear_text.js +96 -0
  12. package/dist/{client → core/client}/decentralized/index.d.ts +0 -0
  13. package/dist/{client → core/client}/decentralized/index.js +0 -0
  14. package/dist/core/client/decentralized/messages.d.ts +41 -0
  15. package/dist/core/client/decentralized/messages.js +54 -0
  16. package/dist/core/client/decentralized/peer.d.ts +26 -0
  17. package/dist/core/client/decentralized/peer.js +210 -0
  18. package/dist/core/client/decentralized/peer_pool.d.ts +14 -0
  19. package/dist/core/client/decentralized/peer_pool.js +92 -0
  20. package/dist/core/client/decentralized/sec_agg.d.ts +22 -0
  21. package/dist/core/client/decentralized/sec_agg.js +190 -0
  22. package/dist/core/client/decentralized/secret_shares.d.ts +3 -0
  23. package/dist/core/client/decentralized/secret_shares.js +39 -0
  24. package/dist/core/client/decentralized/types.d.ts +2 -0
  25. package/dist/core/client/decentralized/types.js +7 -0
  26. package/dist/core/client/event_connection.d.ts +37 -0
  27. package/dist/core/client/event_connection.js +158 -0
  28. package/dist/core/client/federated/client.d.ts +37 -0
  29. package/dist/core/client/federated/client.js +273 -0
  30. package/dist/core/client/federated/index.d.ts +2 -0
  31. package/dist/core/client/federated/index.js +7 -0
  32. package/dist/core/client/federated/messages.d.ts +38 -0
  33. package/dist/core/client/federated/messages.js +25 -0
  34. package/dist/{client → core/client}/index.d.ts +2 -1
  35. package/dist/{client → core/client}/index.js +3 -3
  36. package/dist/{client → core/client}/local.d.ts +2 -2
  37. package/dist/{client → core/client}/local.js +0 -0
  38. package/dist/core/client/messages.d.ts +28 -0
  39. package/dist/core/client/messages.js +33 -0
  40. package/dist/core/client/utils.d.ts +2 -0
  41. package/dist/core/client/utils.js +19 -0
  42. package/dist/core/dataset/data/data.d.ts +11 -0
  43. package/dist/core/dataset/data/data.js +20 -0
  44. package/dist/core/dataset/data/data_split.d.ts +5 -0
  45. package/dist/{client/decentralized/types.js → core/dataset/data/data_split.js} +0 -0
  46. package/dist/core/dataset/data/image_data.d.ts +8 -0
  47. package/dist/core/dataset/data/image_data.js +64 -0
  48. package/dist/core/dataset/data/index.d.ts +5 -0
  49. package/dist/core/dataset/data/index.js +11 -0
  50. package/dist/core/dataset/data/preprocessing.d.ts +13 -0
  51. package/dist/core/dataset/data/preprocessing.js +33 -0
  52. package/dist/core/dataset/data/tabular_data.d.ts +8 -0
  53. package/dist/core/dataset/data/tabular_data.js +40 -0
  54. package/dist/{dataset → core/dataset}/data_loader/data_loader.d.ts +4 -11
  55. package/dist/{dataset → core/dataset}/data_loader/data_loader.js +0 -0
  56. package/dist/core/dataset/data_loader/image_loader.d.ts +17 -0
  57. package/dist/core/dataset/data_loader/image_loader.js +141 -0
  58. package/dist/core/dataset/data_loader/index.d.ts +3 -0
  59. package/dist/core/dataset/data_loader/index.js +9 -0
  60. package/dist/core/dataset/data_loader/tabular_loader.d.ts +29 -0
  61. package/dist/core/dataset/data_loader/tabular_loader.js +101 -0
  62. package/dist/core/dataset/dataset.d.ts +2 -0
  63. package/dist/{task/training_information.js → core/dataset/dataset.js} +0 -0
  64. package/dist/{dataset → core/dataset}/dataset_builder.d.ts +5 -5
  65. package/dist/{dataset → core/dataset}/dataset_builder.js +14 -10
  66. package/dist/core/dataset/index.d.ts +4 -0
  67. package/dist/core/dataset/index.js +14 -0
  68. package/dist/core/index.d.ts +18 -0
  69. package/dist/core/index.js +41 -0
  70. package/dist/{informant → core/informant}/graph_informant.d.ts +0 -0
  71. package/dist/{informant → core/informant}/graph_informant.js +0 -0
  72. package/dist/{informant → core/informant}/index.d.ts +0 -0
  73. package/dist/{informant → core/informant}/index.js +0 -0
  74. package/dist/{informant → core/informant}/training_informant/base.d.ts +3 -3
  75. package/dist/{informant → core/informant}/training_informant/base.js +3 -2
  76. package/dist/{informant → core/informant}/training_informant/decentralized.d.ts +0 -0
  77. package/dist/{informant → core/informant}/training_informant/decentralized.js +0 -0
  78. package/dist/{informant → core/informant}/training_informant/federated.d.ts +0 -0
  79. package/dist/{informant → core/informant}/training_informant/federated.js +0 -0
  80. package/dist/{informant → core/informant}/training_informant/index.d.ts +0 -0
  81. package/dist/{informant → core/informant}/training_informant/index.js +0 -0
  82. package/dist/{informant → core/informant}/training_informant/local.d.ts +2 -2
  83. package/dist/{informant → core/informant}/training_informant/local.js +2 -2
  84. package/dist/{logging → core/logging}/console_logger.d.ts +0 -0
  85. package/dist/{logging → core/logging}/console_logger.js +0 -0
  86. package/dist/{logging → core/logging}/index.d.ts +0 -0
  87. package/dist/{logging → core/logging}/index.js +0 -0
  88. package/dist/{logging → core/logging}/logger.d.ts +0 -0
  89. package/dist/{logging → core/logging}/logger.js +0 -0
  90. package/dist/{logging → core/logging}/trainer_logger.d.ts +0 -0
  91. package/dist/{logging → core/logging}/trainer_logger.js +0 -0
  92. package/dist/{memory → core/memory}/base.d.ts +2 -2
  93. package/dist/{memory → core/memory}/base.js +0 -0
  94. package/dist/{memory → core/memory}/empty.d.ts +0 -0
  95. package/dist/{memory → core/memory}/empty.js +0 -0
  96. package/dist/core/memory/index.d.ts +3 -0
  97. package/dist/core/memory/index.js +9 -0
  98. package/dist/{memory → core/memory}/model_type.d.ts +0 -0
  99. package/dist/{memory → core/memory}/model_type.js +0 -0
  100. package/dist/{privacy.d.ts → core/privacy.d.ts} +2 -3
  101. package/dist/{privacy.js → core/privacy.js} +3 -16
  102. package/dist/{serialization → core/serialization}/index.d.ts +0 -0
  103. package/dist/{serialization → core/serialization}/index.js +0 -0
  104. package/dist/{serialization → core/serialization}/model.d.ts +0 -0
  105. package/dist/{serialization → core/serialization}/model.js +0 -0
  106. package/dist/core/serialization/weights.d.ts +5 -0
  107. package/dist/{serialization → core/serialization}/weights.js +11 -9
  108. package/dist/{task → core/task}/data_example.d.ts +0 -0
  109. package/dist/{task → core/task}/data_example.js +0 -0
  110. package/dist/{task → core/task}/display_information.d.ts +5 -5
  111. package/dist/{task → core/task}/display_information.js +5 -10
  112. package/dist/{task → core/task}/index.d.ts +0 -0
  113. package/dist/{task → core/task}/index.js +0 -0
  114. package/dist/core/task/model_compile_data.d.ts +6 -0
  115. package/dist/core/task/model_compile_data.js +22 -0
  116. package/dist/{task → core/task}/summary.d.ts +0 -0
  117. package/dist/{task → core/task}/summary.js +0 -4
  118. package/dist/{task → core/task}/task.d.ts +2 -2
  119. package/dist/{task → core/task}/task.js +6 -7
  120. package/dist/{task → core/task}/training_information.d.ts +10 -14
  121. package/dist/core/task/training_information.js +66 -0
  122. package/dist/{tasks → core/tasks}/cifar10.d.ts +1 -2
  123. package/dist/{tasks → core/tasks}/cifar10.js +12 -23
  124. package/dist/core/tasks/geotags.d.ts +3 -0
  125. package/dist/core/tasks/geotags.js +67 -0
  126. package/dist/{tasks → core/tasks}/index.d.ts +2 -1
  127. package/dist/{tasks → core/tasks}/index.js +3 -2
  128. package/dist/core/tasks/lus_covid.d.ts +3 -0
  129. package/dist/{tasks → core/tasks}/lus_covid.js +26 -24
  130. package/dist/{tasks → core/tasks}/mnist.d.ts +1 -2
  131. package/dist/{tasks → core/tasks}/mnist.js +18 -16
  132. package/dist/core/tasks/simple_face.d.ts +2 -0
  133. package/dist/core/tasks/simple_face.js +41 -0
  134. package/dist/{tasks → core/tasks}/titanic.d.ts +1 -2
  135. package/dist/{tasks → core/tasks}/titanic.js +11 -11
  136. package/dist/core/training/disco.d.ts +23 -0
  137. package/dist/core/training/disco.js +130 -0
  138. package/dist/{training → core/training}/index.d.ts +0 -0
  139. package/dist/{training → core/training}/index.js +0 -0
  140. package/dist/{training → core/training}/trainer/distributed_trainer.d.ts +1 -2
  141. package/dist/{training → core/training}/trainer/distributed_trainer.js +6 -5
  142. package/dist/{training → core/training}/trainer/local_trainer.d.ts +2 -2
  143. package/dist/{training → core/training}/trainer/local_trainer.js +0 -0
  144. package/dist/{training → core/training}/trainer/round_tracker.d.ts +0 -0
  145. package/dist/{training → core/training}/trainer/round_tracker.js +0 -0
  146. package/dist/{training → core/training}/trainer/trainer.d.ts +1 -2
  147. package/dist/{training → core/training}/trainer/trainer.js +2 -2
  148. package/dist/{training → core/training}/trainer/trainer_builder.d.ts +0 -0
  149. package/dist/{training → core/training}/trainer/trainer_builder.js +0 -0
  150. package/dist/core/training/training_schemes.d.ts +5 -0
  151. package/dist/{training → core/training}/training_schemes.js +2 -2
  152. package/dist/{types.d.ts → core/types.d.ts} +0 -0
  153. package/dist/{types.js → core/types.js} +0 -0
  154. package/dist/{validation → core/validation}/index.d.ts +0 -0
  155. package/dist/{validation → core/validation}/index.js +0 -0
  156. package/dist/{validation → core/validation}/validator.d.ts +5 -8
  157. package/dist/{validation → core/validation}/validator.js +9 -11
  158. package/dist/core/weights/aggregation.d.ts +8 -0
  159. package/dist/core/weights/aggregation.js +96 -0
  160. package/dist/core/weights/index.d.ts +2 -0
  161. package/dist/core/weights/index.js +7 -0
  162. package/dist/core/weights/weights_container.d.ts +19 -0
  163. package/dist/core/weights/weights_container.js +64 -0
  164. package/dist/dataset/data_loader/image_loader.d.ts +3 -15
  165. package/dist/dataset/data_loader/image_loader.js +12 -125
  166. package/dist/dataset/data_loader/index.d.ts +2 -3
  167. package/dist/dataset/data_loader/index.js +3 -5
  168. package/dist/dataset/data_loader/tabular_loader.d.ts +3 -28
  169. package/dist/dataset/data_loader/tabular_loader.js +11 -92
  170. package/dist/imports.d.ts +2 -0
  171. package/dist/imports.js +7 -0
  172. package/dist/index.d.ts +2 -19
  173. package/dist/index.js +3 -39
  174. package/dist/memory/index.d.ts +1 -3
  175. package/dist/memory/index.js +3 -7
  176. package/dist/memory/memory.d.ts +26 -0
  177. package/dist/memory/memory.js +160 -0
  178. package/package.json +13 -26
  179. package/dist/aggregation.d.ts +0 -5
  180. package/dist/aggregation.js +0 -33
  181. package/dist/client/decentralized/base.d.ts +0 -43
  182. package/dist/client/decentralized/base.js +0 -243
  183. package/dist/client/decentralized/clear_text.d.ts +0 -13
  184. package/dist/client/decentralized/clear_text.js +0 -78
  185. package/dist/client/decentralized/messages.d.ts +0 -37
  186. package/dist/client/decentralized/messages.js +0 -15
  187. package/dist/client/decentralized/sec_agg.d.ts +0 -18
  188. package/dist/client/decentralized/sec_agg.js +0 -169
  189. package/dist/client/decentralized/secret_shares.d.ts +0 -5
  190. package/dist/client/decentralized/secret_shares.js +0 -58
  191. package/dist/client/decentralized/types.d.ts +0 -1
  192. package/dist/client/federated.d.ts +0 -30
  193. package/dist/client/federated.js +0 -218
  194. package/dist/dataset/index.d.ts +0 -2
  195. package/dist/dataset/index.js +0 -7
  196. package/dist/model_actor.d.ts +0 -16
  197. package/dist/model_actor.js +0 -20
  198. package/dist/serialization/weights.d.ts +0 -5
  199. package/dist/task/model_compile_data.d.ts +0 -6
  200. package/dist/task/model_compile_data.js +0 -12
  201. package/dist/tasks/lus_covid.d.ts +0 -4
  202. package/dist/tasks/simple_face.d.ts +0 -4
  203. package/dist/tasks/simple_face.js +0 -84
  204. package/dist/tfjs.d.ts +0 -2
  205. package/dist/tfjs.js +0 -6
  206. package/dist/training/disco.d.ts +0 -14
  207. package/dist/training/disco.js +0 -70
  208. package/dist/training/training_schemes.d.ts +0 -5
@@ -12,7 +12,7 @@ var DatasetBuilder = /** @class */ (function () {
12
12
  }
13
13
  DatasetBuilder.prototype.addFiles = function (sources, label) {
14
14
  if (this.built) {
15
- throw new Error('builder already consumed');
15
+ this.resetBuiltState();
16
16
  }
17
17
  if (label === undefined) {
18
18
  this.sources = this.sources.concat(sources);
@@ -29,7 +29,7 @@ var DatasetBuilder = /** @class */ (function () {
29
29
  };
30
30
  DatasetBuilder.prototype.clearFiles = function (label) {
31
31
  if (this.built) {
32
- throw new Error('builder already consumed');
32
+ this.resetBuiltState();
33
33
  }
34
34
  if (label === undefined) {
35
35
  this.sources = [];
@@ -38,6 +38,11 @@ var DatasetBuilder = /** @class */ (function () {
38
38
  this.labelledSources.delete(label);
39
39
  }
40
40
  };
41
+ // If files are added or removed, then this should be called since the latest
42
+ // version of the dataset_builder has not yet been built.
43
+ DatasetBuilder.prototype.resetBuiltState = function () {
44
+ this.built = false;
45
+ };
41
46
  DatasetBuilder.prototype.getLabels = function () {
42
47
  // We need to duplicate the labels as we need one for each soure.
43
48
  // Say for label A we have sources [img1, img2, img3], then we
@@ -50,29 +55,28 @@ var DatasetBuilder = /** @class */ (function () {
50
55
  return labels.flat();
51
56
  };
52
57
  DatasetBuilder.prototype.build = function (config) {
53
- var _a, _b;
54
58
  return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
55
59
  var dataTuple, defaultConfig, defaultConfig, sources;
56
- return (0, tslib_1.__generator)(this, function (_c) {
57
- switch (_c.label) {
60
+ return (0, tslib_1.__generator)(this, function (_a) {
61
+ switch (_a.label) {
58
62
  case 0:
59
63
  // Require that at leat one source collection is non-empty, but not both
60
64
  if ((this.sources.length > 0) === (this.labelledSources.size > 0)) {
61
- throw new Error('invalid sources');
65
+ throw new Error('Please provide dataset input files');
62
66
  }
63
67
  if (!(this.sources.length > 0)) return [3 /*break*/, 2];
64
- defaultConfig = (0, tslib_1.__assign)({ features: (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.inputColumns, labels: (_b = this.task.trainingInformation) === null || _b === void 0 ? void 0 : _b.outputColumns }, config);
68
+ defaultConfig = (0, tslib_1.__assign)({ features: this.task.trainingInformation.inputColumns, labels: this.task.trainingInformation.outputColumns }, config);
65
69
  return [4 /*yield*/, this.dataLoader.loadAll(this.sources, defaultConfig)];
66
70
  case 1:
67
- dataTuple = _c.sent();
71
+ dataTuple = _a.sent();
68
72
  return [3 /*break*/, 4];
69
73
  case 2:
70
74
  defaultConfig = (0, tslib_1.__assign)({ labels: this.getLabels() }, config);
71
75
  sources = Array.from(this.labelledSources.values()).flat();
72
76
  return [4 /*yield*/, this.dataLoader.loadAll(sources, defaultConfig)];
73
77
  case 3:
74
- dataTuple = _c.sent();
75
- _c.label = 4;
78
+ dataTuple = _a.sent();
79
+ _a.label = 4;
76
80
  case 4:
77
81
  // TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
78
82
  this.built = true;
@@ -0,0 +1,4 @@
1
+ export { Dataset } from './dataset';
2
+ export { DatasetBuilder } from './dataset_builder';
3
+ export { DataSplit, Data, TabularData, ImageData, ImagePreprocessing } from './data';
4
+ export { ImageLoader, TabularLoader, DataLoader } from './data_loader';
@@ -0,0 +1,14 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.DataLoader = exports.TabularLoader = exports.ImageLoader = exports.ImagePreprocessing = exports.ImageData = exports.TabularData = exports.Data = exports.DatasetBuilder = void 0;
4
+ var dataset_builder_1 = require("./dataset_builder");
5
+ Object.defineProperty(exports, "DatasetBuilder", { enumerable: true, get: function () { return dataset_builder_1.DatasetBuilder; } });
6
+ var data_1 = require("./data");
7
+ Object.defineProperty(exports, "Data", { enumerable: true, get: function () { return data_1.Data; } });
8
+ Object.defineProperty(exports, "TabularData", { enumerable: true, get: function () { return data_1.TabularData; } });
9
+ Object.defineProperty(exports, "ImageData", { enumerable: true, get: function () { return data_1.ImageData; } });
10
+ Object.defineProperty(exports, "ImagePreprocessing", { enumerable: true, get: function () { return data_1.ImagePreprocessing; } });
11
+ var data_loader_1 = require("./data_loader");
12
+ Object.defineProperty(exports, "ImageLoader", { enumerable: true, get: function () { return data_loader_1.ImageLoader; } });
13
+ Object.defineProperty(exports, "TabularLoader", { enumerable: true, get: function () { return data_loader_1.TabularLoader; } });
14
+ Object.defineProperty(exports, "DataLoader", { enumerable: true, get: function () { return data_loader_1.DataLoader; } });
@@ -0,0 +1,18 @@
1
+ export * as tf from '@tensorflow/tfjs';
2
+ export * as data from './dataset';
3
+ export * as serialization from './serialization';
4
+ export * as training from './training';
5
+ export * as privacy from './privacy';
6
+ export { GraphInformant, TrainingInformant, informant } from './informant';
7
+ export { Base as Client } from './client';
8
+ export * as client from './client';
9
+ export { WeightsContainer, aggregation } from './weights';
10
+ export { AsyncBuffer } from './async_buffer';
11
+ export { AsyncInformant } from './async_informant';
12
+ export { Logger, ConsoleLogger, TrainerLog } from './logging';
13
+ export { Memory, ModelType, ModelInfo, Path, ModelSource, Empty as EmptyMemory } from './memory';
14
+ export { Disco, TrainingSchemes } from './training';
15
+ export { Validator } from './validation';
16
+ export { TrainingInformation, DisplayInformation, isTask, Task, isTaskID, TaskID } from './task';
17
+ export * as tasks from './tasks';
18
+ export * from './types';
@@ -0,0 +1,41 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.tasks = exports.isTaskID = exports.isTask = exports.Validator = exports.TrainingSchemes = exports.Disco = exports.EmptyMemory = exports.ModelType = exports.Memory = exports.TrainerLog = exports.ConsoleLogger = exports.Logger = exports.AsyncInformant = exports.AsyncBuffer = exports.aggregation = exports.WeightsContainer = exports.client = exports.Client = exports.informant = exports.TrainingInformant = exports.GraphInformant = exports.privacy = exports.training = exports.serialization = exports.data = exports.tf = void 0;
4
+ var tslib_1 = require("tslib");
5
+ exports.tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
+ exports.data = (0, tslib_1.__importStar)(require("./dataset"));
7
+ exports.serialization = (0, tslib_1.__importStar)(require("./serialization"));
8
+ exports.training = (0, tslib_1.__importStar)(require("./training"));
9
+ exports.privacy = (0, tslib_1.__importStar)(require("./privacy"));
10
+ var informant_1 = require("./informant");
11
+ Object.defineProperty(exports, "GraphInformant", { enumerable: true, get: function () { return informant_1.GraphInformant; } });
12
+ Object.defineProperty(exports, "TrainingInformant", { enumerable: true, get: function () { return informant_1.TrainingInformant; } });
13
+ Object.defineProperty(exports, "informant", { enumerable: true, get: function () { return informant_1.informant; } });
14
+ var client_1 = require("./client");
15
+ Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Base; } });
16
+ exports.client = (0, tslib_1.__importStar)(require("./client"));
17
+ var weights_1 = require("./weights");
18
+ Object.defineProperty(exports, "WeightsContainer", { enumerable: true, get: function () { return weights_1.WeightsContainer; } });
19
+ Object.defineProperty(exports, "aggregation", { enumerable: true, get: function () { return weights_1.aggregation; } });
20
+ var async_buffer_1 = require("./async_buffer");
21
+ Object.defineProperty(exports, "AsyncBuffer", { enumerable: true, get: function () { return async_buffer_1.AsyncBuffer; } });
22
+ var async_informant_1 = require("./async_informant");
23
+ Object.defineProperty(exports, "AsyncInformant", { enumerable: true, get: function () { return async_informant_1.AsyncInformant; } });
24
+ var logging_1 = require("./logging");
25
+ Object.defineProperty(exports, "Logger", { enumerable: true, get: function () { return logging_1.Logger; } });
26
+ Object.defineProperty(exports, "ConsoleLogger", { enumerable: true, get: function () { return logging_1.ConsoleLogger; } });
27
+ Object.defineProperty(exports, "TrainerLog", { enumerable: true, get: function () { return logging_1.TrainerLog; } });
28
+ var memory_1 = require("./memory");
29
+ Object.defineProperty(exports, "Memory", { enumerable: true, get: function () { return memory_1.Memory; } });
30
+ Object.defineProperty(exports, "ModelType", { enumerable: true, get: function () { return memory_1.ModelType; } });
31
+ Object.defineProperty(exports, "EmptyMemory", { enumerable: true, get: function () { return memory_1.Empty; } });
32
+ var training_1 = require("./training");
33
+ Object.defineProperty(exports, "Disco", { enumerable: true, get: function () { return training_1.Disco; } });
34
+ Object.defineProperty(exports, "TrainingSchemes", { enumerable: true, get: function () { return training_1.TrainingSchemes; } });
35
+ var validation_1 = require("./validation");
36
+ Object.defineProperty(exports, "Validator", { enumerable: true, get: function () { return validation_1.Validator; } });
37
+ var task_1 = require("./task");
38
+ Object.defineProperty(exports, "isTask", { enumerable: true, get: function () { return task_1.isTask; } });
39
+ Object.defineProperty(exports, "isTaskID", { enumerable: true, get: function () { return task_1.isTaskID; } });
40
+ exports.tasks = (0, tslib_1.__importStar)(require("./tasks"));
41
+ (0, tslib_1.__exportStar)(require("./types"), exports);
File without changes
File without changes
@@ -1,8 +1,8 @@
1
1
  import { List } from 'immutable';
2
- import { TaskID } from '@/task';
2
+ import { Task } from '../../task';
3
3
  import { GraphInformant } from '../graph_informant';
4
4
  export declare abstract class Base {
5
- readonly taskID: TaskID;
5
+ readonly task: Task;
6
6
  private readonly nbrMessagesToShow;
7
7
  private messages;
8
8
  protected readonly trainingGraphInformant: GraphInformant;
@@ -11,7 +11,7 @@ export declare abstract class Base {
11
11
  protected currentNumberOfParticipants: number;
12
12
  protected totalNumberOfParticipants: number;
13
13
  protected averageNumberOfParticipants: number;
14
- constructor(taskID: TaskID, nbrMessagesToShow: number);
14
+ constructor(task: Task, nbrMessagesToShow?: number);
15
15
  abstract update(statistics: Record<string, number>): void;
16
16
  addMessage(msg: string): void;
17
17
  getMessages(): string[];
@@ -4,8 +4,9 @@ exports.Base = void 0;
4
4
  var immutable_1 = require("immutable");
5
5
  var graph_informant_1 = require("../graph_informant");
6
6
  var Base = /** @class */ (function () {
7
- function Base(taskID, nbrMessagesToShow) {
8
- this.taskID = taskID;
7
+ function Base(task, nbrMessagesToShow) {
8
+ if (nbrMessagesToShow === void 0) { nbrMessagesToShow = 10; }
9
+ this.task = task;
9
10
  this.nbrMessagesToShow = nbrMessagesToShow;
10
11
  // written feedback
11
12
  this.messages = (0, immutable_1.List)();
@@ -1,6 +1,6 @@
1
- import { TaskID } from '@/task';
1
+ import { Task } from '../../task';
2
2
  import { Base } from '.';
3
3
  export declare class LocalInformant extends Base {
4
- constructor(taskID: TaskID, nbrMessagesToShow: number);
4
+ constructor(task: Task, nbrMessagesToShow?: number);
5
5
  update(statistics: Record<string, number>): void;
6
6
  }
@@ -5,8 +5,8 @@ var tslib_1 = require("tslib");
5
5
  var _1 = require(".");
6
6
  var LocalInformant = /** @class */ (function (_super) {
7
7
  (0, tslib_1.__extends)(LocalInformant, _super);
8
- function LocalInformant(taskID, nbrMessagesToShow) {
9
- var _this = _super.call(this, taskID, nbrMessagesToShow) || this;
8
+ function LocalInformant(task, nbrMessagesToShow) {
9
+ var _this = _super.call(this, task, nbrMessagesToShow) || this;
10
10
  _this.currentNumberOfParticipants = 1;
11
11
  _this.averageNumberOfParticipants = 1;
12
12
  _this.totalNumberOfParticipants = 1;
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -5,14 +5,14 @@ export declare type Path = string;
5
5
  export interface ModelInfo {
6
6
  type?: ModelType;
7
7
  taskID: TaskID;
8
- name?: string;
8
+ name: string;
9
9
  }
10
10
  export declare type ModelSource = ModelInfo | Path;
11
11
  export declare abstract class Memory {
12
12
  abstract getModel(source: ModelSource): Promise<tf.LayersModel>;
13
13
  abstract deleteModel(source: ModelSource): Promise<void>;
14
14
  abstract loadModel(source: ModelSource): Promise<void>;
15
- abstract getModelMetadata(source: ModelSource): Promise<tf.io.ModelArtifactsInfo | undefined>;
15
+ abstract getModelMetadata(source: ModelSource): Promise<object | undefined>;
16
16
  abstract updateWorkingModel(source: ModelSource, model: tf.LayersModel): Promise<void>;
17
17
  abstract saveWorkingModel(source: ModelSource): Promise<void>;
18
18
  abstract downloadModel(source: ModelSource): Promise<void>;
File without changes
File without changes
File without changes
@@ -0,0 +1,3 @@
1
+ export { Empty } from './empty';
2
+ export { Memory, ModelInfo, Path, ModelSource } from './base';
3
+ export { ModelType } from './model_type';
@@ -0,0 +1,9 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.ModelType = exports.Memory = exports.Empty = void 0;
4
+ var empty_1 = require("./empty");
5
+ Object.defineProperty(exports, "Empty", { enumerable: true, get: function () { return empty_1.Empty; } });
6
+ var base_1 = require("./base");
7
+ Object.defineProperty(exports, "Memory", { enumerable: true, get: function () { return base_1.Memory; } });
8
+ var model_type_1 = require("./model_type");
9
+ Object.defineProperty(exports, "ModelType", { enumerable: true, get: function () { return model_type_1.ModelType; } });
File without changes
File without changes
@@ -1,5 +1,4 @@
1
- import { Weights } from '@/types';
2
- import { Task } from '@/task';
1
+ import { Task, WeightsContainer } from '.';
3
2
  /**
4
3
  * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
5
4
  * The previous round's weights are the last weights pulled from server/peers.
@@ -9,4 +8,4 @@ import { Task } from '@/task';
9
8
  * @param task the task
10
9
  * @returns the noised weights for the current round
11
10
  */
12
- export declare function addDifferentialPrivacy(updatedWeights: Weights, staleWeights: Weights, task: Task): Weights;
11
+ export declare function addDifferentialPrivacy(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, task: Task): WeightsContainer;
@@ -1,8 +1,6 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.addDifferentialPrivacy = void 0;
4
- var tslib_1 = require("tslib");
5
- var immutable_1 = require("immutable");
6
4
  var _1 = require(".");
7
5
  /**
8
6
  * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
@@ -17,16 +15,11 @@ function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
17
15
  var _a, _b;
18
16
  var noiseScale = (_a = task.trainingInformation) === null || _a === void 0 ? void 0 : _a.noiseScale;
19
17
  var clippingRadius = (_b = task.trainingInformation) === null || _b === void 0 ? void 0 : _b.clippingRadius;
20
- var weightsDiff = (0, immutable_1.List)(updatedWeights)
21
- .zip((0, immutable_1.List)(staleWeights))
22
- .map(function (_a) {
23
- var _b = (0, tslib_1.__read)(_a, 2), w1 = _b[0], w2 = _b[1];
24
- return w1.add(-w2);
25
- });
18
+ var weightsDiff = updatedWeights.sub(staleWeights);
26
19
  var newWeightsDiff;
27
20
  if (clippingRadius !== undefined) {
28
21
  // Frobenius norm
29
- var norm_1 = Math.sqrt(weightsDiff.map(function (w) { return w.square().sum().dataSync()[0]; }).reduce(function (a, b) { return a + b; }));
22
+ var norm_1 = weightsDiff.frobeniusNorm();
30
23
  newWeightsDiff = weightsDiff.map(function (w) {
31
24
  var clipped = w.div(Math.max(1, norm_1 / clippingRadius));
32
25
  if (noiseScale !== undefined) {
@@ -49,12 +42,6 @@ function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
49
42
  return updatedWeights;
50
43
  }
51
44
  }
52
- return (0, immutable_1.List)(staleWeights)
53
- .zip(newWeightsDiff)
54
- .map(function (_a) {
55
- var _b = (0, tslib_1.__read)(_a, 2), w = _b[0], d = _b[1];
56
- return w.add(d);
57
- })
58
- .toArray();
45
+ return staleWeights.add(newWeightsDiff);
59
46
  }
60
47
  exports.addDifferentialPrivacy = addDifferentialPrivacy;
@@ -0,0 +1,5 @@
1
+ import { WeightsContainer } from '..';
2
+ export declare type Encoded = number[];
3
+ export declare function isEncoded(raw: unknown): raw is Encoded;
4
+ export declare function encode(weights: WeightsContainer): Promise<Encoded>;
5
+ export declare function decode(encoded: Encoded): WeightsContainer;
@@ -2,8 +2,8 @@
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.decode = exports.encode = exports.isEncoded = void 0;
4
4
  var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
5
  var msgpack = (0, tslib_1.__importStar)(require("msgpack-lite"));
6
+ var __1 = require("..");
7
7
  function isSerialized(raw) {
8
8
  if (typeof raw !== 'object' || raw === null) {
9
9
  return false;
@@ -13,7 +13,7 @@ function isSerialized(raw) {
13
13
  }
14
14
  var _a = raw, shape = _a.shape, data = _a.data;
15
15
  if (!(Array.isArray(shape) && shape.every(function (e) { return typeof e === 'number'; })) ||
16
- !(data instanceof Float32Array)) {
16
+ !(Array.isArray(data) && data.every(function (e) { return typeof e === 'number'; }))) {
17
17
  return false;
18
18
  }
19
19
  // eslint-disable-next-line
@@ -30,17 +30,19 @@ function encode(weights) {
30
30
  var _this = this;
31
31
  return (0, tslib_1.__generator)(this, function (_a) {
32
32
  switch (_a.label) {
33
- case 0: return [4 /*yield*/, Promise.all(weights.map(function (t) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
33
+ case 0: return [4 /*yield*/, Promise.all(weights.weights.map(function (t) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
34
34
  var _a;
35
- return (0, tslib_1.__generator)(this, function (_b) {
36
- switch (_b.label) {
35
+ var _b;
36
+ return (0, tslib_1.__generator)(this, function (_c) {
37
+ switch (_c.label) {
37
38
  case 0:
38
- _a = {
39
+ _b = {
39
40
  shape: t.shape
40
41
  };
42
+ _a = [[]];
41
43
  return [4 /*yield*/, t.data()];
42
- case 1: return [2 /*return*/, (_a.data = _b.sent(),
43
- _a)];
44
+ case 1: return [2 /*return*/, (_b.data = tslib_1.__spreadArray.apply(void 0, _a.concat([tslib_1.__read.apply(void 0, [_c.sent()]), false])),
45
+ _b)];
44
46
  }
45
47
  });
46
48
  }); }))];
@@ -57,6 +59,6 @@ function decode(encoded) {
57
59
  if (!(Array.isArray(raw) && raw.every(isSerialized))) {
58
60
  throw new Error('expected to decode an array of serialized weights');
59
61
  }
60
- return raw.map(function (w) { return __1.tf.tensor(w.data, w.shape); });
62
+ return new __1.WeightsContainer(raw.map(function (w) { return __1.tf.tensor(w.data, w.shape); }));
61
63
  }
62
64
  exports.decode = decode;
File without changes
File without changes
@@ -2,11 +2,11 @@ import { Summary } from './summary';
2
2
  import { DataExample } from './data_example';
3
3
  export declare function isDisplayInformation(raw: unknown): raw is DisplayInformation;
4
4
  export interface DisplayInformation {
5
- taskTitle: string;
6
- summary: Summary;
7
- tradeoffs: string;
8
- dataFormatInformation: string;
9
- dataExampleText: string;
5
+ taskTitle?: string;
6
+ summary?: Summary;
7
+ tradeoffs?: string;
8
+ dataFormatInformation?: string;
9
+ dataExampleText?: string;
10
10
  model?: string;
11
11
  dataExample?: DataExample[];
12
12
  headers?: string[];
@@ -1,7 +1,6 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.isDisplayInformation = void 0;
4
- var immutable_1 = require("immutable");
5
4
  var summary_1 = require("./summary");
6
5
  var data_example_1 = require("./data_example");
7
6
  function isDisplayInformation(raw) {
@@ -11,21 +10,17 @@ function isDisplayInformation(raw) {
11
10
  if (raw === null) {
12
11
  return false;
13
12
  }
14
- var requiredFields = immutable_1.Set.of('dataExampleText', 'dataFormatInformation', 'summary', 'taskTitle', 'tradeoffs');
15
- if (!requiredFields.isSubset(Object.keys(raw))) {
16
- return false;
17
- }
18
13
  var _a = raw, dataExample = _a.dataExample, dataExampleImage = _a.dataExampleImage, dataExampleText = _a.dataExampleText, dataFormatInformation = _a.dataFormatInformation, headers = _a.headers, limitations = _a.limitations, model = _a.model, summary = _a.summary, taskTitle = _a.taskTitle, tradeoffs = _a.tradeoffs;
19
- if (typeof dataExampleText !== 'string' ||
20
- typeof dataFormatInformation !== 'string' ||
21
- typeof taskTitle !== 'string' ||
22
- typeof tradeoffs !== 'string' ||
14
+ if (typeof taskTitle !== 'string' ||
15
+ (dataExampleText !== undefined && typeof dataExampleText !== 'string') ||
16
+ (dataFormatInformation !== undefined && typeof dataFormatInformation !== 'string') ||
17
+ (tradeoffs !== undefined && typeof tradeoffs !== 'string') ||
23
18
  (model !== undefined && typeof model !== 'string') ||
24
19
  (dataExampleImage !== undefined && typeof dataExampleImage !== 'string') ||
25
20
  (limitations !== undefined && typeof limitations !== 'string')) {
26
21
  return false;
27
22
  }
28
- if (!(0, summary_1.isSummary)(summary)) {
23
+ if (summary !== undefined && !(0, summary_1.isSummary)(summary)) {
29
24
  return false;
30
25
  }
31
26
  if (dataExample !== undefined && !(Array.isArray(dataExample) &&
File without changes
File without changes
@@ -0,0 +1,6 @@
1
+ export declare function isModelCompileData(raw: unknown): raw is ModelCompileData;
2
+ export interface ModelCompileData {
3
+ optimizer: string;
4
+ loss: string;
5
+ metrics: string[];
6
+ }
@@ -0,0 +1,22 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.isModelCompileData = void 0;
4
+ function isModelCompileData(raw) {
5
+ if (typeof raw !== 'object') {
6
+ return false;
7
+ }
8
+ if (raw === null) {
9
+ return false;
10
+ }
11
+ var _a = raw, optimizer = _a.optimizer, loss = _a.loss, metrics = _a.metrics;
12
+ if (typeof optimizer !== 'string' ||
13
+ typeof loss !== 'string') {
14
+ return false;
15
+ }
16
+ if (!(Array.isArray(metrics) &&
17
+ metrics.every(function (e) { return typeof e === 'string'; }))) {
18
+ return false;
19
+ }
20
+ return true;
21
+ }
22
+ exports.isModelCompileData = isModelCompileData;
File without changes
@@ -1,7 +1,6 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.isSummary = void 0;
4
- var immutable_1 = require("immutable");
5
4
  function isSummary(raw) {
6
5
  if (typeof raw !== 'object') {
7
6
  return false;
@@ -9,9 +8,6 @@ function isSummary(raw) {
9
8
  if (raw === null) {
10
9
  return false;
11
10
  }
12
- if (!(0, immutable_1.Set)(Object.keys(raw)).equals(immutable_1.Set.of('preview', 'overview'))) {
13
- return false;
14
- }
15
11
  var _a = raw, preview = _a.preview, overview = _a.overview;
16
12
  if (!(typeof preview === 'string' && typeof overview === 'string')) {
17
13
  return false;
@@ -5,6 +5,6 @@ export declare function isTaskID(obj: unknown): obj is TaskID;
5
5
  export declare function isTask(raw: unknown): raw is Task;
6
6
  export interface Task {
7
7
  taskID: TaskID;
8
- displayInformation?: DisplayInformation;
9
- trainingInformation?: TrainingInformation;
8
+ displayInformation: DisplayInformation;
9
+ trainingInformation: TrainingInformation;
10
10
  }
@@ -2,6 +2,7 @@
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.isTask = exports.isTaskID = void 0;
4
4
  var display_information_1 = require("./display_information");
5
+ var training_information_1 = require("./training_information");
5
6
  function isTaskID(obj) {
6
7
  return typeof obj === 'string';
7
8
  }
@@ -13,20 +14,18 @@ function isTask(raw) {
13
14
  if (raw === null) {
14
15
  return false;
15
16
  }
16
- if (!('taskID' in raw)) {
17
+ var _a = raw, taskID = _a.taskID, displayInformation = _a.displayInformation, trainingInformation = _a.trainingInformation;
18
+ if (typeof taskID !== 'string') {
17
19
  return false;
18
20
  }
19
- var _a = raw, taskID = _a.taskID, displayInformation = _a.displayInformation;
20
- if (typeof taskID !== 'string') {
21
+ if (!(0, display_information_1.isDisplayInformation)(displayInformation)) {
21
22
  return false;
22
23
  }
23
- if (displayInformation !== undefined &&
24
- !(0, display_information_1.isDisplayInformation)(displayInformation)) {
24
+ if (!(0, training_information_1.isTrainingInformation)(trainingInformation)) {
25
25
  return false;
26
26
  }
27
- // TODO check for TrainingInformation
28
27
  // eslint-disable-next-line @typescript-eslint/no-unused-vars
29
- var _ = { taskID: taskID, displayInformation: displayInformation };
28
+ var _ = { taskID: taskID, displayInformation: displayInformation, trainingInformation: trainingInformation };
30
29
  return true;
31
30
  }
32
31
  exports.isTask = isTask;