@epfml/discojs 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. package/README.md +32 -0
  2. package/dist/aggregation.d.ts +3 -0
  3. package/dist/aggregation.js +19 -0
  4. package/dist/async_buffer.d.ts +41 -0
  5. package/dist/async_buffer.js +98 -0
  6. package/dist/async_informant.d.ts +20 -0
  7. package/dist/async_informant.js +69 -0
  8. package/dist/client/base.d.ts +36 -0
  9. package/dist/client/base.js +34 -0
  10. package/dist/client/decentralized.d.ts +23 -0
  11. package/dist/client/decentralized.js +275 -0
  12. package/dist/client/federated.d.ts +30 -0
  13. package/dist/client/federated.js +221 -0
  14. package/dist/client/index.d.ts +4 -0
  15. package/dist/client/index.js +11 -0
  16. package/dist/client/local.d.ts +8 -0
  17. package/dist/client/local.js +34 -0
  18. package/dist/dataset/data_loader/data_loader.d.ts +16 -0
  19. package/dist/dataset/data_loader/data_loader.js +10 -0
  20. package/dist/dataset/data_loader/image_loader.d.ts +14 -0
  21. package/dist/dataset/data_loader/image_loader.js +93 -0
  22. package/dist/dataset/data_loader/index.d.ts +3 -0
  23. package/dist/dataset/data_loader/index.js +9 -0
  24. package/dist/dataset/data_loader/tabular_loader.d.ts +29 -0
  25. package/dist/dataset/data_loader/tabular_loader.js +88 -0
  26. package/dist/dataset/dataset_builder.d.ts +17 -0
  27. package/dist/dataset/dataset_builder.js +80 -0
  28. package/dist/dataset/index.d.ts +2 -0
  29. package/dist/dataset/index.js +7 -0
  30. package/dist/index.d.ts +17 -0
  31. package/dist/index.js +34 -0
  32. package/dist/logging/console_logger.d.ts +18 -0
  33. package/dist/logging/console_logger.js +33 -0
  34. package/dist/logging/index.d.ts +2 -0
  35. package/dist/logging/index.js +7 -0
  36. package/dist/logging/logger.d.ts +12 -0
  37. package/dist/logging/logger.js +9 -0
  38. package/dist/logging/trainer_logger.d.ts +24 -0
  39. package/dist/logging/trainer_logger.js +59 -0
  40. package/dist/memory/base.d.ts +53 -0
  41. package/dist/memory/base.js +9 -0
  42. package/dist/memory/empty.d.ts +12 -0
  43. package/dist/memory/empty.js +69 -0
  44. package/dist/memory/index.d.ts +3 -0
  45. package/dist/memory/index.js +9 -0
  46. package/dist/memory/model_type.d.ts +4 -0
  47. package/dist/memory/model_type.js +9 -0
  48. package/dist/model_actor.d.ts +16 -0
  49. package/dist/model_actor.js +20 -0
  50. package/dist/privacy.d.ts +12 -0
  51. package/dist/privacy.js +60 -0
  52. package/dist/serialization/index.d.ts +2 -0
  53. package/dist/serialization/index.js +6 -0
  54. package/dist/serialization/model.d.ts +5 -0
  55. package/dist/serialization/model.js +55 -0
  56. package/dist/serialization/weights.d.ts +5 -0
  57. package/dist/serialization/weights.js +62 -0
  58. package/dist/task/data_example.d.ts +5 -0
  59. package/dist/task/data_example.js +24 -0
  60. package/dist/task/display_information.d.ts +15 -0
  61. package/dist/task/display_information.js +53 -0
  62. package/dist/task/index.d.ts +3 -0
  63. package/dist/task/index.js +8 -0
  64. package/dist/task/model_compile_data.d.ts +6 -0
  65. package/dist/task/model_compile_data.js +12 -0
  66. package/dist/task/task.d.ts +10 -0
  67. package/dist/task/task.js +32 -0
  68. package/dist/task/training_information.d.ts +29 -0
  69. package/dist/task/training_information.js +2 -0
  70. package/dist/tasks/cifar10.d.ts +4 -0
  71. package/dist/tasks/cifar10.js +74 -0
  72. package/dist/tasks/index.d.ts +5 -0
  73. package/dist/tasks/index.js +9 -0
  74. package/dist/tasks/lus_covid.d.ts +4 -0
  75. package/dist/tasks/lus_covid.js +48 -0
  76. package/dist/tasks/mnist.d.ts +4 -0
  77. package/dist/tasks/mnist.js +56 -0
  78. package/dist/tasks/simple_face.d.ts +4 -0
  79. package/dist/tasks/simple_face.js +88 -0
  80. package/dist/tasks/titanic.d.ts +4 -0
  81. package/dist/tasks/titanic.js +86 -0
  82. package/dist/testing/tester.d.ts +5 -0
  83. package/dist/testing/tester.js +21 -0
  84. package/dist/training/disco.d.ts +12 -0
  85. package/dist/training/disco.js +62 -0
  86. package/dist/training/index.d.ts +2 -0
  87. package/dist/training/index.js +7 -0
  88. package/dist/training/trainer/distributed_trainer.d.ts +21 -0
  89. package/dist/training/trainer/distributed_trainer.js +60 -0
  90. package/dist/training/trainer/local_trainer.d.ts +10 -0
  91. package/dist/training/trainer/local_trainer.js +37 -0
  92. package/dist/training/trainer/round_tracker.d.ts +30 -0
  93. package/dist/training/trainer/round_tracker.js +44 -0
  94. package/dist/training/trainer/trainer.d.ts +66 -0
  95. package/dist/training/trainer/trainer.js +146 -0
  96. package/dist/training/trainer/trainer_builder.d.ts +25 -0
  97. package/dist/training/trainer/trainer_builder.js +102 -0
  98. package/dist/training/training_schemes.d.ts +5 -0
  99. package/dist/training/training_schemes.js +10 -0
  100. package/dist/training_informant.d.ts +88 -0
  101. package/dist/training_informant.js +135 -0
  102. package/dist/types.d.ts +4 -0
  103. package/dist/types.js +2 -0
  104. package/package.json +48 -0
@@ -0,0 +1,55 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.decode = exports.encode = exports.isEncoded = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
+ var msgpack_lite_1 = (0, tslib_1.__importDefault)(require("msgpack-lite"));
7
+ function isEncoded(raw) {
8
+ return Array.isArray(raw) && raw.every(function (r) { return typeof r === 'number'; });
9
+ }
10
+ exports.isEncoded = isEncoded;
11
+ function encode(model) {
12
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
13
+ var saved;
14
+ var _this = this;
15
+ return (0, tslib_1.__generator)(this, function (_a) {
16
+ switch (_a.label) {
17
+ case 0: return [4 /*yield*/, new Promise(function (resolve) {
18
+ void model.save({
19
+ save: function (artifacts) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
20
+ return (0, tslib_1.__generator)(this, function (_a) {
21
+ resolve(artifacts);
22
+ return [2 /*return*/, {
23
+ modelArtifactsInfo: {
24
+ dateSaved: new Date(),
25
+ modelTopologyType: 'JSON'
26
+ }
27
+ }];
28
+ });
29
+ }); }
30
+ });
31
+ })];
32
+ case 1:
33
+ saved = _a.sent();
34
+ return [2 /*return*/, (0, tslib_1.__spreadArray)([], (0, tslib_1.__read)(msgpack_lite_1.default.encode(saved).values()), false)];
35
+ }
36
+ });
37
+ });
38
+ }
39
+ exports.encode = encode;
40
+ function decode(encoded) {
41
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
42
+ var raw;
43
+ return (0, tslib_1.__generator)(this, function (_a) {
44
+ switch (_a.label) {
45
+ case 0:
46
+ raw = msgpack_lite_1.default.decode(encoded);
47
+ return [4 /*yield*/, tf.loadLayersModel({
48
+ load: function () { return raw; }
49
+ })];
50
+ case 1: return [2 /*return*/, _a.sent()];
51
+ }
52
+ });
53
+ });
54
+ }
55
+ exports.decode = decode;
@@ -0,0 +1,5 @@
1
+ import { Weights } from '@/types';
2
+ export declare type Encoded = number[];
3
+ export declare function isEncoded(raw: unknown): raw is Encoded;
4
+ export declare function encode(weights: Weights): Promise<Encoded>;
5
+ export declare function decode(encoded: Encoded): Weights;
@@ -0,0 +1,62 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.decode = exports.encode = exports.isEncoded = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
+ var msgpack = (0, tslib_1.__importStar)(require("msgpack-lite"));
7
+ function isSerialized(raw) {
8
+ if (typeof raw !== 'object' || raw === null) {
9
+ return false;
10
+ }
11
+ if (!('shape' in raw && 'data' in raw)) {
12
+ return false;
13
+ }
14
+ var _a = raw, shape = _a.shape, data = _a.data;
15
+ if (!(Array.isArray(shape) && shape.every(function (e) { return typeof e === 'number'; })) ||
16
+ !(data instanceof Float32Array)) {
17
+ return false;
18
+ }
19
+ // eslint-disable-next-line
20
+ var _ = { shape: shape, data: data };
21
+ return true;
22
+ }
23
+ function isEncoded(raw) {
24
+ return Array.isArray(raw) && raw.every(function (e) { return typeof e === 'number'; });
25
+ }
26
+ exports.isEncoded = isEncoded;
27
+ function encode(weights) {
28
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
29
+ var serialized;
30
+ var _this = this;
31
+ return (0, tslib_1.__generator)(this, function (_a) {
32
+ switch (_a.label) {
33
+ case 0: return [4 /*yield*/, Promise.all(weights.map(function (t) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () {
34
+ var _a;
35
+ return (0, tslib_1.__generator)(this, function (_b) {
36
+ switch (_b.label) {
37
+ case 0:
38
+ _a = {
39
+ shape: t.shape
40
+ };
41
+ return [4 /*yield*/, t.data()];
42
+ case 1: return [2 /*return*/, (_a.data = _b.sent(),
43
+ _a)];
44
+ }
45
+ });
46
+ }); }))];
47
+ case 1:
48
+ serialized = _a.sent();
49
+ return [2 /*return*/, (0, tslib_1.__spreadArray)([], (0, tslib_1.__read)(msgpack.encode(serialized).values()), false)];
50
+ }
51
+ });
52
+ });
53
+ }
54
+ exports.encode = encode;
55
+ function decode(encoded) {
56
+ var raw = msgpack.decode(encoded);
57
+ if (!(Array.isArray(raw) && raw.every(isSerialized))) {
58
+ throw new Error('expected to decode an array of serialized weights');
59
+ }
60
+ return raw.map(function (w) { return tf.tensor(w.data, w.shape); });
61
+ }
62
+ exports.decode = decode;
@@ -0,0 +1,5 @@
1
+ export declare function isDataExample(raw: unknown): raw is DataExample;
2
+ export interface DataExample {
3
+ columnName: string;
4
+ columnData: string | number;
5
+ }
@@ -0,0 +1,24 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.isDataExample = void 0;
4
+ var immutable_1 = require("immutable");
5
+ function isDataExample(raw) {
6
+ if (typeof raw !== 'object') {
7
+ return false;
8
+ }
9
+ if (raw === null) {
10
+ return false;
11
+ }
12
+ if (!(0, immutable_1.Set)(Object.keys(raw)).equals(immutable_1.Set.of('columnName', 'columnData'))) {
13
+ return false;
14
+ }
15
+ var _a = raw, columnName = _a.columnName, columnData = _a.columnData;
16
+ if (typeof columnName !== 'string' ||
17
+ (typeof columnData !== 'string' && typeof columnData !== 'number')) {
18
+ return false;
19
+ }
20
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
21
+ var _ = { columnName: columnName, columnData: columnData };
22
+ return true;
23
+ }
24
+ exports.isDataExample = isDataExample;
@@ -0,0 +1,15 @@
1
+ import { DataExample } from './data_example';
2
+ export declare function isDisplayInformation(raw: unknown): raw is DisplayInformation;
3
+ export interface DisplayInformation {
4
+ taskTitle: string;
5
+ summary: string;
6
+ overview: string;
7
+ tradeoffs: string;
8
+ dataFormatInformation: string;
9
+ dataExampleText: string;
10
+ model?: string;
11
+ dataExample?: DataExample[];
12
+ headers?: string[];
13
+ dataExampleImage?: string;
14
+ limitations?: string;
15
+ }
@@ -0,0 +1,53 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.isDisplayInformation = void 0;
4
+ var immutable_1 = require("immutable");
5
+ var data_example_1 = require("./data_example");
6
+ function isDisplayInformation(raw) {
7
+ if (typeof raw !== 'object') {
8
+ return false;
9
+ }
10
+ if (raw === null) {
11
+ return false;
12
+ }
13
+ var requiredFields = immutable_1.Set.of('dataExampleText', 'dataFormatInformation', 'overview', 'summary', 'taskTitle', 'tradeoffs');
14
+ if (!requiredFields.isSubset(Object.keys(raw))) {
15
+ return false;
16
+ }
17
+ var _a = raw, dataExample = _a.dataExample, dataExampleImage = _a.dataExampleImage, dataExampleText = _a.dataExampleText, dataFormatInformation = _a.dataFormatInformation, headers = _a.headers, limitations = _a.limitations, model = _a.model, overview = _a.overview, summary = _a.summary, taskTitle = _a.taskTitle, tradeoffs = _a.tradeoffs;
18
+ if (typeof dataExampleText !== 'string' ||
19
+ typeof dataFormatInformation !== 'string' ||
20
+ typeof overview !== 'string' ||
21
+ typeof summary !== 'string' ||
22
+ typeof taskTitle !== 'string' ||
23
+ typeof tradeoffs !== 'string' ||
24
+ (model !== undefined && typeof model !== 'string') ||
25
+ (dataExampleImage !== undefined && typeof dataExampleImage !== 'string') ||
26
+ (limitations !== undefined && typeof limitations !== 'string')) {
27
+ return false;
28
+ }
29
+ if (dataExample !== undefined && !(Array.isArray(dataExample) &&
30
+ dataExample.every(data_example_1.isDataExample))) {
31
+ return false;
32
+ }
33
+ if (headers !== undefined && !(Array.isArray(headers) &&
34
+ headers.every(function (e) { return typeof e === 'string'; }))) {
35
+ return false;
36
+ }
37
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
38
+ var _ = {
39
+ taskTitle: taskTitle,
40
+ summary: summary,
41
+ overview: overview,
42
+ tradeoffs: tradeoffs,
43
+ dataFormatInformation: dataFormatInformation,
44
+ dataExampleText: dataExampleText,
45
+ model: model,
46
+ dataExample: dataExample,
47
+ headers: headers,
48
+ dataExampleImage: dataExampleImage,
49
+ limitations: limitations
50
+ };
51
+ return true;
52
+ }
53
+ exports.isDisplayInformation = isDisplayInformation;
@@ -0,0 +1,3 @@
1
+ export { isTask, Task, isTaskID, TaskID } from './task';
2
+ export { isDisplayInformation, DisplayInformation } from './display_information';
3
+ export { TrainingInformation } from './training_information';
@@ -0,0 +1,8 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.isDisplayInformation = exports.isTaskID = exports.isTask = void 0;
4
+ var task_1 = require("./task");
5
+ Object.defineProperty(exports, "isTask", { enumerable: true, get: function () { return task_1.isTask; } });
6
+ Object.defineProperty(exports, "isTaskID", { enumerable: true, get: function () { return task_1.isTaskID; } });
7
+ var display_information_1 = require("./display_information");
8
+ Object.defineProperty(exports, "isDisplayInformation", { enumerable: true, get: function () { return display_information_1.isDisplayInformation; } });
@@ -0,0 +1,6 @@
1
+ export declare class ModelCompileData {
2
+ readonly optimizer: string;
3
+ readonly loss: string;
4
+ readonly metrics: string[];
5
+ constructor(optimizer: string, loss: string, metrics: string[]);
6
+ }
@@ -0,0 +1,12 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.ModelCompileData = void 0;
4
+ var ModelCompileData = /** @class */ (function () {
5
+ function ModelCompileData(optimizer, loss, metrics) {
6
+ this.optimizer = optimizer;
7
+ this.loss = loss;
8
+ this.metrics = metrics;
9
+ }
10
+ return ModelCompileData;
11
+ }());
12
+ exports.ModelCompileData = ModelCompileData;
@@ -0,0 +1,10 @@
1
+ import { DisplayInformation } from './display_information';
2
+ import { TrainingInformation } from './training_information';
3
+ export declare type TaskID = string;
4
+ export declare function isTaskID(obj: unknown): obj is TaskID;
5
+ export declare function isTask(raw: unknown): raw is Task;
6
+ export interface Task {
7
+ taskID: TaskID;
8
+ displayInformation?: DisplayInformation;
9
+ trainingInformation?: TrainingInformation;
10
+ }
@@ -0,0 +1,32 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.isTask = exports.isTaskID = void 0;
4
+ var display_information_1 = require("./display_information");
5
+ function isTaskID(obj) {
6
+ return typeof obj === 'string';
7
+ }
8
+ exports.isTaskID = isTaskID;
9
+ function isTask(raw) {
10
+ if (typeof raw !== 'object') {
11
+ return false;
12
+ }
13
+ if (raw === null) {
14
+ return false;
15
+ }
16
+ if (!('taskID' in raw)) {
17
+ return false;
18
+ }
19
+ var _a = raw, taskID = _a.taskID, displayInformation = _a.displayInformation;
20
+ if (typeof taskID !== 'string') {
21
+ return false;
22
+ }
23
+ if (displayInformation !== undefined &&
24
+ !(0, display_information_1.isDisplayInformation)(displayInformation)) {
25
+ return false;
26
+ }
27
+ // TODO check for TrainingInformation
28
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
29
+ var _ = { taskID: taskID, displayInformation: displayInformation };
30
+ return true;
31
+ }
32
+ exports.isTask = isTask;
@@ -0,0 +1,29 @@
1
+ import { DataExample } from './data_example';
2
+ import { ModelCompileData } from './model_compile_data';
3
+ export interface TrainingInformation {
4
+ modelID: string;
5
+ epochs: number;
6
+ roundDuration: number;
7
+ validationSplit: number;
8
+ batchSize: number;
9
+ preprocessFunctions: string[];
10
+ modelCompileData: ModelCompileData;
11
+ dataType: string;
12
+ receivedMessagesThreshold?: number;
13
+ inputColumns?: string[];
14
+ outputColumns?: string[];
15
+ threshold?: number;
16
+ IMAGE_H?: number;
17
+ IMAGE_W?: number;
18
+ LABEL_LIST?: string[];
19
+ aggregateImagesById?: boolean;
20
+ learningRate?: number;
21
+ NUM_CLASSES?: number;
22
+ csvLabels?: boolean;
23
+ RESIZED_IMAGE_H?: number;
24
+ RESIZED_IMAGE_W?: number;
25
+ LABEL_ASSIGNMENT?: DataExample[];
26
+ scheme?: string;
27
+ noiseScale?: number;
28
+ clippingRadius?: number;
29
+ }
@@ -0,0 +1,2 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
@@ -0,0 +1,4 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { Task } from '../task';
3
+ export declare const task: Task;
4
+ export declare function model(): Promise<tf.LayersModel>;
@@ -0,0 +1,74 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.model = exports.task = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
+ exports.task = {
7
+ taskID: 'cifar10',
8
+ displayInformation: {
9
+ taskTitle: 'CIFAR10',
10
+ summary: 'In this challenge, we ask you to classify images into categories based on the objects shown on the image.',
11
+ overview: 'The CIFAR-10 dataset is a collection of images that are commonly used to train machine learning and computer vision algorithms. It is one of the most widely used datasets for machine learning research.',
12
+ limitations: 'The training data is limited to small images of size 32x32.',
13
+ tradeoffs: 'Training success strongly depends on label distribution',
14
+ dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The label file should be .csv, where each row contains a file_name, class. <br> <br> e.g. if you have images: 0.png (of a frog) and 1.png (of a car) <br> labels.csv contains: (Note that no header is needed)<br> 0.png, frog <br> 1.png, car',
15
+ dataExampleText: 'Below you can find 10 random examples from each of the 10 classes in the dataset.',
16
+ dataExampleImage: './cifar10-example.png'
17
+ },
18
+ trainingInformation: {
19
+ modelID: 'cifar10-model',
20
+ epochs: 10,
21
+ roundDuration: 10,
22
+ validationSplit: 0.2,
23
+ batchSize: 10,
24
+ modelCompileData: {
25
+ optimizer: 'adam',
26
+ loss: 'categoricalCrossentropy',
27
+ metrics: ['accuracy']
28
+ },
29
+ threshold: 1,
30
+ dataType: 'image',
31
+ csvLabels: true,
32
+ IMAGE_H: 32,
33
+ IMAGE_W: 32,
34
+ preprocessFunctions: ['resize'],
35
+ RESIZED_IMAGE_H: 224,
36
+ RESIZED_IMAGE_W: 224,
37
+ LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
38
+ LABEL_ASSIGNMENT: [
39
+ { columnName: 'airplane', columnData: 0 },
40
+ { columnName: 'automobile', columnData: 1 },
41
+ { columnName: 'bird', columnData: 2 },
42
+ { columnName: 'cat', columnData: 3 },
43
+ { columnName: 'deer', columnData: 4 },
44
+ { columnName: 'dog', columnData: 5 },
45
+ { columnName: 'frog', columnData: 6 },
46
+ { columnName: 'horse', columnData: 7 },
47
+ { columnName: 'ship', columnData: 8 },
48
+ { columnName: 'truck', columnData: 9 }
49
+ ],
50
+ scheme: 'Decentralized'
51
+ }
52
+ };
53
+ function model() {
54
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
55
+ var mobilenet, x, predictions;
56
+ return (0, tslib_1.__generator)(this, function (_a) {
57
+ switch (_a.label) {
58
+ case 0: return [4 /*yield*/, tf.loadLayersModel('https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json')];
59
+ case 1:
60
+ mobilenet = _a.sent();
61
+ x = mobilenet.getLayer('global_average_pooling2d_1');
62
+ predictions = tf.layers
63
+ .dense({ units: 10, activation: 'softmax', name: 'denseModified' })
64
+ .apply(x.output);
65
+ return [2 /*return*/, tf.model({
66
+ inputs: mobilenet.input,
67
+ outputs: predictions,
68
+ name: 'modelModified'
69
+ })];
70
+ }
71
+ });
72
+ });
73
+ }
74
+ exports.model = model;
@@ -0,0 +1,5 @@
1
+ export * as cifar10 from './cifar10';
2
+ export * as lus_covid from './lus_covid';
3
+ export * as mnist from './mnist';
4
+ export * as simple_face from './simple_face';
5
+ export * as titanic from './titanic';
@@ -0,0 +1,9 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.titanic = exports.simple_face = exports.mnist = exports.lus_covid = exports.cifar10 = void 0;
4
+ var tslib_1 = require("tslib");
5
+ exports.cifar10 = (0, tslib_1.__importStar)(require("./cifar10"));
6
+ exports.lus_covid = (0, tslib_1.__importStar)(require("./lus_covid"));
7
+ exports.mnist = (0, tslib_1.__importStar)(require("./mnist"));
8
+ exports.simple_face = (0, tslib_1.__importStar)(require("./simple_face"));
9
+ exports.titanic = (0, tslib_1.__importStar)(require("./titanic"));
@@ -0,0 +1,4 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { Task } from '../task';
3
+ export declare const task: Task;
4
+ export declare function model(): tf.LayersModel;
@@ -0,0 +1,48 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.model = exports.task = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
+ exports.task = {
7
+ taskID: 'lus_covid',
8
+ displayInformation: {
9
+ taskTitle: 'COVID Lung Ultrasound',
10
+ summary: "Do you have a dataset of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task. <br><br> Don’t have a dataset of your own? Download a sample of a few cases <a class='underline text-primary-dark dark:text-primary-light' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly'>here</a>.",
11
+ overview: "Do you have a dataset of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task. <br><br> Don’t have a dataset of your own? Download a sample of a few cases <a class='underline text-primary-dark dark:text-primary-light' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly'>here</a>.",
12
+ model: 'We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (intelligent Global Health). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet',
13
+ tradeoffs: 'We are using a simpler version of DeepChest in order to be able to run it on the browser.',
14
+ dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
15
+ dataExampleText: 'Below you can find an example of an expected lung image for patient 2 named: 2_QAID_1.masked.reshaped.squared.224.png',
16
+ dataExampleImage: './2_QAID_1.masked.reshaped.squared.224.png'
17
+ },
18
+ trainingInformation: {
19
+ modelID: 'lus-covid-model',
20
+ epochs: 15,
21
+ roundDuration: 10,
22
+ validationSplit: 0.2,
23
+ batchSize: 2,
24
+ modelCompileData: {
25
+ optimizer: 'adam',
26
+ loss: 'binaryCrossentropy',
27
+ metrics: ['accuracy']
28
+ },
29
+ learningRate: 0.05,
30
+ threshold: 2,
31
+ IMAGE_H: 224,
32
+ IMAGE_W: 224,
33
+ preprocessFunctions: [],
34
+ LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
35
+ NUM_CLASSES: 2,
36
+ dataType: 'image',
37
+ aggregateImagesById: true,
38
+ scheme: 'Decentralized'
39
+ }
40
+ };
41
+ function model() {
42
+ var model = tf.sequential();
43
+ model.add(tf.layers.dense({ inputShape: [1000], units: 512, activation: 'relu' }));
44
+ model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
45
+ model.add(tf.layers.dense({ units: 2, activation: 'softmax' }));
46
+ return model;
47
+ }
48
+ exports.model = model;
@@ -0,0 +1,4 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { Task } from '../task';
3
+ export declare const task: Task;
4
+ export declare function model(): tf.LayersModel;
@@ -0,0 +1,56 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.model = exports.task = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
+ exports.task = {
7
+ taskID: 'mnist',
8
+ displayInformation: {
9
+ taskTitle: 'MNIST',
10
+ summary: "Test our platform by using a publicly available <b>image</b> dataset. <br><br> Download the classic MNIST imagebank of hand-written numbers <a class='underline text-primary-dark dark:text-primary-light' href='https://www.kaggle.com/scolianni/mnistasjpg'>here</a>. <br> This model learns to identify hand written numbers.",
11
+ overview: 'The MNIST handwritten digit classification problem is a standard dataset used in computer vision and deep learning. Although the dataset is effectively solved, we use it to test our Decentralised Learning algorithms and platform.',
12
+ model: 'The current model is a very simple CNN and its main goal is to test the app and the Decentralizsed Learning functionality.',
13
+ tradeoffs: 'We are using a simple model, first a 2d convolutional layer > max pooling > 2d convolutional layer > max pooling > convolutional layer > 2 dense layers.',
14
+ dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can upload each digit image of your dataset in the box corresponding to its label. The model taskes images of size 28x28 as input.',
15
+ dataExampleText: 'Below you can find an example of an expected image representing the digit 9.',
16
+ dataExampleImage: './9-mnist-example.png'
17
+ },
18
+ trainingInformation: {
19
+ modelID: 'mnist-model',
20
+ epochs: 10,
21
+ roundDuration: 10,
22
+ validationSplit: 0.2,
23
+ batchSize: 30,
24
+ modelCompileData: {
25
+ optimizer: 'rmsprop',
26
+ loss: 'categoricalCrossentropy',
27
+ metrics: ['accuracy']
28
+ },
29
+ threshold: 1,
30
+ dataType: 'image',
31
+ IMAGE_H: 28,
32
+ IMAGE_W: 28,
33
+ preprocessFunctions: [],
34
+ LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
35
+ aggregateImagesById: false,
36
+ scheme: 'Decentralized'
37
+ }
38
+ };
39
+ function model() {
40
+ var model = tf.sequential();
41
+ model.add(tf.layers.conv2d({
42
+ inputShape: [28, 28, 3],
43
+ kernelSize: 3,
44
+ filters: 16,
45
+ activation: 'relu'
46
+ }));
47
+ model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
48
+ model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
49
+ model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
50
+ model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
51
+ model.add(tf.layers.flatten({}));
52
+ model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
53
+ model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
54
+ return model;
55
+ }
56
+ exports.model = model;
@@ -0,0 +1,4 @@
1
+ import * as tf from '@tensorflow/tfjs';
2
+ import { Task } from '../task';
3
+ export declare const task: Task;
4
+ export declare function model(imageWidth?: number, imageHeight?: number, imageChannels?: number, numOutputClasses?: number): tf.LayersModel;
@@ -0,0 +1,88 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.model = exports.task = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
6
+ exports.task = {
7
+ taskID: 'simple_face',
8
+ displayInformation: {
9
+ taskTitle: 'Simple Face',
10
+ summary: 'Can you detect if the person in a picture is a child or an adult?',
11
+ overview: 'Simple face is a small subset of face_task from Kaggle',
12
+ limitations: 'The training data is limited to small images of size 200x200.',
13
+ tradeoffs: 'Training success strongly depends on label distribution',
14
+ dataFormatInformation: '',
15
+ dataExampleText: 'Bellow you can find an example',
16
+ dataExampleImage: './simple_face-example.png'
17
+ },
18
+ trainingInformation: {
19
+ modelID: 'simple_face-model',
20
+ epochs: 50,
21
+ roundDuration: 10,
22
+ validationSplit: 0.2,
23
+ batchSize: 10,
24
+ preprocessFunctions: [],
25
+ modelCompileData: {
26
+ optimizer: 'adam',
27
+ loss: 'categoricalCrossentropy',
28
+ metrics: ['accuracy']
29
+ },
30
+ dataType: 'image',
31
+ csvLabels: false,
32
+ IMAGE_H: 200,
33
+ IMAGE_W: 200,
34
+ LABEL_LIST: ['child', 'adult']
35
+ }
36
+ };
37
+ function model(imageWidth, imageHeight, imageChannels, numOutputClasses) {
38
+ if (imageWidth === void 0) { imageWidth = 200; }
39
+ if (imageHeight === void 0) { imageHeight = 200; }
40
+ if (imageChannels === void 0) { imageChannels = 3; }
41
+ if (numOutputClasses === void 0) { numOutputClasses = 2; }
42
+ var model = tf.sequential();
43
+ // In the first layer of our convolutional neural network we have
44
+ // to specify the input shape. Then we specify some parameters for
45
+ // the convolution operation that takes place in this layer.
46
+ model.add(tf.layers.conv2d({
47
+ inputShape: [imageWidth, imageHeight, imageChannels],
48
+ kernelSize: 5,
49
+ filters: 8,
50
+ strides: 1,
51
+ activation: 'relu',
52
+ kernelInitializer: 'varianceScaling'
53
+ }));
54
+ // The MaxPooling layer acts as a sort of downsampling using max values
55
+ // in a region instead of averaging.
56
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
57
+ // Repeat another conv2d + maxPooling stack.
58
+ // Note that we have more filters in the convolution.
59
+ model.add(tf.layers.conv2d({
60
+ kernelSize: 5,
61
+ filters: 16,
62
+ strides: 1,
63
+ activation: 'relu',
64
+ kernelInitializer: 'varianceScaling'
65
+ }));
66
+ model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }));
67
+ // Now we flatten the output from the 2D filters into a 1D vector to prepare
68
+ // it for input into our last layer. This is common practice when feeding
69
+ // higher dimensional data to a final classification output layer.
70
+ model.add(tf.layers.flatten());
71
+ // Our last layer is a dense layer which has 10 output units, one for each
72
+ // output class (i.e. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9).
73
+ model.add(tf.layers.dense({
74
+ units: numOutputClasses,
75
+ kernelInitializer: 'varianceScaling',
76
+ activation: 'softmax'
77
+ }));
78
+ // Choose an optimizer, loss function and accuracy metric,
79
+ // then compile and return the model
80
+ var optimizer = tf.train.adam();
81
+ model.compile({
82
+ optimizer: optimizer,
83
+ loss: 'categoricalCrossentropy',
84
+ metrics: ['accuracy']
85
+ });
86
+ return model;
87
+ }
88
+ exports.model = model;