@epfml/discojs 1.0.0 → 2.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. package/README.md +28 -8
  2. package/dist/{async_buffer.d.ts → core/async_buffer.d.ts} +3 -3
  3. package/dist/{async_buffer.js → core/async_buffer.js} +5 -6
  4. package/dist/{async_informant.d.ts → core/async_informant.d.ts} +0 -0
  5. package/dist/{async_informant.js → core/async_informant.js} +0 -0
  6. package/dist/{client → core/client}/base.d.ts +4 -7
  7. package/dist/{client → core/client}/base.js +3 -2
  8. package/dist/core/client/decentralized/base.d.ts +32 -0
  9. package/dist/core/client/decentralized/base.js +212 -0
  10. package/dist/core/client/decentralized/clear_text.d.ts +14 -0
  11. package/dist/core/client/decentralized/clear_text.js +96 -0
  12. package/dist/{client → core/client}/decentralized/index.d.ts +0 -0
  13. package/dist/{client → core/client}/decentralized/index.js +0 -0
  14. package/dist/core/client/decentralized/messages.d.ts +41 -0
  15. package/dist/core/client/decentralized/messages.js +54 -0
  16. package/dist/core/client/decentralized/peer.d.ts +26 -0
  17. package/dist/core/client/decentralized/peer.js +210 -0
  18. package/dist/core/client/decentralized/peer_pool.d.ts +14 -0
  19. package/dist/core/client/decentralized/peer_pool.js +92 -0
  20. package/dist/core/client/decentralized/sec_agg.d.ts +22 -0
  21. package/dist/core/client/decentralized/sec_agg.js +190 -0
  22. package/dist/core/client/decentralized/secret_shares.d.ts +3 -0
  23. package/dist/core/client/decentralized/secret_shares.js +39 -0
  24. package/dist/core/client/decentralized/types.d.ts +2 -0
  25. package/dist/core/client/decentralized/types.js +7 -0
  26. package/dist/core/client/event_connection.d.ts +37 -0
  27. package/dist/core/client/event_connection.js +158 -0
  28. package/dist/core/client/federated/client.d.ts +37 -0
  29. package/dist/core/client/federated/client.js +273 -0
  30. package/dist/core/client/federated/index.d.ts +2 -0
  31. package/dist/core/client/federated/index.js +7 -0
  32. package/dist/core/client/federated/messages.d.ts +38 -0
  33. package/dist/core/client/federated/messages.js +25 -0
  34. package/dist/{client → core/client}/index.d.ts +2 -1
  35. package/dist/{client → core/client}/index.js +3 -3
  36. package/dist/{client → core/client}/local.d.ts +2 -2
  37. package/dist/{client → core/client}/local.js +0 -0
  38. package/dist/core/client/messages.d.ts +28 -0
  39. package/dist/core/client/messages.js +33 -0
  40. package/dist/core/client/utils.d.ts +2 -0
  41. package/dist/core/client/utils.js +19 -0
  42. package/dist/core/dataset/data/data.d.ts +11 -0
  43. package/dist/core/dataset/data/data.js +20 -0
  44. package/dist/core/dataset/data/data_split.d.ts +5 -0
  45. package/dist/{client/decentralized/types.js → core/dataset/data/data_split.js} +0 -0
  46. package/dist/core/dataset/data/image_data.d.ts +8 -0
  47. package/dist/core/dataset/data/image_data.js +64 -0
  48. package/dist/core/dataset/data/index.d.ts +5 -0
  49. package/dist/core/dataset/data/index.js +11 -0
  50. package/dist/core/dataset/data/preprocessing.d.ts +13 -0
  51. package/dist/core/dataset/data/preprocessing.js +33 -0
  52. package/dist/core/dataset/data/tabular_data.d.ts +8 -0
  53. package/dist/core/dataset/data/tabular_data.js +40 -0
  54. package/dist/{dataset → core/dataset}/data_loader/data_loader.d.ts +4 -11
  55. package/dist/{dataset → core/dataset}/data_loader/data_loader.js +0 -0
  56. package/dist/core/dataset/data_loader/image_loader.d.ts +17 -0
  57. package/dist/core/dataset/data_loader/image_loader.js +141 -0
  58. package/dist/core/dataset/data_loader/index.d.ts +3 -0
  59. package/dist/core/dataset/data_loader/index.js +9 -0
  60. package/dist/core/dataset/data_loader/tabular_loader.d.ts +29 -0
  61. package/dist/core/dataset/data_loader/tabular_loader.js +101 -0
  62. package/dist/core/dataset/dataset.d.ts +2 -0
  63. package/dist/{task/training_information.js → core/dataset/dataset.js} +0 -0
  64. package/dist/{dataset → core/dataset}/dataset_builder.d.ts +5 -5
  65. package/dist/{dataset → core/dataset}/dataset_builder.js +14 -10
  66. package/dist/core/dataset/index.d.ts +4 -0
  67. package/dist/core/dataset/index.js +14 -0
  68. package/dist/core/index.d.ts +18 -0
  69. package/dist/core/index.js +41 -0
  70. package/dist/{informant → core/informant}/graph_informant.d.ts +0 -0
  71. package/dist/{informant → core/informant}/graph_informant.js +0 -0
  72. package/dist/{informant → core/informant}/index.d.ts +0 -0
  73. package/dist/{informant → core/informant}/index.js +0 -0
  74. package/dist/{informant → core/informant}/training_informant/base.d.ts +3 -3
  75. package/dist/{informant → core/informant}/training_informant/base.js +3 -2
  76. package/dist/{informant → core/informant}/training_informant/decentralized.d.ts +0 -0
  77. package/dist/{informant → core/informant}/training_informant/decentralized.js +0 -0
  78. package/dist/{informant → core/informant}/training_informant/federated.d.ts +0 -0
  79. package/dist/{informant → core/informant}/training_informant/federated.js +0 -0
  80. package/dist/{informant → core/informant}/training_informant/index.d.ts +0 -0
  81. package/dist/{informant → core/informant}/training_informant/index.js +0 -0
  82. package/dist/{informant → core/informant}/training_informant/local.d.ts +2 -2
  83. package/dist/{informant → core/informant}/training_informant/local.js +2 -2
  84. package/dist/{logging → core/logging}/console_logger.d.ts +0 -0
  85. package/dist/{logging → core/logging}/console_logger.js +0 -0
  86. package/dist/{logging → core/logging}/index.d.ts +0 -0
  87. package/dist/{logging → core/logging}/index.js +0 -0
  88. package/dist/{logging → core/logging}/logger.d.ts +0 -0
  89. package/dist/{logging → core/logging}/logger.js +0 -0
  90. package/dist/{logging → core/logging}/trainer_logger.d.ts +0 -0
  91. package/dist/{logging → core/logging}/trainer_logger.js +0 -0
  92. package/dist/{memory → core/memory}/base.d.ts +2 -2
  93. package/dist/{memory → core/memory}/base.js +0 -0
  94. package/dist/{memory → core/memory}/empty.d.ts +0 -0
  95. package/dist/{memory → core/memory}/empty.js +0 -0
  96. package/dist/core/memory/index.d.ts +3 -0
  97. package/dist/core/memory/index.js +9 -0
  98. package/dist/{memory → core/memory}/model_type.d.ts +0 -0
  99. package/dist/{memory → core/memory}/model_type.js +0 -0
  100. package/dist/{privacy.d.ts → core/privacy.d.ts} +2 -3
  101. package/dist/{privacy.js → core/privacy.js} +3 -16
  102. package/dist/{serialization → core/serialization}/index.d.ts +0 -0
  103. package/dist/{serialization → core/serialization}/index.js +0 -0
  104. package/dist/{serialization → core/serialization}/model.d.ts +0 -0
  105. package/dist/{serialization → core/serialization}/model.js +0 -0
  106. package/dist/core/serialization/weights.d.ts +5 -0
  107. package/dist/{serialization → core/serialization}/weights.js +11 -9
  108. package/dist/{task → core/task}/data_example.d.ts +0 -0
  109. package/dist/{task → core/task}/data_example.js +0 -0
  110. package/dist/{task → core/task}/display_information.d.ts +5 -5
  111. package/dist/{task → core/task}/display_information.js +5 -10
  112. package/dist/{task → core/task}/index.d.ts +0 -0
  113. package/dist/{task → core/task}/index.js +0 -0
  114. package/dist/core/task/model_compile_data.d.ts +6 -0
  115. package/dist/core/task/model_compile_data.js +22 -0
  116. package/dist/{task → core/task}/summary.d.ts +0 -0
  117. package/dist/{task → core/task}/summary.js +0 -4
  118. package/dist/{task → core/task}/task.d.ts +2 -2
  119. package/dist/{task → core/task}/task.js +6 -7
  120. package/dist/{task → core/task}/training_information.d.ts +10 -14
  121. package/dist/core/task/training_information.js +66 -0
  122. package/dist/{tasks → core/tasks}/cifar10.d.ts +1 -2
  123. package/dist/{tasks → core/tasks}/cifar10.js +12 -23
  124. package/dist/core/tasks/geotags.d.ts +3 -0
  125. package/dist/core/tasks/geotags.js +67 -0
  126. package/dist/{tasks → core/tasks}/index.d.ts +2 -1
  127. package/dist/{tasks → core/tasks}/index.js +3 -2
  128. package/dist/core/tasks/lus_covid.d.ts +3 -0
  129. package/dist/{tasks → core/tasks}/lus_covid.js +26 -24
  130. package/dist/{tasks → core/tasks}/mnist.d.ts +1 -2
  131. package/dist/{tasks → core/tasks}/mnist.js +18 -16
  132. package/dist/core/tasks/simple_face.d.ts +2 -0
  133. package/dist/core/tasks/simple_face.js +41 -0
  134. package/dist/{tasks → core/tasks}/titanic.d.ts +1 -2
  135. package/dist/{tasks → core/tasks}/titanic.js +11 -11
  136. package/dist/core/training/disco.d.ts +23 -0
  137. package/dist/core/training/disco.js +130 -0
  138. package/dist/{training → core/training}/index.d.ts +0 -0
  139. package/dist/{training → core/training}/index.js +0 -0
  140. package/dist/{training → core/training}/trainer/distributed_trainer.d.ts +1 -2
  141. package/dist/{training → core/training}/trainer/distributed_trainer.js +6 -5
  142. package/dist/{training → core/training}/trainer/local_trainer.d.ts +2 -2
  143. package/dist/{training → core/training}/trainer/local_trainer.js +0 -0
  144. package/dist/{training → core/training}/trainer/round_tracker.d.ts +0 -0
  145. package/dist/{training → core/training}/trainer/round_tracker.js +0 -0
  146. package/dist/{training → core/training}/trainer/trainer.d.ts +1 -2
  147. package/dist/{training → core/training}/trainer/trainer.js +2 -2
  148. package/dist/{training → core/training}/trainer/trainer_builder.d.ts +0 -0
  149. package/dist/{training → core/training}/trainer/trainer_builder.js +0 -0
  150. package/dist/core/training/training_schemes.d.ts +5 -0
  151. package/dist/{training → core/training}/training_schemes.js +2 -2
  152. package/dist/{types.d.ts → core/types.d.ts} +0 -0
  153. package/dist/{types.js → core/types.js} +0 -0
  154. package/dist/{validation → core/validation}/index.d.ts +0 -0
  155. package/dist/{validation → core/validation}/index.js +0 -0
  156. package/dist/{validation → core/validation}/validator.d.ts +5 -8
  157. package/dist/{validation → core/validation}/validator.js +9 -11
  158. package/dist/core/weights/aggregation.d.ts +8 -0
  159. package/dist/core/weights/aggregation.js +96 -0
  160. package/dist/core/weights/index.d.ts +2 -0
  161. package/dist/core/weights/index.js +7 -0
  162. package/dist/core/weights/weights_container.d.ts +19 -0
  163. package/dist/core/weights/weights_container.js +64 -0
  164. package/dist/dataset/data_loader/image_loader.d.ts +3 -15
  165. package/dist/dataset/data_loader/image_loader.js +12 -125
  166. package/dist/dataset/data_loader/index.d.ts +2 -3
  167. package/dist/dataset/data_loader/index.js +3 -5
  168. package/dist/dataset/data_loader/tabular_loader.d.ts +3 -28
  169. package/dist/dataset/data_loader/tabular_loader.js +11 -92
  170. package/dist/imports.d.ts +2 -0
  171. package/dist/imports.js +7 -0
  172. package/dist/index.d.ts +2 -19
  173. package/dist/index.js +3 -39
  174. package/dist/memory/index.d.ts +1 -3
  175. package/dist/memory/index.js +3 -7
  176. package/dist/memory/memory.d.ts +26 -0
  177. package/dist/memory/memory.js +160 -0
  178. package/package.json +14 -27
  179. package/dist/aggregation.d.ts +0 -5
  180. package/dist/aggregation.js +0 -33
  181. package/dist/client/decentralized/base.d.ts +0 -43
  182. package/dist/client/decentralized/base.js +0 -243
  183. package/dist/client/decentralized/clear_text.d.ts +0 -13
  184. package/dist/client/decentralized/clear_text.js +0 -78
  185. package/dist/client/decentralized/messages.d.ts +0 -37
  186. package/dist/client/decentralized/messages.js +0 -15
  187. package/dist/client/decentralized/sec_agg.d.ts +0 -18
  188. package/dist/client/decentralized/sec_agg.js +0 -169
  189. package/dist/client/decentralized/secret_shares.d.ts +0 -5
  190. package/dist/client/decentralized/secret_shares.js +0 -58
  191. package/dist/client/decentralized/types.d.ts +0 -1
  192. package/dist/client/federated.d.ts +0 -30
  193. package/dist/client/federated.js +0 -218
  194. package/dist/dataset/index.d.ts +0 -2
  195. package/dist/dataset/index.js +0 -7
  196. package/dist/model_actor.d.ts +0 -16
  197. package/dist/model_actor.js +0 -20
  198. package/dist/serialization/weights.d.ts +0 -5
  199. package/dist/task/model_compile_data.d.ts +0 -6
  200. package/dist/task/model_compile_data.js +0 -12
  201. package/dist/tasks/lus_covid.d.ts +0 -4
  202. package/dist/tasks/simple_face.d.ts +0 -4
  203. package/dist/tasks/simple_face.js +0 -84
  204. package/dist/tfjs.d.ts +0 -2
  205. package/dist/tfjs.js +0 -6
  206. package/dist/training/disco.d.ts +0 -14
  207. package/dist/training/disco.js +0 -70
  208. package/dist/training/training_schemes.d.ts +0 -5
@@ -1,32 +1,28 @@
1
- import { DataExample } from './data_example';
1
+ import { Preprocessing } from '../dataset/data/preprocessing';
2
2
  import { ModelCompileData } from './model_compile_data';
3
+ export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
3
4
  export interface TrainingInformation {
4
5
  modelID: string;
5
6
  epochs: number;
6
7
  roundDuration: number;
7
8
  validationSplit: number;
8
9
  batchSize: number;
9
- preprocessFunctions: string[];
10
+ preprocessingFunctions?: Preprocessing[];
10
11
  modelCompileData: ModelCompileData;
11
12
  dataType: string;
12
- maxShareValue?: number;
13
- minimumReadyPeers?: number;
14
- decentralizedSecure?: boolean;
15
- receivedMessagesThreshold?: number;
16
13
  inputColumns?: string[];
17
14
  outputColumns?: string[];
18
- threshold?: number;
19
15
  IMAGE_H?: number;
20
16
  IMAGE_W?: number;
17
+ modelURL?: string;
21
18
  LABEL_LIST?: string[];
22
- aggregateImagesById?: boolean;
23
19
  learningRate?: number;
24
- NUM_CLASSES?: number;
25
- csvLabels?: boolean;
26
- RESIZED_IMAGE_H?: number;
27
- RESIZED_IMAGE_W?: number;
28
- LABEL_ASSIGNMENT?: DataExample[];
29
- scheme?: string;
20
+ scheme: string;
30
21
  noiseScale?: number;
31
22
  clippingRadius?: number;
23
+ decentralizedSecure?: boolean;
24
+ byzantineRobustAggregator?: boolean;
25
+ tauPercentile?: number;
26
+ maxShareValue?: number;
27
+ minimumReadyPeers?: number;
32
28
  }
@@ -0,0 +1,66 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.isTrainingInformation = void 0;
4
+ var model_compile_data_1 = require("./model_compile_data");
5
+ function isTrainingInformation(raw) {
6
+ if (typeof raw !== 'object') {
7
+ return false;
8
+ }
9
+ if (raw === null) {
10
+ return false;
11
+ }
12
+ var _a = raw, dataType = _a.dataType, scheme = _a.scheme, epochs = _a.epochs,
13
+ // roundDuration,
14
+ validationSplit = _a.validationSplit, batchSize = _a.batchSize, modelCompileData = _a.modelCompileData, modelID = _a.modelID, preprocessingFunctions = _a.preprocessingFunctions, inputColumns = _a.inputColumns, outputColumns = _a.outputColumns, IMAGE_H = _a.IMAGE_H, IMAGE_W = _a.IMAGE_W, roundDuration = _a.roundDuration, modelURL = _a.modelURL, learningRate = _a.learningRate, decentralizedSecure = _a.decentralizedSecure, maxShareValue = _a.maxShareValue, minimumReadyPeers = _a.minimumReadyPeers, LABEL_LIST = _a.LABEL_LIST, noiseScale = _a.noiseScale, clippingRadius = _a.clippingRadius;
15
+ if (typeof dataType !== 'string' ||
16
+ typeof modelID !== 'string' ||
17
+ typeof epochs !== 'number' ||
18
+ typeof batchSize !== 'number' ||
19
+ typeof roundDuration !== 'number' ||
20
+ typeof validationSplit !== 'number' ||
21
+ (modelURL !== undefined && typeof modelURL !== 'string') ||
22
+ (noiseScale !== undefined && typeof noiseScale !== 'number') ||
23
+ (clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
24
+ (learningRate !== undefined && typeof learningRate !== 'number') ||
25
+ (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
26
+ (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
27
+ (minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number')) {
28
+ return false;
29
+ }
30
+ // interdepences on data type
31
+ switch (dataType) {
32
+ case 'image':
33
+ if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
34
+ return false;
35
+ }
36
+ break;
37
+ case 'tabular':
38
+ if (!(Array.isArray(inputColumns) && inputColumns.every(function (e) { return typeof e === 'string'; }))) {
39
+ return false;
40
+ }
41
+ if (!(Array.isArray(outputColumns) && outputColumns.every(function (e) { return typeof e === 'string'; }))) {
42
+ return false;
43
+ }
44
+ break;
45
+ }
46
+ // interdepences on scheme
47
+ switch (scheme) {
48
+ case 'decentralized':
49
+ break;
50
+ case 'federated':
51
+ break;
52
+ case 'local':
53
+ break;
54
+ }
55
+ if (!(0, model_compile_data_1.isModelCompileData)(modelCompileData)) {
56
+ return false;
57
+ }
58
+ if (LABEL_LIST !== undefined && !(Array.isArray(LABEL_LIST) && LABEL_LIST.every(function (e) { return typeof e === 'string'; }))) {
59
+ return false;
60
+ }
61
+ if (preprocessingFunctions !== undefined && !(Array.isArray(preprocessingFunctions) && preprocessingFunctions.every(function (e) { return typeof e === 'string'; }))) {
62
+ return false;
63
+ }
64
+ return true;
65
+ }
66
+ exports.isTrainingInformation = isTrainingInformation;
@@ -1,4 +1,3 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { Task } from '../task';
1
+ import { tf, Task } from '..';
3
2
  export declare const task: Task;
4
3
  export declare function model(): Promise<tf.LayersModel>;
@@ -2,7 +2,7 @@
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.model = exports.task = void 0;
4
4
  var tslib_1 = require("tslib");
5
- var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
5
+ var __1 = require("..");
6
6
  exports.task = {
7
7
  taskID: 'cifar10',
8
8
  displayInformation: {
@@ -15,7 +15,7 @@ exports.task = {
15
15
  tradeoffs: 'Training success strongly depends on label distribution',
16
16
  dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The label file should be .csv, where each row contains a file_name, class. <br> <br> e.g. if you have images: 0.png (of a frog) and 1.png (of a car) <br> labels.csv contains: (Note that no header is needed)<br> 0.png, frog <br> 1.png, car',
17
17
  dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.',
18
- dataExampleImage: './cifar10-example.png'
18
+ dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png'
19
19
  },
20
20
  trainingInformation: {
21
21
  modelID: 'cifar10-model',
@@ -28,28 +28,17 @@ exports.task = {
28
28
  loss: 'categoricalCrossentropy',
29
29
  metrics: ['accuracy']
30
30
  },
31
- threshold: 1,
32
31
  dataType: 'image',
33
- csvLabels: true,
34
32
  IMAGE_H: 32,
35
33
  IMAGE_W: 32,
36
- preprocessFunctions: ['resize'],
37
- RESIZED_IMAGE_H: 224,
38
- RESIZED_IMAGE_W: 224,
34
+ preprocessingFunctions: [],
39
35
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
40
- LABEL_ASSIGNMENT: [
41
- { columnName: 'airplane', columnData: 0 },
42
- { columnName: 'automobile', columnData: 1 },
43
- { columnName: 'bird', columnData: 2 },
44
- { columnName: 'cat', columnData: 3 },
45
- { columnName: 'deer', columnData: 4 },
46
- { columnName: 'dog', columnData: 5 },
47
- { columnName: 'frog', columnData: 6 },
48
- { columnName: 'horse', columnData: 7 },
49
- { columnName: 'ship', columnData: 8 },
50
- { columnName: 'truck', columnData: 9 }
51
- ],
52
- scheme: 'Decentralized'
36
+ scheme: 'Decentralized',
37
+ noiseScale: undefined,
38
+ clippingRadius: 20,
39
+ decentralizedSecure: true,
40
+ minimumReadyPeers: 3,
41
+ maxShareValue: 100
53
42
  }
54
43
  };
55
44
  function model() {
@@ -57,14 +46,14 @@ function model() {
57
46
  var mobilenet, x, predictions;
58
47
  return (0, tslib_1.__generator)(this, function (_a) {
59
48
  switch (_a.label) {
60
- case 0: return [4 /*yield*/, tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json')];
49
+ case 0: return [4 /*yield*/, __1.tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json')];
61
50
  case 1:
62
51
  mobilenet = _a.sent();
63
52
  x = mobilenet.getLayer('global_average_pooling2d_1');
64
- predictions = tf.layers
53
+ predictions = __1.tf.layers
65
54
  .dense({ units: 10, activation: 'softmax', name: 'denseModified' })
66
55
  .apply(x.output);
67
- return [2 /*return*/, tf.model({
56
+ return [2 /*return*/, __1.tf.model({
68
57
  inputs: mobilenet.input,
69
58
  outputs: predictions,
70
59
  name: 'modelModified'
@@ -0,0 +1,3 @@
1
+ import { tf, Task } from '..';
2
+ export declare const task: Task;
3
+ export declare function model(_?: string): Promise<tf.LayersModel>;
@@ -0,0 +1,67 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.model = exports.task = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var __1 = require("..");
6
+ var immutable_1 = require("immutable");
7
+ exports.task = {
8
+ taskID: 'geotags',
9
+ displayInformation: {
10
+ taskTitle: 'GeoTags',
11
+ summary: {
12
+ preview: 'In this challenge, we predict the geo-location of a photo given its pixels in terms of a cell number of a grid built on top of Switzerland',
13
+ overview: 'The geotags dataset is a collection of images with geo-location information used to train a machine learning algorithm to predict the location of a photo given its pixels.'
14
+ },
15
+ limitations: 'The training data is limited to images of size 224x224.',
16
+ tradeoffs: 'Training success strongly depends on label distribution',
17
+ dataFormatInformation: 'Images should be of .png format and of size 224x224. <br> The label file should be .csv, where each row contains a file_name, class. The class is the cell number of a the given grid of Switzerland. '
18
+ },
19
+ trainingInformation: {
20
+ modelID: 'geotags-model',
21
+ epochs: 10,
22
+ roundDuration: 10,
23
+ validationSplit: 0.2,
24
+ batchSize: 10,
25
+ modelCompileData: {
26
+ optimizer: 'adam',
27
+ loss: 'categoricalCrossentropy',
28
+ metrics: ['accuracy']
29
+ },
30
+ dataType: 'image',
31
+ IMAGE_H: 224,
32
+ IMAGE_W: 224,
33
+ preprocessingFunctions: [__1.data.ImagePreprocessing.Resize],
34
+ LABEL_LIST: (0, immutable_1.Range)(0, 140).map(String).toArray(),
35
+ scheme: 'Federated',
36
+ noiseScale: undefined,
37
+ clippingRadius: 20,
38
+ decentralizedSecure: true,
39
+ minimumReadyPeers: 3,
40
+ maxShareValue: 100
41
+ }
42
+ };
43
+ function model(_) {
44
+ if (_ === void 0) { _ = ''; }
45
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
46
+ var pretrainedModel, numLayers, model;
47
+ return (0, tslib_1.__generator)(this, function (_a) {
48
+ switch (_a.label) {
49
+ case 0: return [4 /*yield*/, __1.tf.loadLayersModel('https://storage.googleapis.com/epfl-disco-models/geotags/v2/model.json')];
50
+ case 1:
51
+ pretrainedModel = _a.sent();
52
+ numLayers = pretrainedModel.layers.length;
53
+ pretrainedModel.layers.forEach(function (layer) { layer.trainable = false; });
54
+ pretrainedModel.layers[numLayers - 1].trainable = true;
55
+ model = __1.tf.sequential({
56
+ layers: [
57
+ __1.tf.layers.inputLayer({ inputShape: [224, 224, 3] }),
58
+ __1.tf.layers.rescaling({ scale: 1 / 127.5, offset: -1 }),
59
+ pretrainedModel
60
+ ]
61
+ });
62
+ return [2 /*return*/, model];
63
+ }
64
+ });
65
+ });
66
+ }
67
+ exports.model = model;
@@ -1,5 +1,6 @@
1
1
  export * as cifar10 from './cifar10';
2
2
  export * as lus_covid from './lus_covid';
3
3
  export * as mnist from './mnist';
4
- export * as simple_face from './simple_face';
5
4
  export * as titanic from './titanic';
5
+ export * as simple_face from './simple_face';
6
+ export * as geotags from './geotags';
@@ -1,9 +1,10 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
- exports.titanic = exports.simple_face = exports.mnist = exports.lus_covid = exports.cifar10 = void 0;
3
+ exports.geotags = exports.simple_face = exports.titanic = exports.mnist = exports.lus_covid = exports.cifar10 = void 0;
4
4
  var tslib_1 = require("tslib");
5
5
  exports.cifar10 = (0, tslib_1.__importStar)(require("./cifar10"));
6
6
  exports.lus_covid = (0, tslib_1.__importStar)(require("./lus_covid"));
7
7
  exports.mnist = (0, tslib_1.__importStar)(require("./mnist"));
8
- exports.simple_face = (0, tslib_1.__importStar)(require("./simple_face"));
9
8
  exports.titanic = (0, tslib_1.__importStar)(require("./titanic"));
9
+ exports.simple_face = (0, tslib_1.__importStar)(require("./simple_face"));
10
+ exports.geotags = (0, tslib_1.__importStar)(require("./geotags"));
@@ -0,0 +1,3 @@
1
+ import { tf, Task } from '..';
2
+ export declare const task: Task;
3
+ export declare function model(): tf.LayersModel;
@@ -1,20 +1,20 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.model = exports.task = void 0;
4
- var tslib_1 = require("tslib");
5
- var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
4
+ var __1 = require("..");
6
5
  exports.task = {
7
6
  taskID: 'lus_covid',
8
7
  displayInformation: {
9
8
  taskTitle: 'COVID Lung Ultrasound',
10
9
  summary: {
11
- preview: "Do you have a dataset of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task. <br><br> Don’t have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly'>here</a>.",
12
- overview: "Do you have a dataset of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task. <br><br> Don’t have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly'>here</a>."
10
+ preview: 'Do you have a dataset of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
11
+ overview: "Don’t have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
13
12
  },
14
13
  model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
15
14
  tradeoffs: 'We are using a simpler version of DeepChest in order to be able to run it on the browser.',
16
15
  dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
17
- dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png'
16
+ dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png',
17
+ dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png'
18
18
  },
19
19
  trainingInformation: {
20
20
  modelID: 'lus-covid-model',
@@ -28,27 +28,29 @@ exports.task = {
28
28
  metrics: ['accuracy']
29
29
  },
30
30
  learningRate: 0.001,
31
- threshold: 2,
32
31
  IMAGE_H: 100,
33
32
  IMAGE_W: 100,
34
- preprocessFunctions: [],
33
+ preprocessingFunctions: [__1.data.ImagePreprocessing.Resize],
35
34
  LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
36
- NUM_CLASSES: 2,
37
35
  dataType: 'image',
38
- aggregateImagesById: true,
39
- scheme: 'Decentralized'
36
+ scheme: 'Decentralized',
37
+ noiseScale: undefined,
38
+ clippingRadius: 20,
39
+ decentralizedSecure: true,
40
+ minimumReadyPeers: 3,
41
+ maxShareValue: 100
40
42
  }
41
43
  };
42
- function model(imageHeight, imageWidth, imageChannels, numOutputClasses) {
43
- if (imageHeight === void 0) { imageHeight = 100; }
44
- if (imageWidth === void 0) { imageWidth = 100; }
45
- if (imageChannels === void 0) { imageChannels = 3; }
46
- if (numOutputClasses === void 0) { numOutputClasses = 2; }
47
- var model = tf.sequential();
44
+ function model() {
45
+ var imageHeight = 100;
46
+ var imageWidth = 100;
47
+ var imageChannels = 3;
48
+ var numOutputClasses = 2;
49
+ var model = __1.tf.sequential();
48
50
  // In the first layer of our convolutional neural network we have
49
51
  // to specify the input shape. Then we specify some parameters for
50
52
  // the convolution operation that takes place in this layer.
51
- model.add(tf.layers.conv2d({
53
+ model.add(__1.tf.layers.conv2d({
52
54
  inputShape: [imageHeight, imageWidth, imageChannels],
53
55
  kernelSize: 5,
54
56
  filters: 8,
@@ -58,24 +60,24 @@ function model(imageHeight, imageWidth, imageChannels, numOutputClasses) {
58
60
  }));
59
61
  // The MaxPooling layer acts as a sort of downsampling using max values
60
62
  // in a region instead of averaging.
61
- model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
63
+ model.add(__1.tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
62
64
  // Repeat another conv2d + maxPooling stack.
63
65
  // Note that we have more filters in the convolution.
64
- model.add(tf.layers.conv2d({
66
+ model.add(__1.tf.layers.conv2d({
65
67
  kernelSize: 5,
66
68
  filters: 16,
67
69
  strides: 1,
68
70
  activation: 'relu',
69
71
  kernelInitializer: 'varianceScaling'
70
72
  }));
71
- model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
73
+ model.add(__1.tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
72
74
  // Now we flatten the output from the 2D filters into a 1D vector to prepare
73
75
  // it for input into our last layer. This is common practice when feeding
74
76
  // higher dimensional data to a final classification output layer.
75
- model.add(tf.layers.flatten());
76
- // Our last layer is a dense layer which has 10 output units, one for each
77
- // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
78
- model.add(tf.layers.dense({
77
+ model.add(__1.tf.layers.flatten());
78
+ // Our last layer is a dense layer which has 2 output units, one for each
79
+ // output class.
80
+ model.add(__1.tf.layers.dense({
79
81
  units: numOutputClasses,
80
82
  kernelInitializer: 'varianceScaling',
81
83
  activation: 'softmax'
@@ -1,4 +1,3 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { Task } from '../task';
1
+ import { tf, Task } from '..';
3
2
  export declare const task: Task;
4
3
  export declare function model(): tf.LayersModel;
@@ -1,8 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.model = exports.task = void 0;
4
- var tslib_1 = require("tslib");
5
- var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
4
+ var __1 = require("..");
6
5
  exports.task = {
7
6
  taskID: 'mnist',
8
7
  displayInformation: {
@@ -15,7 +14,7 @@ exports.task = {
15
14
  tradeoffs: 'We are using a simple model, first a 2d convolutional layer > max pooling > 2d convolutional layer > max pooling > convolutional layer > 2 dense layers.',
16
15
  dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can upload each digit image of your dataset in the box corresponding to its label. The model taskes images of size 28x28 as input.',
17
16
  dataExampleText: 'Below you can find an example of an expected image representing the digit 9.',
18
- dataExampleImage: './9-mnist-example.png'
17
+ dataExampleImage: 'http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png'
19
18
  },
20
19
  trainingInformation: {
21
20
  modelID: 'mnist-model',
@@ -28,31 +27,34 @@ exports.task = {
28
27
  loss: 'categoricalCrossentropy',
29
28
  metrics: ['accuracy']
30
29
  },
31
- threshold: 1,
32
30
  dataType: 'image',
33
31
  IMAGE_H: 28,
34
32
  IMAGE_W: 28,
35
- preprocessFunctions: [],
33
+ preprocessingFunctions: [],
36
34
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
37
- aggregateImagesById: false,
38
- scheme: 'Decentralized'
35
+ scheme: 'Decentralized',
36
+ noiseScale: undefined,
37
+ clippingRadius: 20,
38
+ decentralizedSecure: true,
39
+ minimumReadyPeers: 3,
40
+ maxShareValue: 100
39
41
  }
40
42
  };
41
43
  function model() {
42
- var model = tf.sequential();
43
- model.add(tf.layers.conv2d({
44
+ var model = __1.tf.sequential();
45
+ model.add(__1.tf.layers.conv2d({
44
46
  inputShape: [28, 28, 3],
45
47
  kernelSize: 3,
46
48
  filters: 16,
47
49
  activation: 'relu'
48
50
  }));
49
- model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
50
- model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
51
- model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
52
- model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
53
- model.add(tf.layers.flatten({}));
54
- model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
55
- model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
51
+ model.add(__1.tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
52
+ model.add(__1.tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
53
+ model.add(__1.tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
54
+ model.add(__1.tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
55
+ model.add(__1.tf.layers.flatten({}));
56
+ model.add(__1.tf.layers.dense({ units: 64, activation: 'relu' }));
57
+ model.add(__1.tf.layers.dense({ units: 10, activation: 'softmax' }));
56
58
  return model;
57
59
  }
58
60
  exports.model = model;
@@ -0,0 +1,2 @@
1
+ import { Task } from '..';
2
+ export declare const task: Task;
@@ -0,0 +1,41 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.task = void 0;
4
+ var __1 = require("..");
5
+ exports.task = {
6
+ taskID: 'simple_face',
7
+ displayInformation: {
8
+ taskTitle: 'Simple Face',
9
+ summary: {
10
+ preview: 'Can you detect if the person in a picture is a child or an adult?',
11
+ overview: 'Simple face is a small subset of face_task from Kaggle'
12
+ },
13
+ limitations: 'The training data is limited to small images of size 200x200.',
14
+ tradeoffs: 'Training success strongly depends on label distribution',
15
+ dataFormatInformation: '',
16
+ dataExampleText: 'Below you find an example',
17
+ dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
18
+ },
19
+ trainingInformation: {
20
+ modelID: 'simple_face-model',
21
+ epochs: 50,
22
+ modelURL: 'https://storage.googleapis.com/deai-313515.appspot.com/models/mobileNetV2_35_alpha_2_classes/model.json',
23
+ roundDuration: 1,
24
+ validationSplit: 0.2,
25
+ batchSize: 10,
26
+ preprocessingFunctions: [__1.data.ImagePreprocessing.Normalize],
27
+ learningRate: 0.001,
28
+ modelCompileData: {
29
+ optimizer: 'sgd',
30
+ loss: 'categoricalCrossentropy',
31
+ metrics: ['accuracy']
32
+ },
33
+ dataType: 'image',
34
+ IMAGE_H: 200,
35
+ IMAGE_W: 200,
36
+ LABEL_LIST: ['child', 'adult'],
37
+ scheme: 'Federated',
38
+ noiseScale: undefined,
39
+ clippingRadius: undefined
40
+ }
41
+ };
@@ -1,4 +1,3 @@
1
- import * as tf from '@tensorflow/tfjs';
2
- import { Task } from '..';
1
+ import { tf, Task } from '..';
3
2
  export declare const task: Task;
4
3
  export declare function model(): tf.LayersModel;
@@ -1,14 +1,13 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.model = exports.task = void 0;
4
- var tslib_1 = require("tslib");
5
- var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
4
+ var __1 = require("..");
6
5
  exports.task = {
7
6
  taskID: 'titanic',
8
7
  displayInformation: {
9
8
  taskTitle: 'Titanic',
10
9
  summary: {
11
- preview: "Test our platform by using a publicly available <b>tabular</b> dataset. <br><br> Download the passenger list from the Titanic shipwreck here: <a class='underline text-primary-dark dark:text-primary-light' href='http://s3.amazonaws.com/assets.datacamp.com/course/Kaggle/train.csv'>train.csv</a> and <a class='underline text-primary-dark dark:text-primary-light' href='http://s3.amazonaws.com/assets.datacamp.com/course/Kaggle/train.csv'>test.csv</a> (more info <a class='underline text-primary-dark dark:text-primary-light' href='https://www.kaggle.com/c/titanic'>here</a>). <br> This model predicts the type of person most likely to survive/die in the historic ship accident, based on their characteristics (sex, age, class etc.).",
10
+ preview: "Test our platform by using a publicly available <b>tabular</b> dataset. <br><br> Download the passenger list from the Titanic shipwreck here: <a class='underline text-primary-dark dark:text-primary-light' href='https://github.com/epfml/disco/raw/develop/example_training_data/titanic_train.csv'>titanic_train.csv</a> (more info <a class='underline text-primary-dark dark:text-primary-light' href='https://www.kaggle.com/c/titanic'>here</a>). <br> This model predicts the type of person most likely to survive/die in the historic ship accident, based on their characteristics (sex, age, class etc.).",
12
11
  overview: 'We all know the unfortunate story of the Titanic: this flamboyant new transatlantic boat that sunk in 1912 in the North Atlantic Ocean. Today, we revist this tragedy by trying to predict the survival odds of the passenger given some basic features.'
13
12
  },
14
13
  model: 'The current model does not normalize the given data and applies only a very simple pre-processing of the data.',
@@ -50,13 +49,12 @@ exports.task = {
50
49
  roundDuration: 10,
51
50
  validationSplit: 0,
52
51
  batchSize: 30,
53
- preprocessFunctions: [],
52
+ preprocessingFunctions: [],
54
53
  modelCompileData: {
55
54
  optimizer: 'rmsprop',
56
55
  loss: 'binaryCrossentropy',
57
56
  metrics: ['accuracy']
58
57
  },
59
- receivedMessagesThreshold: 1,
60
58
  dataType: 'tabular',
61
59
  inputColumns: [
62
60
  'PassengerId',
@@ -69,20 +67,22 @@ exports.task = {
69
67
  outputColumns: [
70
68
  'Survived'
71
69
  ],
72
- scheme: 'Federated'
70
+ scheme: 'Federated',
71
+ noiseScale: undefined,
72
+ clippingRadius: undefined
73
73
  }
74
74
  };
75
75
  function model() {
76
- var model = tf.sequential();
77
- model.add(tf.layers.dense({
76
+ var model = __1.tf.sequential();
77
+ model.add(__1.tf.layers.dense({
78
78
  inputShape: [6],
79
79
  units: 124,
80
80
  activation: 'relu',
81
81
  kernelInitializer: 'leCunNormal'
82
82
  }));
83
- model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
84
- model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
85
- model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
83
+ model.add(__1.tf.layers.dense({ units: 64, activation: 'relu' }));
84
+ model.add(__1.tf.layers.dense({ units: 32, activation: 'relu' }));
85
+ model.add(__1.tf.layers.dense({ units: 1, activation: 'sigmoid' }));
86
86
  return model;
87
87
  }
88
88
  exports.model = model;
@@ -0,0 +1,23 @@
1
+ import { Client, data, Logger, Task, TrainingInformant, TrainingSchemes, Memory } from '..';
2
+ import { TrainerLog } from '../logging/trainer_logger';
3
+ interface DiscoOptions {
4
+ client?: Client;
5
+ url?: string | URL;
6
+ scheme?: TrainingSchemes;
7
+ informant?: TrainingInformant;
8
+ logger?: Logger;
9
+ memory?: Memory;
10
+ }
11
+ export declare class Disco {
12
+ readonly task: Task;
13
+ readonly logger: Logger;
14
+ readonly memory: Memory;
15
+ private readonly client;
16
+ private readonly trainer;
17
+ constructor(task: Task, options: DiscoOptions);
18
+ fit(dataTuple: data.DataSplit): Promise<void>;
19
+ pause(): Promise<void>;
20
+ close(): Promise<void>;
21
+ logs(): Promise<TrainerLog>;
22
+ }
23
+ export {};