@epfml/discojs-node 2.0.0 → 2.1.2-p20240506085037.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (182) hide show
  1. package/dist/data/image_loader.d.ts +5 -0
  2. package/dist/data/image_loader.js +11 -0
  3. package/dist/data/index.d.ts +3 -0
  4. package/dist/data/index.js +3 -0
  5. package/dist/data/tabular_loader.d.ts +4 -0
  6. package/dist/data/tabular_loader.js +11 -0
  7. package/dist/data/text_loader.d.ts +4 -0
  8. package/dist/data/text_loader.js +14 -0
  9. package/dist/index.d.ts +2 -2
  10. package/dist/index.js +2 -6
  11. package/package.json +14 -17
  12. package/README.md +0 -53
  13. package/dist/core/async_buffer.d.ts +0 -41
  14. package/dist/core/async_buffer.js +0 -97
  15. package/dist/core/async_informant.d.ts +0 -20
  16. package/dist/core/async_informant.js +0 -69
  17. package/dist/core/client/base.d.ts +0 -33
  18. package/dist/core/client/base.js +0 -35
  19. package/dist/core/client/decentralized/base.d.ts +0 -32
  20. package/dist/core/client/decentralized/base.js +0 -212
  21. package/dist/core/client/decentralized/clear_text.d.ts +0 -14
  22. package/dist/core/client/decentralized/clear_text.js +0 -96
  23. package/dist/core/client/decentralized/index.d.ts +0 -4
  24. package/dist/core/client/decentralized/index.js +0 -9
  25. package/dist/core/client/decentralized/messages.d.ts +0 -41
  26. package/dist/core/client/decentralized/messages.js +0 -54
  27. package/dist/core/client/decentralized/peer.d.ts +0 -26
  28. package/dist/core/client/decentralized/peer.js +0 -210
  29. package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
  30. package/dist/core/client/decentralized/peer_pool.js +0 -92
  31. package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
  32. package/dist/core/client/decentralized/sec_agg.js +0 -190
  33. package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
  34. package/dist/core/client/decentralized/secret_shares.js +0 -39
  35. package/dist/core/client/decentralized/types.d.ts +0 -2
  36. package/dist/core/client/decentralized/types.js +0 -7
  37. package/dist/core/client/event_connection.d.ts +0 -37
  38. package/dist/core/client/event_connection.js +0 -158
  39. package/dist/core/client/federated/client.d.ts +0 -37
  40. package/dist/core/client/federated/client.js +0 -273
  41. package/dist/core/client/federated/index.d.ts +0 -2
  42. package/dist/core/client/federated/index.js +0 -7
  43. package/dist/core/client/federated/messages.d.ts +0 -38
  44. package/dist/core/client/federated/messages.js +0 -25
  45. package/dist/core/client/index.d.ts +0 -5
  46. package/dist/core/client/index.js +0 -11
  47. package/dist/core/client/local.d.ts +0 -8
  48. package/dist/core/client/local.js +0 -36
  49. package/dist/core/client/messages.d.ts +0 -28
  50. package/dist/core/client/messages.js +0 -33
  51. package/dist/core/client/utils.d.ts +0 -2
  52. package/dist/core/client/utils.js +0 -19
  53. package/dist/core/dataset/data/data.d.ts +0 -11
  54. package/dist/core/dataset/data/data.js +0 -20
  55. package/dist/core/dataset/data/data_split.d.ts +0 -5
  56. package/dist/core/dataset/data/data_split.js +0 -2
  57. package/dist/core/dataset/data/image_data.d.ts +0 -8
  58. package/dist/core/dataset/data/image_data.js +0 -64
  59. package/dist/core/dataset/data/index.d.ts +0 -5
  60. package/dist/core/dataset/data/index.js +0 -11
  61. package/dist/core/dataset/data/preprocessing.d.ts +0 -13
  62. package/dist/core/dataset/data/preprocessing.js +0 -33
  63. package/dist/core/dataset/data/tabular_data.d.ts +0 -8
  64. package/dist/core/dataset/data/tabular_data.js +0 -40
  65. package/dist/core/dataset/data_loader/data_loader.d.ts +0 -15
  66. package/dist/core/dataset/data_loader/data_loader.js +0 -10
  67. package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
  68. package/dist/core/dataset/data_loader/image_loader.js +0 -141
  69. package/dist/core/dataset/data_loader/index.d.ts +0 -3
  70. package/dist/core/dataset/data_loader/index.js +0 -9
  71. package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
  72. package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
  73. package/dist/core/dataset/dataset.d.ts +0 -2
  74. package/dist/core/dataset/dataset.js +0 -2
  75. package/dist/core/dataset/dataset_builder.d.ts +0 -18
  76. package/dist/core/dataset/dataset_builder.js +0 -96
  77. package/dist/core/dataset/index.d.ts +0 -4
  78. package/dist/core/dataset/index.js +0 -14
  79. package/dist/core/index.d.ts +0 -18
  80. package/dist/core/index.js +0 -41
  81. package/dist/core/informant/graph_informant.d.ts +0 -10
  82. package/dist/core/informant/graph_informant.js +0 -23
  83. package/dist/core/informant/index.d.ts +0 -3
  84. package/dist/core/informant/index.js +0 -9
  85. package/dist/core/informant/training_informant/base.d.ts +0 -31
  86. package/dist/core/informant/training_informant/base.js +0 -83
  87. package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
  88. package/dist/core/informant/training_informant/decentralized.js +0 -22
  89. package/dist/core/informant/training_informant/federated.d.ts +0 -14
  90. package/dist/core/informant/training_informant/federated.js +0 -32
  91. package/dist/core/informant/training_informant/index.d.ts +0 -4
  92. package/dist/core/informant/training_informant/index.js +0 -11
  93. package/dist/core/informant/training_informant/local.d.ts +0 -6
  94. package/dist/core/informant/training_informant/local.js +0 -20
  95. package/dist/core/logging/console_logger.d.ts +0 -18
  96. package/dist/core/logging/console_logger.js +0 -33
  97. package/dist/core/logging/index.d.ts +0 -3
  98. package/dist/core/logging/index.js +0 -9
  99. package/dist/core/logging/logger.d.ts +0 -12
  100. package/dist/core/logging/logger.js +0 -9
  101. package/dist/core/logging/trainer_logger.d.ts +0 -24
  102. package/dist/core/logging/trainer_logger.js +0 -59
  103. package/dist/core/memory/base.d.ts +0 -22
  104. package/dist/core/memory/base.js +0 -9
  105. package/dist/core/memory/empty.d.ts +0 -14
  106. package/dist/core/memory/empty.js +0 -75
  107. package/dist/core/memory/index.d.ts +0 -3
  108. package/dist/core/memory/index.js +0 -9
  109. package/dist/core/memory/model_type.d.ts +0 -4
  110. package/dist/core/memory/model_type.js +0 -9
  111. package/dist/core/privacy.d.ts +0 -11
  112. package/dist/core/privacy.js +0 -47
  113. package/dist/core/serialization/index.d.ts +0 -2
  114. package/dist/core/serialization/index.js +0 -6
  115. package/dist/core/serialization/model.d.ts +0 -5
  116. package/dist/core/serialization/model.js +0 -55
  117. package/dist/core/serialization/weights.d.ts +0 -5
  118. package/dist/core/serialization/weights.js +0 -64
  119. package/dist/core/task/data_example.d.ts +0 -5
  120. package/dist/core/task/data_example.js +0 -24
  121. package/dist/core/task/display_information.d.ts +0 -15
  122. package/dist/core/task/display_information.js +0 -49
  123. package/dist/core/task/index.d.ts +0 -3
  124. package/dist/core/task/index.js +0 -8
  125. package/dist/core/task/model_compile_data.d.ts +0 -6
  126. package/dist/core/task/model_compile_data.js +0 -22
  127. package/dist/core/task/summary.d.ts +0 -5
  128. package/dist/core/task/summary.js +0 -19
  129. package/dist/core/task/task.d.ts +0 -10
  130. package/dist/core/task/task.js +0 -31
  131. package/dist/core/task/training_information.d.ts +0 -28
  132. package/dist/core/task/training_information.js +0 -66
  133. package/dist/core/tasks/cifar10.d.ts +0 -3
  134. package/dist/core/tasks/cifar10.js +0 -65
  135. package/dist/core/tasks/geotags.d.ts +0 -3
  136. package/dist/core/tasks/geotags.js +0 -67
  137. package/dist/core/tasks/index.d.ts +0 -6
  138. package/dist/core/tasks/index.js +0 -10
  139. package/dist/core/tasks/lus_covid.d.ts +0 -3
  140. package/dist/core/tasks/lus_covid.js +0 -87
  141. package/dist/core/tasks/mnist.d.ts +0 -3
  142. package/dist/core/tasks/mnist.js +0 -60
  143. package/dist/core/tasks/simple_face.d.ts +0 -2
  144. package/dist/core/tasks/simple_face.js +0 -41
  145. package/dist/core/tasks/titanic.d.ts +0 -3
  146. package/dist/core/tasks/titanic.js +0 -88
  147. package/dist/core/training/disco.d.ts +0 -23
  148. package/dist/core/training/disco.js +0 -130
  149. package/dist/core/training/index.d.ts +0 -2
  150. package/dist/core/training/index.js +0 -7
  151. package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
  152. package/dist/core/training/trainer/distributed_trainer.js +0 -65
  153. package/dist/core/training/trainer/local_trainer.d.ts +0 -11
  154. package/dist/core/training/trainer/local_trainer.js +0 -34
  155. package/dist/core/training/trainer/round_tracker.d.ts +0 -30
  156. package/dist/core/training/trainer/round_tracker.js +0 -47
  157. package/dist/core/training/trainer/trainer.d.ts +0 -65
  158. package/dist/core/training/trainer/trainer.js +0 -160
  159. package/dist/core/training/trainer/trainer_builder.d.ts +0 -25
  160. package/dist/core/training/trainer/trainer_builder.js +0 -95
  161. package/dist/core/training/training_schemes.d.ts +0 -5
  162. package/dist/core/training/training_schemes.js +0 -10
  163. package/dist/core/types.d.ts +0 -4
  164. package/dist/core/types.js +0 -2
  165. package/dist/core/validation/index.d.ts +0 -1
  166. package/dist/core/validation/index.js +0 -5
  167. package/dist/core/validation/validator.d.ts +0 -17
  168. package/dist/core/validation/validator.js +0 -104
  169. package/dist/core/weights/aggregation.d.ts +0 -8
  170. package/dist/core/weights/aggregation.js +0 -96
  171. package/dist/core/weights/index.d.ts +0 -2
  172. package/dist/core/weights/index.js +0 -7
  173. package/dist/core/weights/weights_container.d.ts +0 -19
  174. package/dist/core/weights/weights_container.js +0 -64
  175. package/dist/dataset/data_loader/image_loader.d.ts +0 -4
  176. package/dist/dataset/data_loader/image_loader.js +0 -21
  177. package/dist/dataset/data_loader/index.d.ts +0 -2
  178. package/dist/dataset/data_loader/index.js +0 -7
  179. package/dist/dataset/data_loader/tabular_loader.d.ts +0 -4
  180. package/dist/dataset/data_loader/tabular_loader.js +0 -20
  181. package/dist/imports.d.ts +0 -1
  182. package/dist/imports.js +0 -5
@@ -1,47 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.addDifferentialPrivacy = void 0;
4
- var _1 = require(".");
5
- /**
6
- * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
7
- * The previous round's weights are the last weights pulled from server/peers.
8
- * The current round's weights are obtained after a single round of training, from the previous round's weights.
9
- * @param updatedWeights weights from the current round
10
- * @param staleWeights weights from the previous round
11
- * @param task the task
12
- * @returns the noised weights for the current round
13
- */
14
- function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
15
- var _a, _b;
16
- var noiseScale = (_a = task.trainingInformation) === null || _a === void 0 ? void 0 : _a.noiseScale;
17
- var clippingRadius = (_b = task.trainingInformation) === null || _b === void 0 ? void 0 : _b.clippingRadius;
18
- var weightsDiff = updatedWeights.sub(staleWeights);
19
- var newWeightsDiff;
20
- if (clippingRadius !== undefined) {
21
- // Frobenius norm
22
- var norm_1 = weightsDiff.frobeniusNorm();
23
- newWeightsDiff = weightsDiff.map(function (w) {
24
- var clipped = w.div(Math.max(1, norm_1 / clippingRadius));
25
- if (noiseScale !== undefined) {
26
- // Add clipping and noise
27
- var noise = _1.tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
28
- return clipped.add(noise);
29
- }
30
- else {
31
- // Add clipping without any noise
32
- return clipped;
33
- }
34
- });
35
- }
36
- else {
37
- if (noiseScale !== undefined) {
38
- // Add noise without any clipping
39
- newWeightsDiff = weightsDiff.map(function (w) { return _1.tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)); });
40
- }
41
- else {
42
- return updatedWeights;
43
- }
44
- }
45
- return staleWeights.add(newWeightsDiff);
46
- }
47
- exports.addDifferentialPrivacy = addDifferentialPrivacy;
@@ -1,2 +0,0 @@
1
- export * as model from './model';
2
- export * as weights from './weights';
@@ -1,6 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.weights = exports.model = void 0;
4
- var tslib_1 = require("tslib");
5
- exports.model = (0, tslib_1.__importStar)(require("./model"));
6
- exports.weights = (0, tslib_1.__importStar)(require("./weights"));
@@ -1,5 +0,0 @@
1
- import { tf } from '..';
2
- export declare type Encoded = number[];
3
- export declare function isEncoded(raw: unknown): raw is Encoded;
4
- export declare function encode(model: tf.LayersModel): Promise<Encoded>;
5
- export declare function decode(encoded: Encoded): Promise<tf.LayersModel>;
@@ -1,55 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.decode = exports.encode = exports.isEncoded = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
- var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
7
- function isEncoded(raw) {
8
- return Array.isArray(raw) && raw.every(function (r) { return typeof r === 'number'; });
9
- }
10
- exports.isEncoded = isEncoded;
11
- function encode(model) {
12
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
13
- var saved;
14
- var _this = this;
15
- return (0, tslib_1.__generator)(this, function (_a) {
16
- switch (_a.label) {
17
- case 0: return [4 /*yield*/, new Promise(function (resolve) {
18
- void model.save({
19
- save: function (artifacts) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
20
- return (0, tslib_1.__generator)(this, function (_a) {
21
- resolve(artifacts);
22
- return [2 /*return*/, {
23
- modelArtifactsInfo: {
24
- dateSaved: new Date(),
25
- modelTopologyType: 'JSON'
26
- }
27
- }];
28
- });
29
- }); }
30
- });
31
- })];
32
- case 1:
33
- saved = _a.sent();
34
- return [2 /*return*/, (0, tslib_1.__spreadArray)([], (0, tslib_1.__read)(msgpack_lite_1.default.encode(saved).values()), false)];
35
- }
36
- });
37
- });
38
- }
39
- exports.encode = encode;
40
- function decode(encoded) {
41
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
42
- var raw;
43
- return (0, tslib_1.__generator)(this, function (_a) {
44
- switch (_a.label) {
45
- case 0:
46
- raw = msgpack_lite_1.default.decode(encoded);
47
- return [4 /*yield*/, __1.tf.loadLayersModel({
48
- load: function () { return raw; }
49
- })];
50
- case 1: return [2 /*return*/, _a.sent()];
51
- }
52
- });
53
- });
54
- }
55
- exports.decode = decode;
@@ -1,5 +0,0 @@
1
- import { WeightsContainer } from '..';
2
- export declare type Encoded = number[];
3
- export declare function isEncoded(raw: unknown): raw is Encoded;
4
- export declare function encode(weights: WeightsContainer): Promise<Encoded>;
5
- export declare function decode(encoded: Encoded): WeightsContainer;
@@ -1,64 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.decode = exports.encode = exports.isEncoded = void 0;
4
- var tslib_1 = require("tslib");
5
- var msgpack = (0, tslib_1.__importStar)(require("msgpack-lite"));
6
- var __1 = require("..");
7
- function isSerialized(raw) {
8
- if (typeof raw !== 'object' || raw === null) {
9
- return false;
10
- }
11
- if (!('shape' in raw && 'data' in raw)) {
12
- return false;
13
- }
14
- var _a = raw, shape = _a.shape, data = _a.data;
15
- if (!(Array.isArray(shape) && shape.every(function (e) { return typeof e === 'number'; })) ||
16
- !(Array.isArray(data) && data.every(function (e) { return typeof e === 'number'; }))) {
17
- return false;
18
- }
19
- // eslint-disable-next-line
20
- var _ = { shape: shape, data: data };
21
- return true;
22
- }
23
- function isEncoded(raw) {
24
- return Array.isArray(raw) && raw.every(function (e) { return typeof e === 'number'; });
25
- }
26
- exports.isEncoded = isEncoded;
27
- function encode(weights) {
28
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
29
- var serialized;
30
- var _this = this;
31
- return (0, tslib_1.__generator)(this, function (_a) {
32
- switch (_a.label) {
33
- case 0: return [4 /*yield*/, Promise.all(weights.weights.map(function (t) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
34
- var _a;
35
- var _b;
36
- return (0, tslib_1.__generator)(this, function (_c) {
37
- switch (_c.label) {
38
- case 0:
39
- _b = {
40
- shape: t.shape
41
- };
42
- _a = [[]];
43
- return [4 /*yield*/, t.data()];
44
- case 1: return [2 /*return*/, (_b.data = tslib_1.__spreadArray.apply(void 0, _a.concat([tslib_1.__read.apply(void 0, [_c.sent()]), false])),
45
- _b)];
46
- }
47
- });
48
- }); }))];
49
- case 1:
50
- serialized = _a.sent();
51
- return [2 /*return*/, (0, tslib_1.__spreadArray)([], (0, tslib_1.__read)(msgpack.encode(serialized).values()), false)];
52
- }
53
- });
54
- });
55
- }
56
- exports.encode = encode;
57
- function decode(encoded) {
58
- var raw = msgpack.decode(encoded);
59
- if (!(Array.isArray(raw) && raw.every(isSerialized))) {
60
- throw new Error('expected to decode an array of serialized weights');
61
- }
62
- return new __1.WeightsContainer(raw.map(function (w) { return __1.tf.tensor(w.data, w.shape); }));
63
- }
64
- exports.decode = decode;
@@ -1,5 +0,0 @@
1
- export declare function isDataExample(raw: unknown): raw is DataExample;
2
- export interface DataExample {
3
- columnName: string;
4
- columnData: string | number;
5
- }
@@ -1,24 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isDataExample = void 0;
4
- var immutable_1 = require("immutable");
5
- function isDataExample(raw) {
6
- if (typeof raw !== 'object') {
7
- return false;
8
- }
9
- if (raw === null) {
10
- return false;
11
- }
12
- if (!(0, immutable_1.Set)(Object.keys(raw)).equals(immutable_1.Set.of('columnName', 'columnData'))) {
13
- return false;
14
- }
15
- var _a = raw, columnName = _a.columnName, columnData = _a.columnData;
16
- if (typeof columnName !== 'string' ||
17
- (typeof columnData !== 'string' && typeof columnData !== 'number')) {
18
- return false;
19
- }
20
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
21
- var _ = { columnName: columnName, columnData: columnData };
22
- return true;
23
- }
24
- exports.isDataExample = isDataExample;
@@ -1,15 +0,0 @@
1
- import { Summary } from './summary';
2
- import { DataExample } from './data_example';
3
- export declare function isDisplayInformation(raw: unknown): raw is DisplayInformation;
4
- export interface DisplayInformation {
5
- taskTitle?: string;
6
- summary?: Summary;
7
- tradeoffs?: string;
8
- dataFormatInformation?: string;
9
- dataExampleText?: string;
10
- model?: string;
11
- dataExample?: DataExample[];
12
- headers?: string[];
13
- dataExampleImage?: string;
14
- limitations?: string;
15
- }
@@ -1,49 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isDisplayInformation = void 0;
4
- var summary_1 = require("./summary");
5
- var data_example_1 = require("./data_example");
6
- function isDisplayInformation(raw) {
7
- if (typeof raw !== 'object') {
8
- return false;
9
- }
10
- if (raw === null) {
11
- return false;
12
- }
13
- var _a = raw, dataExample = _a.dataExample, dataExampleImage = _a.dataExampleImage, dataExampleText = _a.dataExampleText, dataFormatInformation = _a.dataFormatInformation, headers = _a.headers, limitations = _a.limitations, model = _a.model, summary = _a.summary, taskTitle = _a.taskTitle, tradeoffs = _a.tradeoffs;
14
- if (typeof taskTitle !== 'string' ||
15
- (dataExampleText !== undefined && typeof dataExampleText !== 'string') ||
16
- (dataFormatInformation !== undefined && typeof dataFormatInformation !== 'string') ||
17
- (tradeoffs !== undefined && typeof tradeoffs !== 'string') ||
18
- (model !== undefined && typeof model !== 'string') ||
19
- (dataExampleImage !== undefined && typeof dataExampleImage !== 'string') ||
20
- (limitations !== undefined && typeof limitations !== 'string')) {
21
- return false;
22
- }
23
- if (summary !== undefined && !(0, summary_1.isSummary)(summary)) {
24
- return false;
25
- }
26
- if (dataExample !== undefined && !(Array.isArray(dataExample) &&
27
- dataExample.every(data_example_1.isDataExample))) {
28
- return false;
29
- }
30
- if (headers !== undefined && !(Array.isArray(headers) &&
31
- headers.every(function (e) { return typeof e === 'string'; }))) {
32
- return false;
33
- }
34
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
35
- var _ = {
36
- taskTitle: taskTitle,
37
- summary: summary,
38
- tradeoffs: tradeoffs,
39
- dataFormatInformation: dataFormatInformation,
40
- dataExampleText: dataExampleText,
41
- model: model,
42
- dataExample: dataExample,
43
- headers: headers,
44
- dataExampleImage: dataExampleImage,
45
- limitations: limitations
46
- };
47
- return true;
48
- }
49
- exports.isDisplayInformation = isDisplayInformation;
@@ -1,3 +0,0 @@
1
- export { isTask, Task, isTaskID, TaskID } from './task';
2
- export { isDisplayInformation, DisplayInformation } from './display_information';
3
- export { TrainingInformation } from './training_information';
@@ -1,8 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isDisplayInformation = exports.isTaskID = exports.isTask = void 0;
4
- var task_1 = require("./task");
5
- Object.defineProperty(exports, "isTask", { enumerable: true, get: function () { return task_1.isTask; } });
6
- Object.defineProperty(exports, "isTaskID", { enumerable: true, get: function () { return task_1.isTaskID; } });
7
- var display_information_1 = require("./display_information");
8
- Object.defineProperty(exports, "isDisplayInformation", { enumerable: true, get: function () { return display_information_1.isDisplayInformation; } });
@@ -1,6 +0,0 @@
1
- export declare function isModelCompileData(raw: unknown): raw is ModelCompileData;
2
- export interface ModelCompileData {
3
- optimizer: string;
4
- loss: string;
5
- metrics: string[];
6
- }
@@ -1,22 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isModelCompileData = void 0;
4
- function isModelCompileData(raw) {
5
- if (typeof raw !== 'object') {
6
- return false;
7
- }
8
- if (raw === null) {
9
- return false;
10
- }
11
- var _a = raw, optimizer = _a.optimizer, loss = _a.loss, metrics = _a.metrics;
12
- if (typeof optimizer !== 'string' ||
13
- typeof loss !== 'string') {
14
- return false;
15
- }
16
- if (!(Array.isArray(metrics) &&
17
- metrics.every(function (e) { return typeof e === 'string'; }))) {
18
- return false;
19
- }
20
- return true;
21
- }
22
- exports.isModelCompileData = isModelCompileData;
@@ -1,5 +0,0 @@
1
- export declare function isSummary(raw: unknown): raw is Summary;
2
- export interface Summary {
3
- preview: string;
4
- overview: string;
5
- }
@@ -1,19 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isSummary = void 0;
4
- function isSummary(raw) {
5
- if (typeof raw !== 'object') {
6
- return false;
7
- }
8
- if (raw === null) {
9
- return false;
10
- }
11
- var _a = raw, preview = _a.preview, overview = _a.overview;
12
- if (!(typeof preview === 'string' && typeof overview === 'string')) {
13
- return false;
14
- }
15
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
16
- var _ = { preview: preview, overview: overview };
17
- return true;
18
- }
19
- exports.isSummary = isSummary;
@@ -1,10 +0,0 @@
1
- import { DisplayInformation } from './display_information';
2
- import { TrainingInformation } from './training_information';
3
- export declare type TaskID = string;
4
- export declare function isTaskID(obj: unknown): obj is TaskID;
5
- export declare function isTask(raw: unknown): raw is Task;
6
- export interface Task {
7
- taskID: TaskID;
8
- displayInformation: DisplayInformation;
9
- trainingInformation: TrainingInformation;
10
- }
@@ -1,31 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isTask = exports.isTaskID = void 0;
4
- var display_information_1 = require("./display_information");
5
- var training_information_1 = require("./training_information");
6
- function isTaskID(obj) {
7
- return typeof obj === 'string';
8
- }
9
- exports.isTaskID = isTaskID;
10
- function isTask(raw) {
11
- if (typeof raw !== 'object') {
12
- return false;
13
- }
14
- if (raw === null) {
15
- return false;
16
- }
17
- var _a = raw, taskID = _a.taskID, displayInformation = _a.displayInformation, trainingInformation = _a.trainingInformation;
18
- if (typeof taskID !== 'string') {
19
- return false;
20
- }
21
- if (!(0, display_information_1.isDisplayInformation)(displayInformation)) {
22
- return false;
23
- }
24
- if (!(0, training_information_1.isTrainingInformation)(trainingInformation)) {
25
- return false;
26
- }
27
- // eslint-disable-next-line @typescript-eslint/no-unused-vars
28
- var _ = { taskID: taskID, displayInformation: displayInformation, trainingInformation: trainingInformation };
29
- return true;
30
- }
31
- exports.isTask = isTask;
@@ -1,28 +0,0 @@
1
- import { Preprocessing } from '../dataset/data/preprocessing';
2
- import { ModelCompileData } from './model_compile_data';
3
- export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
4
- export interface TrainingInformation {
5
- modelID: string;
6
- epochs: number;
7
- roundDuration: number;
8
- validationSplit: number;
9
- batchSize: number;
10
- preprocessingFunctions?: Preprocessing[];
11
- modelCompileData: ModelCompileData;
12
- dataType: string;
13
- inputColumns?: string[];
14
- outputColumns?: string[];
15
- IMAGE_H?: number;
16
- IMAGE_W?: number;
17
- modelURL?: string;
18
- LABEL_LIST?: string[];
19
- learningRate?: number;
20
- scheme: string;
21
- noiseScale?: number;
22
- clippingRadius?: number;
23
- decentralizedSecure?: boolean;
24
- byzantineRobustAggregator?: boolean;
25
- tauPercentile?: number;
26
- maxShareValue?: number;
27
- minimumReadyPeers?: number;
28
- }
@@ -1,66 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.isTrainingInformation = void 0;
4
- var model_compile_data_1 = require("./model_compile_data");
5
- function isTrainingInformation(raw) {
6
- if (typeof raw !== 'object') {
7
- return false;
8
- }
9
- if (raw === null) {
10
- return false;
11
- }
12
- var _a = raw, dataType = _a.dataType, scheme = _a.scheme, epochs = _a.epochs,
13
- // roundDuration,
14
- validationSplit = _a.validationSplit, batchSize = _a.batchSize, modelCompileData = _a.modelCompileData, modelID = _a.modelID, preprocessingFunctions = _a.preprocessingFunctions, inputColumns = _a.inputColumns, outputColumns = _a.outputColumns, IMAGE_H = _a.IMAGE_H, IMAGE_W = _a.IMAGE_W, roundDuration = _a.roundDuration, modelURL = _a.modelURL, learningRate = _a.learningRate, decentralizedSecure = _a.decentralizedSecure, maxShareValue = _a.maxShareValue, minimumReadyPeers = _a.minimumReadyPeers, LABEL_LIST = _a.LABEL_LIST, noiseScale = _a.noiseScale, clippingRadius = _a.clippingRadius;
15
- if (typeof dataType !== 'string' ||
16
- typeof modelID !== 'string' ||
17
- typeof epochs !== 'number' ||
18
- typeof batchSize !== 'number' ||
19
- typeof roundDuration !== 'number' ||
20
- typeof validationSplit !== 'number' ||
21
- (modelURL !== undefined && typeof modelURL !== 'string') ||
22
- (noiseScale !== undefined && typeof noiseScale !== 'number') ||
23
- (clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
24
- (learningRate !== undefined && typeof learningRate !== 'number') ||
25
- (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
26
- (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
27
- (minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number')) {
28
- return false;
29
- }
30
- // interdepences on data type
31
- switch (dataType) {
32
- case 'image':
33
- if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
34
- return false;
35
- }
36
- break;
37
- case 'tabular':
38
- if (!(Array.isArray(inputColumns) && inputColumns.every(function (e) { return typeof e === 'string'; }))) {
39
- return false;
40
- }
41
- if (!(Array.isArray(outputColumns) && outputColumns.every(function (e) { return typeof e === 'string'; }))) {
42
- return false;
43
- }
44
- break;
45
- }
46
- // interdepences on scheme
47
- switch (scheme) {
48
- case 'decentralized':
49
- break;
50
- case 'federated':
51
- break;
52
- case 'local':
53
- break;
54
- }
55
- if (!(0, model_compile_data_1.isModelCompileData)(modelCompileData)) {
56
- return false;
57
- }
58
- if (LABEL_LIST !== undefined && !(Array.isArray(LABEL_LIST) && LABEL_LIST.every(function (e) { return typeof e === 'string'; }))) {
59
- return false;
60
- }
61
- if (preprocessingFunctions !== undefined && !(Array.isArray(preprocessingFunctions) && preprocessingFunctions.every(function (e) { return typeof e === 'string'; }))) {
62
- return false;
63
- }
64
- return true;
65
- }
66
- exports.isTrainingInformation = isTrainingInformation;
@@ -1,3 +0,0 @@
1
- import { tf, Task } from '..';
2
- export declare const task: Task;
3
- export declare function model(): Promise<tf.LayersModel>;
@@ -1,65 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.model = exports.task = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
- exports.task = {
7
- taskID: 'cifar10',
8
- displayInformation: {
9
- taskTitle: 'CIFAR10',
10
- summary: {
11
- preview: 'In this challenge, we ask you to classify images into categories based on the objects shown on the image.',
12
- overview: 'The CIFAR-10 dataset is a collection of images that are commonly used to train machine learning and computer vision algorithms. It is one of the most widely used datasets for machine learning research.'
13
- },
14
- limitations: 'The training data is limited to small images of size 32x32.',
15
- tradeoffs: 'Training success strongly depends on label distribution',
16
- dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The label file should be .csv, where each row contains a file_name, class. <br> <br> e.g. if you have images: 0.png (of a frog) and 1.png (of a car) <br> labels.csv contains: (Note that no header is needed)<br> 0.png, frog <br> 1.png, car',
17
- dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.',
18
- dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png'
19
- },
20
- trainingInformation: {
21
- modelID: 'cifar10-model',
22
- epochs: 10,
23
- roundDuration: 10,
24
- validationSplit: 0.2,
25
- batchSize: 10,
26
- modelCompileData: {
27
- optimizer: 'sgd',
28
- loss: 'categoricalCrossentropy',
29
- metrics: ['accuracy']
30
- },
31
- dataType: 'image',
32
- IMAGE_H: 32,
33
- IMAGE_W: 32,
34
- preprocessingFunctions: [],
35
- LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
36
- scheme: 'Decentralized',
37
- noiseScale: undefined,
38
- clippingRadius: 20,
39
- decentralizedSecure: true,
40
- minimumReadyPeers: 3,
41
- maxShareValue: 100
42
- }
43
- };
44
- function model() {
45
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
46
- var mobilenet, x, predictions;
47
- return (0, tslib_1.__generator)(this, function (_a) {
48
- switch (_a.label) {
49
- case 0: return [4 /*yield*/, __1.tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json')];
50
- case 1:
51
- mobilenet = _a.sent();
52
- x = mobilenet.getLayer('global_average_pooling2d_1');
53
- predictions = __1.tf.layers
54
- .dense({ units: 10, activation: 'softmax', name: 'denseModified' })
55
- .apply(x.output);
56
- return [2 /*return*/, __1.tf.model({
57
- inputs: mobilenet.input,
58
- outputs: predictions,
59
- name: 'modelModified'
60
- })];
61
- }
62
- });
63
- });
64
- }
65
- exports.model = model;
@@ -1,3 +0,0 @@
1
- import { tf, Task } from '..';
2
- export declare const task: Task;
3
- export declare function model(_?: string): Promise<tf.LayersModel>;
@@ -1,67 +0,0 @@
1
- "use strict";
2
- Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.model = exports.task = void 0;
4
- var tslib_1 = require("tslib");
5
- var __1 = require("..");
6
- var immutable_1 = require("immutable");
7
- exports.task = {
8
- taskID: 'geotags',
9
- displayInformation: {
10
- taskTitle: 'GeoTags',
11
- summary: {
12
- preview: 'In this challenge, we predict the geo-location of a photo given its pixels in terms of a cell number of a grid built on top of Switzerland',
13
- overview: 'The geotags dataset is a collection of images with geo-location information used to train a machine learning algorithm to predict the location of a photo given its pixels.'
14
- },
15
- limitations: 'The training data is limited to images of size 224x224.',
16
- tradeoffs: 'Training success strongly depends on label distribution',
17
- dataFormatInformation: 'Images should be of .png format and of size 224x224. <br> The label file should be .csv, where each row contains a file_name, class. The class is the cell number of a the given grid of Switzerland. '
18
- },
19
- trainingInformation: {
20
- modelID: 'geotags-model',
21
- epochs: 10,
22
- roundDuration: 10,
23
- validationSplit: 0.2,
24
- batchSize: 10,
25
- modelCompileData: {
26
- optimizer: 'adam',
27
- loss: 'categoricalCrossentropy',
28
- metrics: ['accuracy']
29
- },
30
- dataType: 'image',
31
- IMAGE_H: 224,
32
- IMAGE_W: 224,
33
- preprocessingFunctions: [__1.data.ImagePreprocessing.Resize],
34
- LABEL_LIST: (0, immutable_1.Range)(0, 140).map(String).toArray(),
35
- scheme: 'Federated',
36
- noiseScale: undefined,
37
- clippingRadius: 20,
38
- decentralizedSecure: true,
39
- minimumReadyPeers: 3,
40
- maxShareValue: 100
41
- }
42
- };
43
- function model(_) {
44
- if (_ === void 0) { _ = ''; }
45
- return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
46
- var pretrainedModel, numLayers, model;
47
- return (0, tslib_1.__generator)(this, function (_a) {
48
- switch (_a.label) {
49
- case 0: return [4 /*yield*/, __1.tf.loadLayersModel('https://storage.googleapis.com/epfl-disco-models/geotags/v2/model.json')];
50
- case 1:
51
- pretrainedModel = _a.sent();
52
- numLayers = pretrainedModel.layers.length;
53
- pretrainedModel.layers.forEach(function (layer) { layer.trainable = false; });
54
- pretrainedModel.layers[numLayers - 1].trainable = true;
55
- model = __1.tf.sequential({
56
- layers: [
57
- __1.tf.layers.inputLayer({ inputShape: [224, 224, 3] }),
58
- __1.tf.layers.rescaling({ scale: 1 / 127.5, offset: -1 }),
59
- pretrainedModel
60
- ]
61
- });
62
- return [2 /*return*/, model];
63
- }
64
- });
65
- });
66
- }
67
- exports.model = model;
@@ -1,6 +0,0 @@
1
- export * as cifar10 from './cifar10';
2
- export * as lus_covid from './lus_covid';
3
- export * as mnist from './mnist';
4
- export * as titanic from './titanic';
5
- export * as simple_face from './simple_face';
6
- export * as geotags from './geotags';