@epfml/discojs 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +32 -0
- package/dist/aggregation.d.ts +3 -0
- package/dist/aggregation.js +19 -0
- package/dist/async_buffer.d.ts +41 -0
- package/dist/async_buffer.js +98 -0
- package/dist/async_informant.d.ts +20 -0
- package/dist/async_informant.js +69 -0
- package/dist/client/base.d.ts +36 -0
- package/dist/client/base.js +34 -0
- package/dist/client/decentralized.d.ts +23 -0
- package/dist/client/decentralized.js +275 -0
- package/dist/client/federated.d.ts +30 -0
- package/dist/client/federated.js +221 -0
- package/dist/client/index.d.ts +4 -0
- package/dist/client/index.js +11 -0
- package/dist/client/local.d.ts +8 -0
- package/dist/client/local.js +34 -0
- package/dist/dataset/data_loader/data_loader.d.ts +16 -0
- package/dist/dataset/data_loader/data_loader.js +10 -0
- package/dist/dataset/data_loader/image_loader.d.ts +14 -0
- package/dist/dataset/data_loader/image_loader.js +93 -0
- package/dist/dataset/data_loader/index.d.ts +3 -0
- package/dist/dataset/data_loader/index.js +9 -0
- package/dist/dataset/data_loader/tabular_loader.d.ts +29 -0
- package/dist/dataset/data_loader/tabular_loader.js +88 -0
- package/dist/dataset/dataset_builder.d.ts +17 -0
- package/dist/dataset/dataset_builder.js +80 -0
- package/dist/dataset/index.d.ts +2 -0
- package/dist/dataset/index.js +7 -0
- package/dist/index.d.ts +17 -0
- package/dist/index.js +34 -0
- package/dist/logging/console_logger.d.ts +18 -0
- package/dist/logging/console_logger.js +33 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +7 -0
- package/dist/logging/logger.d.ts +12 -0
- package/dist/logging/logger.js +9 -0
- package/dist/logging/trainer_logger.d.ts +24 -0
- package/dist/logging/trainer_logger.js +59 -0
- package/dist/memory/base.d.ts +53 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +12 -0
- package/dist/memory/empty.js +69 -0
- package/dist/memory/index.d.ts +3 -0
- package/dist/memory/index.js +9 -0
- package/dist/memory/model_type.d.ts +4 -0
- package/dist/memory/model_type.js +9 -0
- package/dist/model_actor.d.ts +16 -0
- package/dist/model_actor.js +20 -0
- package/dist/privacy.d.ts +12 -0
- package/dist/privacy.js +60 -0
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +6 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +55 -0
- package/dist/serialization/weights.d.ts +5 -0
- package/dist/serialization/weights.js +62 -0
- package/dist/task/data_example.d.ts +5 -0
- package/dist/task/data_example.js +24 -0
- package/dist/task/display_information.d.ts +15 -0
- package/dist/task/display_information.js +53 -0
- package/dist/task/index.d.ts +3 -0
- package/dist/task/index.js +8 -0
- package/dist/task/model_compile_data.d.ts +6 -0
- package/dist/task/model_compile_data.js +12 -0
- package/dist/task/task.d.ts +10 -0
- package/dist/task/task.js +32 -0
- package/dist/task/training_information.d.ts +29 -0
- package/dist/task/training_information.js +2 -0
- package/dist/tasks/cifar10.d.ts +4 -0
- package/dist/tasks/cifar10.js +74 -0
- package/dist/tasks/index.d.ts +5 -0
- package/dist/tasks/index.js +9 -0
- package/dist/tasks/lus_covid.d.ts +4 -0
- package/dist/tasks/lus_covid.js +48 -0
- package/dist/tasks/mnist.d.ts +4 -0
- package/dist/tasks/mnist.js +56 -0
- package/dist/tasks/simple_face.d.ts +4 -0
- package/dist/tasks/simple_face.js +88 -0
- package/dist/tasks/titanic.d.ts +4 -0
- package/dist/tasks/titanic.js +86 -0
- package/dist/testing/tester.d.ts +5 -0
- package/dist/testing/tester.js +21 -0
- package/dist/training/disco.d.ts +12 -0
- package/dist/training/disco.js +62 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +7 -0
- package/dist/training/trainer/distributed_trainer.d.ts +21 -0
- package/dist/training/trainer/distributed_trainer.js +60 -0
- package/dist/training/trainer/local_trainer.d.ts +10 -0
- package/dist/training/trainer/local_trainer.js +37 -0
- package/dist/training/trainer/round_tracker.d.ts +30 -0
- package/dist/training/trainer/round_tracker.js +44 -0
- package/dist/training/trainer/trainer.d.ts +66 -0
- package/dist/training/trainer/trainer.js +146 -0
- package/dist/training/trainer/trainer_builder.d.ts +25 -0
- package/dist/training/trainer/trainer_builder.js +102 -0
- package/dist/training/training_schemes.d.ts +5 -0
- package/dist/training/training_schemes.js +10 -0
- package/dist/training_informant.d.ts +88 -0
- package/dist/training_informant.js +135 -0
- package/dist/types.d.ts +4 -0
- package/dist/types.js +2 -0
- package/package.json +48 -0
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.DatasetBuilder = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var DatasetBuilder = /** @class */ (function () {
|
|
6
|
+
function DatasetBuilder(dataLoader, task) {
|
|
7
|
+
this.dataLoader = dataLoader;
|
|
8
|
+
this.task = task;
|
|
9
|
+
this.sources = [];
|
|
10
|
+
this.labelledSources = new Map();
|
|
11
|
+
this.built = false;
|
|
12
|
+
}
|
|
13
|
+
DatasetBuilder.prototype.addFiles = function (sources, label) {
|
|
14
|
+
var _this = this;
|
|
15
|
+
if (this.built) {
|
|
16
|
+
throw new Error('builder already consumed');
|
|
17
|
+
}
|
|
18
|
+
if (label === undefined) {
|
|
19
|
+
this.sources = this.sources.concat(sources);
|
|
20
|
+
}
|
|
21
|
+
else {
|
|
22
|
+
sources.forEach(function (source) { return _this.labelledSources.set(label, source); });
|
|
23
|
+
}
|
|
24
|
+
};
|
|
25
|
+
DatasetBuilder.prototype.clearFiles = function (label) {
|
|
26
|
+
if (this.built) {
|
|
27
|
+
throw new Error('builder already consumed');
|
|
28
|
+
}
|
|
29
|
+
if (label === undefined) {
|
|
30
|
+
this.sources = [];
|
|
31
|
+
}
|
|
32
|
+
else {
|
|
33
|
+
this.labelledSources.delete(label);
|
|
34
|
+
}
|
|
35
|
+
};
|
|
36
|
+
DatasetBuilder.prototype.build = function () {
|
|
37
|
+
var _a, _b;
|
|
38
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
39
|
+
var data, config, config;
|
|
40
|
+
return (0, tslib_1.__generator)(this, function (_c) {
|
|
41
|
+
switch (_c.label) {
|
|
42
|
+
case 0:
|
|
43
|
+
// Require that at leat one source collection is non-empty, but not both
|
|
44
|
+
if ((this.sources.length > 0) === (this.labelledSources.size > 0)) {
|
|
45
|
+
throw new Error('invalid sources');
|
|
46
|
+
}
|
|
47
|
+
if (!(this.sources.length > 0)) return [3 /*break*/, 2];
|
|
48
|
+
config = {
|
|
49
|
+
features: (_a = this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.inputColumns,
|
|
50
|
+
labels: (_b = this.task.trainingInformation) === null || _b === void 0 ? void 0 : _b.outputColumns
|
|
51
|
+
};
|
|
52
|
+
return [4 /*yield*/, this.dataLoader.loadAll(this.sources, config)];
|
|
53
|
+
case 1:
|
|
54
|
+
data = _c.sent();
|
|
55
|
+
return [3 /*break*/, 4];
|
|
56
|
+
case 2:
|
|
57
|
+
config = {
|
|
58
|
+
labels: Array.from(this.labelledSources.keys())
|
|
59
|
+
};
|
|
60
|
+
return [4 /*yield*/, this.dataLoader.loadAll(Array.from(this.labelledSources.values()), config)];
|
|
61
|
+
case 3:
|
|
62
|
+
data = _c.sent();
|
|
63
|
+
_c.label = 4;
|
|
64
|
+
case 4:
|
|
65
|
+
// TODO @s314cy: Support .csv labels for image datasets
|
|
66
|
+
this.built = true;
|
|
67
|
+
return [2 /*return*/, data];
|
|
68
|
+
}
|
|
69
|
+
});
|
|
70
|
+
});
|
|
71
|
+
};
|
|
72
|
+
DatasetBuilder.prototype.isBuilt = function () {
|
|
73
|
+
return this.built;
|
|
74
|
+
};
|
|
75
|
+
DatasetBuilder.prototype.size = function () {
|
|
76
|
+
return Math.max(this.sources.length, this.labelledSources.size);
|
|
77
|
+
};
|
|
78
|
+
return DatasetBuilder;
|
|
79
|
+
}());
|
|
80
|
+
exports.DatasetBuilder = DatasetBuilder;
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.DatasetBuilder = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var dataset_builder_1 = require("./dataset_builder");
|
|
6
|
+
Object.defineProperty(exports, "DatasetBuilder", { enumerable: true, get: function () { return dataset_builder_1.DatasetBuilder; } });
|
|
7
|
+
(0, tslib_1.__exportStar)(require("./data_loader"), exports);
|
package/dist/index.d.ts
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
export * as aggregation from './aggregation';
|
|
2
|
+
export * as dataset from './dataset';
|
|
3
|
+
export * as serialization from './serialization';
|
|
4
|
+
export * as tasks from './tasks';
|
|
5
|
+
export * as training from './training';
|
|
6
|
+
export * as privacy from './privacy';
|
|
7
|
+
export { Base as Client } from './client';
|
|
8
|
+
export * as client from './client';
|
|
9
|
+
export { AsyncBuffer } from './async_buffer';
|
|
10
|
+
export { AsyncInformant } from './async_informant';
|
|
11
|
+
export { Logger, ConsoleLogger } from './logging';
|
|
12
|
+
export { Memory, ModelType, Empty as EmptyMemory } from './memory';
|
|
13
|
+
export { ModelActor } from './model_actor';
|
|
14
|
+
export { TrainingInformation, DisplayInformation, isTask, Task, isTaskID, TaskID } from './task';
|
|
15
|
+
export { TrainingInformant } from './training_informant';
|
|
16
|
+
export { TrainingSchemes } from './training/training_schemes';
|
|
17
|
+
export * from './types';
|
package/dist/index.js
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.TrainingSchemes = exports.TrainingInformant = exports.isTaskID = exports.isTask = exports.ModelActor = exports.EmptyMemory = exports.ModelType = exports.Memory = exports.ConsoleLogger = exports.Logger = exports.AsyncInformant = exports.AsyncBuffer = exports.client = exports.Client = exports.privacy = exports.training = exports.tasks = exports.serialization = exports.dataset = exports.aggregation = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
exports.aggregation = (0, tslib_1.__importStar)(require("./aggregation"));
|
|
6
|
+
exports.dataset = (0, tslib_1.__importStar)(require("./dataset"));
|
|
7
|
+
exports.serialization = (0, tslib_1.__importStar)(require("./serialization"));
|
|
8
|
+
exports.tasks = (0, tslib_1.__importStar)(require("./tasks"));
|
|
9
|
+
exports.training = (0, tslib_1.__importStar)(require("./training"));
|
|
10
|
+
exports.privacy = (0, tslib_1.__importStar)(require("./privacy"));
|
|
11
|
+
var client_1 = require("./client");
|
|
12
|
+
Object.defineProperty(exports, "Client", { enumerable: true, get: function () { return client_1.Base; } });
|
|
13
|
+
exports.client = (0, tslib_1.__importStar)(require("./client"));
|
|
14
|
+
var async_buffer_1 = require("./async_buffer");
|
|
15
|
+
Object.defineProperty(exports, "AsyncBuffer", { enumerable: true, get: function () { return async_buffer_1.AsyncBuffer; } });
|
|
16
|
+
var async_informant_1 = require("./async_informant");
|
|
17
|
+
Object.defineProperty(exports, "AsyncInformant", { enumerable: true, get: function () { return async_informant_1.AsyncInformant; } });
|
|
18
|
+
var logging_1 = require("./logging");
|
|
19
|
+
Object.defineProperty(exports, "Logger", { enumerable: true, get: function () { return logging_1.Logger; } });
|
|
20
|
+
Object.defineProperty(exports, "ConsoleLogger", { enumerable: true, get: function () { return logging_1.ConsoleLogger; } });
|
|
21
|
+
var memory_1 = require("./memory");
|
|
22
|
+
Object.defineProperty(exports, "Memory", { enumerable: true, get: function () { return memory_1.Memory; } });
|
|
23
|
+
Object.defineProperty(exports, "ModelType", { enumerable: true, get: function () { return memory_1.ModelType; } });
|
|
24
|
+
Object.defineProperty(exports, "EmptyMemory", { enumerable: true, get: function () { return memory_1.Empty; } });
|
|
25
|
+
var model_actor_1 = require("./model_actor");
|
|
26
|
+
Object.defineProperty(exports, "ModelActor", { enumerable: true, get: function () { return model_actor_1.ModelActor; } });
|
|
27
|
+
var task_1 = require("./task");
|
|
28
|
+
Object.defineProperty(exports, "isTask", { enumerable: true, get: function () { return task_1.isTask; } });
|
|
29
|
+
Object.defineProperty(exports, "isTaskID", { enumerable: true, get: function () { return task_1.isTaskID; } });
|
|
30
|
+
var training_informant_1 = require("./training_informant");
|
|
31
|
+
Object.defineProperty(exports, "TrainingInformant", { enumerable: true, get: function () { return training_informant_1.TrainingInformant; } });
|
|
32
|
+
var training_schemes_1 = require("./training/training_schemes");
|
|
33
|
+
Object.defineProperty(exports, "TrainingSchemes", { enumerable: true, get: function () { return training_schemes_1.TrainingSchemes; } });
|
|
34
|
+
(0, tslib_1.__exportStar)(require("./types"), exports);
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import { Logger } from './logger';
|
|
2
|
+
/**
|
|
3
|
+
* Same properties as Toaster but on the console
|
|
4
|
+
*
|
|
5
|
+
* @class Logger
|
|
6
|
+
*/
|
|
7
|
+
export declare class ConsoleLogger extends Logger {
|
|
8
|
+
/**
|
|
9
|
+
* Logs success message on the console (in green)
|
|
10
|
+
* @param {String} message - message to be displayed
|
|
11
|
+
*/
|
|
12
|
+
success(message: string): void;
|
|
13
|
+
/**
|
|
14
|
+
* Logs error message on the console (in red)
|
|
15
|
+
* @param message - message to be displayed
|
|
16
|
+
*/
|
|
17
|
+
error(message: string): void;
|
|
18
|
+
}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ConsoleLogger = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var chalk_1 = (0, tslib_1.__importDefault)(require("chalk"));
|
|
6
|
+
var logger_1 = require("./logger");
|
|
7
|
+
/**
|
|
8
|
+
* Same properties as Toaster but on the console
|
|
9
|
+
*
|
|
10
|
+
* @class Logger
|
|
11
|
+
*/
|
|
12
|
+
var ConsoleLogger = /** @class */ (function (_super) {
|
|
13
|
+
(0, tslib_1.__extends)(ConsoleLogger, _super);
|
|
14
|
+
function ConsoleLogger() {
|
|
15
|
+
return _super !== null && _super.apply(this, arguments) || this;
|
|
16
|
+
}
|
|
17
|
+
/**
|
|
18
|
+
* Logs success message on the console (in green)
|
|
19
|
+
* @param {String} message - message to be displayed
|
|
20
|
+
*/
|
|
21
|
+
ConsoleLogger.prototype.success = function (message) {
|
|
22
|
+
console.log(chalk_1.default.green(message));
|
|
23
|
+
};
|
|
24
|
+
/**
|
|
25
|
+
* Logs error message on the console (in red)
|
|
26
|
+
* @param message - message to be displayed
|
|
27
|
+
*/
|
|
28
|
+
ConsoleLogger.prototype.error = function (message) {
|
|
29
|
+
console.log(chalk_1.default.red(message));
|
|
30
|
+
};
|
|
31
|
+
return ConsoleLogger;
|
|
32
|
+
}(logger_1.Logger));
|
|
33
|
+
exports.ConsoleLogger = ConsoleLogger;
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ConsoleLogger = exports.Logger = void 0;
|
|
4
|
+
var logger_1 = require("./logger");
|
|
5
|
+
Object.defineProperty(exports, "Logger", { enumerable: true, get: function () { return logger_1.Logger; } });
|
|
6
|
+
var console_logger_1 = require("./console_logger");
|
|
7
|
+
Object.defineProperty(exports, "ConsoleLogger", { enumerable: true, get: function () { return console_logger_1.ConsoleLogger; } });
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
export declare abstract class Logger {
|
|
2
|
+
/**
|
|
3
|
+
* Logs sucess message (in green)
|
|
4
|
+
* @param message - message to be displayed
|
|
5
|
+
*/
|
|
6
|
+
abstract success(message: string): void;
|
|
7
|
+
/**
|
|
8
|
+
* Logs error message (in red)
|
|
9
|
+
* @param message - message to be displayed
|
|
10
|
+
*/
|
|
11
|
+
abstract error(message: string): void;
|
|
12
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { List } from 'immutable';
|
|
3
|
+
import { ConsoleLogger } from '.';
|
|
4
|
+
export declare class TrainerLog {
|
|
5
|
+
epochs: List<number>;
|
|
6
|
+
trainAccuracy: List<number>;
|
|
7
|
+
validationAccuracy: List<number>;
|
|
8
|
+
loss: List<number>;
|
|
9
|
+
add(epoch: number, logs?: tf.Logs): void;
|
|
10
|
+
}
|
|
11
|
+
/**
|
|
12
|
+
*
|
|
13
|
+
* @class TrainerLogger
|
|
14
|
+
*/
|
|
15
|
+
export declare class TrainerLogger extends ConsoleLogger {
|
|
16
|
+
readonly log: TrainerLog;
|
|
17
|
+
readonly saveTrainerLog: boolean;
|
|
18
|
+
constructor(saveTrainerLog?: boolean);
|
|
19
|
+
onEpochEnd(epoch: number, logs?: tf.Logs): void;
|
|
20
|
+
/**
|
|
21
|
+
* Display ram usage
|
|
22
|
+
*/
|
|
23
|
+
ramUsage(): void;
|
|
24
|
+
}
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.TrainerLogger = exports.TrainerLog = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
|
|
6
|
+
var immutable_1 = require("immutable");
|
|
7
|
+
var _1 = require(".");
|
|
8
|
+
var TrainerLog = /** @class */ (function () {
|
|
9
|
+
function TrainerLog() {
|
|
10
|
+
this.epochs = (0, immutable_1.List)();
|
|
11
|
+
this.trainAccuracy = (0, immutable_1.List)();
|
|
12
|
+
this.validationAccuracy = (0, immutable_1.List)();
|
|
13
|
+
this.loss = (0, immutable_1.List)();
|
|
14
|
+
}
|
|
15
|
+
TrainerLog.prototype.add = function (epoch, logs) {
|
|
16
|
+
this.epochs = this.epochs.push(epoch);
|
|
17
|
+
if (logs !== undefined) {
|
|
18
|
+
this.trainAccuracy = this.trainAccuracy.push(logs.acc);
|
|
19
|
+
this.validationAccuracy = this.validationAccuracy.push(logs.val_acc);
|
|
20
|
+
this.loss = this.loss.push(logs.loss);
|
|
21
|
+
}
|
|
22
|
+
};
|
|
23
|
+
return TrainerLog;
|
|
24
|
+
}());
|
|
25
|
+
exports.TrainerLog = TrainerLog;
|
|
26
|
+
/**
|
|
27
|
+
*
|
|
28
|
+
* @class TrainerLogger
|
|
29
|
+
*/
|
|
30
|
+
var TrainerLogger = /** @class */ (function (_super) {
|
|
31
|
+
(0, tslib_1.__extends)(TrainerLogger, _super);
|
|
32
|
+
// TODO: pass savaTrainerLog as false in browser, used for benchmarking
|
|
33
|
+
function TrainerLogger(saveTrainerLog) {
|
|
34
|
+
if (saveTrainerLog === void 0) { saveTrainerLog = true; }
|
|
35
|
+
var _this = _super.call(this) || this;
|
|
36
|
+
_this.saveTrainerLog = saveTrainerLog;
|
|
37
|
+
_this.log = new TrainerLog();
|
|
38
|
+
return _this;
|
|
39
|
+
}
|
|
40
|
+
TrainerLogger.prototype.onEpochEnd = function (epoch, logs) {
|
|
41
|
+
var _a, _b, _c;
|
|
42
|
+
// save logs
|
|
43
|
+
if (this.saveTrainerLog) {
|
|
44
|
+
this.log.add(epoch, logs);
|
|
45
|
+
}
|
|
46
|
+
// console output
|
|
47
|
+
var msg = "Train: " + ((_a = logs === null || logs === void 0 ? void 0 : logs.acc) !== null && _a !== void 0 ? _a : 'undefined') + "\nValidation:" + ((_b = logs === null || logs === void 0 ? void 0 : logs.val_acc) !== null && _b !== void 0 ? _b : 'undefined') + "\nLoss:" + ((_c = logs === null || logs === void 0 ? void 0 : logs.loss) !== null && _c !== void 0 ? _c : 'undefined');
|
|
48
|
+
this.success("On epoch end:\n" + msg);
|
|
49
|
+
};
|
|
50
|
+
/**
|
|
51
|
+
* Display ram usage
|
|
52
|
+
*/
|
|
53
|
+
TrainerLogger.prototype.ramUsage = function () {
|
|
54
|
+
this.success("Training RAM usage is = " + tf.memory().numBytes * 0.000001 + " MB");
|
|
55
|
+
this.success("Number of allocated tensors = " + tf.memory().numTensors);
|
|
56
|
+
};
|
|
57
|
+
return TrainerLogger;
|
|
58
|
+
}(_1.ConsoleLogger));
|
|
59
|
+
exports.TrainerLogger = TrainerLogger;
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { TaskID } from '..';
|
|
3
|
+
import { ModelType } from './model_type';
|
|
4
|
+
export declare abstract class Memory {
|
|
5
|
+
/**
|
|
6
|
+
* Fetches metadata of the model.
|
|
7
|
+
* @param taskID the working model's corresponding task
|
|
8
|
+
* @param modelName the working model's file name
|
|
9
|
+
*/
|
|
10
|
+
abstract getModelMetadata(type: ModelType, taskID: TaskID, modelName: string): Promise<tf.io.ModelArtifactsInfo | undefined>;
|
|
11
|
+
/**
|
|
12
|
+
* Loads the current working model and returns it as a fresh TFJS object.
|
|
13
|
+
* @param taskID the working model's corresponding task
|
|
14
|
+
* @param modelName the working model's file name
|
|
15
|
+
*/
|
|
16
|
+
abstract getModel(type: ModelType, taskID: TaskID, modelName: string): Promise<tf.LayersModel>;
|
|
17
|
+
/**
|
|
18
|
+
* Loads a model from the model library into the current working model.
|
|
19
|
+
* @param taskID the saved model's corresponding task
|
|
20
|
+
* @param modelName the saved model's file name
|
|
21
|
+
*/
|
|
22
|
+
abstract loadSavedModel(taskID: TaskID, modelName: string): Promise<void>;
|
|
23
|
+
/**
|
|
24
|
+
* Loads a fresh TFJS model object into the current working model.
|
|
25
|
+
* @param taskID the working model's corresponding task
|
|
26
|
+
* @param modelName the working model's file name
|
|
27
|
+
* @param model the fresh model
|
|
28
|
+
*/
|
|
29
|
+
abstract updateWorkingModel(taskID: TaskID, modelName: string, model: tf.LayersModel): Promise<void>;
|
|
30
|
+
/**
|
|
31
|
+
* Adds the current working model to the model library.
|
|
32
|
+
* @param taskID the working model's corresponding task
|
|
33
|
+
* @param modelName the working model's file name
|
|
34
|
+
*/
|
|
35
|
+
abstract saveWorkingModel(taskID: TaskID, modelName: string): Promise<void>;
|
|
36
|
+
/**
|
|
37
|
+
* Removes the model from the library.
|
|
38
|
+
* @param taskID the model's corresponding task
|
|
39
|
+
* @param modelName the model's file name
|
|
40
|
+
*/
|
|
41
|
+
abstract deleteModel(type: ModelType, taskID: TaskID, modelName: string): Promise<void>;
|
|
42
|
+
/**
|
|
43
|
+
* Downloads a previously saved model.
|
|
44
|
+
* @param {taskID} taskID the saved model's corresponding task
|
|
45
|
+
* @param {string} modelName the saved model's file name
|
|
46
|
+
*/
|
|
47
|
+
abstract downloadSavedModel(taskID: TaskID, modelName: string): Promise<void>;
|
|
48
|
+
/**
|
|
49
|
+
* @param {taskID} taskID
|
|
50
|
+
* @param {string} modelName
|
|
51
|
+
*/
|
|
52
|
+
abstract contains(modelType: ModelType, taskID: TaskID, modelName: string): Promise<boolean>;
|
|
53
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import { Memory } from './base';
|
|
3
|
+
export declare class Empty extends Memory {
|
|
4
|
+
getModelMetadata(): Promise<undefined>;
|
|
5
|
+
contains(): Promise<boolean>;
|
|
6
|
+
getModel(): Promise<tf.LayersModel>;
|
|
7
|
+
loadSavedModel(): Promise<void>;
|
|
8
|
+
updateWorkingModel(): Promise<void>;
|
|
9
|
+
saveWorkingModel(): Promise<void>;
|
|
10
|
+
deleteModel(): Promise<void>;
|
|
11
|
+
downloadSavedModel(): Promise<void>;
|
|
12
|
+
}
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Empty = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var base_1 = require("./base");
|
|
6
|
+
var Empty = /** @class */ (function (_super) {
|
|
7
|
+
(0, tslib_1.__extends)(Empty, _super);
|
|
8
|
+
function Empty() {
|
|
9
|
+
return _super !== null && _super.apply(this, arguments) || this;
|
|
10
|
+
}
|
|
11
|
+
Empty.prototype.getModelMetadata = function () {
|
|
12
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
13
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
14
|
+
return [2 /*return*/, undefined];
|
|
15
|
+
});
|
|
16
|
+
});
|
|
17
|
+
};
|
|
18
|
+
Empty.prototype.contains = function () {
|
|
19
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
20
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
21
|
+
return [2 /*return*/, false];
|
|
22
|
+
});
|
|
23
|
+
});
|
|
24
|
+
};
|
|
25
|
+
Empty.prototype.getModel = function () {
|
|
26
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
27
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
28
|
+
throw new Error('empty');
|
|
29
|
+
});
|
|
30
|
+
});
|
|
31
|
+
};
|
|
32
|
+
Empty.prototype.loadSavedModel = function () {
|
|
33
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
34
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
35
|
+
throw new Error('empty');
|
|
36
|
+
});
|
|
37
|
+
});
|
|
38
|
+
};
|
|
39
|
+
Empty.prototype.updateWorkingModel = function () {
|
|
40
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
41
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
42
|
+
return [2 /*return*/];
|
|
43
|
+
});
|
|
44
|
+
});
|
|
45
|
+
};
|
|
46
|
+
Empty.prototype.saveWorkingModel = function () {
|
|
47
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
48
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
49
|
+
return [2 /*return*/];
|
|
50
|
+
});
|
|
51
|
+
});
|
|
52
|
+
};
|
|
53
|
+
Empty.prototype.deleteModel = function () {
|
|
54
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
55
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
56
|
+
return [2 /*return*/];
|
|
57
|
+
});
|
|
58
|
+
});
|
|
59
|
+
};
|
|
60
|
+
Empty.prototype.downloadSavedModel = function () {
|
|
61
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
62
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
63
|
+
throw new Error('empty');
|
|
64
|
+
});
|
|
65
|
+
});
|
|
66
|
+
};
|
|
67
|
+
return Empty;
|
|
68
|
+
}(base_1.Memory));
|
|
69
|
+
exports.Empty = Empty;
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ModelType = exports.Memory = exports.Empty = void 0;
|
|
4
|
+
var empty_1 = require("./empty");
|
|
5
|
+
Object.defineProperty(exports, "Empty", { enumerable: true, get: function () { return empty_1.Empty; } });
|
|
6
|
+
var base_1 = require("./base");
|
|
7
|
+
Object.defineProperty(exports, "Memory", { enumerable: true, get: function () { return base_1.Memory; } });
|
|
8
|
+
var model_type_1 = require("./model_type");
|
|
9
|
+
Object.defineProperty(exports, "ModelType", { enumerable: true, get: function () { return model_type_1.ModelType; } });
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ModelType = void 0;
|
|
4
|
+
// Type of model to store & retrieve
|
|
5
|
+
var ModelType;
|
|
6
|
+
(function (ModelType) {
|
|
7
|
+
ModelType["WORKING"] = "working";
|
|
8
|
+
ModelType["SAVED"] = "saved";
|
|
9
|
+
})(ModelType = exports.ModelType || (exports.ModelType = {}));
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import { Logger } from '@/logging/logger';
|
|
2
|
+
import { Task } from '@/task';
|
|
3
|
+
/**
|
|
4
|
+
* Base class for all actors of the system (e.g. trainer, tester, etc.)
|
|
5
|
+
* containing commonly used parameters
|
|
6
|
+
*/
|
|
7
|
+
export declare class ModelActor {
|
|
8
|
+
task: Task;
|
|
9
|
+
logger: Logger;
|
|
10
|
+
/**
|
|
11
|
+
* Constructor for Actor
|
|
12
|
+
* @param {Task} task - task on which the tasking shall be performed
|
|
13
|
+
* @param {Logger} logger - logging system (e.g. toaster)
|
|
14
|
+
*/
|
|
15
|
+
constructor(task: Task, logger: Logger);
|
|
16
|
+
}
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.ModelActor = void 0;
|
|
4
|
+
/**
|
|
5
|
+
* Base class for all actors of the system (e.g. trainer, tester, etc.)
|
|
6
|
+
* containing commonly used parameters
|
|
7
|
+
*/
|
|
8
|
+
var ModelActor = /** @class */ (function () {
|
|
9
|
+
/**
|
|
10
|
+
* Constructor for Actor
|
|
11
|
+
* @param {Task} task - task on which the tasking shall be performed
|
|
12
|
+
* @param {Logger} logger - logging system (e.g. toaster)
|
|
13
|
+
*/
|
|
14
|
+
function ModelActor(task, logger) {
|
|
15
|
+
this.task = task;
|
|
16
|
+
this.logger = logger;
|
|
17
|
+
}
|
|
18
|
+
return ModelActor;
|
|
19
|
+
}());
|
|
20
|
+
exports.ModelActor = ModelActor;
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import { Weights } from '@/types';
|
|
2
|
+
import { Task } from '@/task';
|
|
3
|
+
/**
|
|
4
|
+
* Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
|
|
5
|
+
* The previous round's weights are the last weights pulled from server/peers.
|
|
6
|
+
* The current round's weights are obtained after a single round of training, from the previous round's weights.
|
|
7
|
+
* @param updatedWeights weights from the current round
|
|
8
|
+
* @param staleWeights weights from the previous round
|
|
9
|
+
* @param task the task
|
|
10
|
+
* @returns the noised weights for the current round
|
|
11
|
+
*/
|
|
12
|
+
export declare function addDifferentialPrivacy(updatedWeights: Weights, staleWeights: Weights, task: Task): Weights;
|
package/dist/privacy.js
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.addDifferentialPrivacy = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var tf = (0, tslib_1.__importStar)(require("@tensorflow/tfjs"));
|
|
7
|
+
/**
|
|
8
|
+
* Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
|
|
9
|
+
* The previous round's weights are the last weights pulled from server/peers.
|
|
10
|
+
* The current round's weights are obtained after a single round of training, from the previous round's weights.
|
|
11
|
+
* @param updatedWeights weights from the current round
|
|
12
|
+
* @param staleWeights weights from the previous round
|
|
13
|
+
* @param task the task
|
|
14
|
+
* @returns the noised weights for the current round
|
|
15
|
+
*/
|
|
16
|
+
function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
|
|
17
|
+
var _a, _b;
|
|
18
|
+
var noiseScale = (_a = task.trainingInformation) === null || _a === void 0 ? void 0 : _a.noiseScale;
|
|
19
|
+
var clippingRadius = (_b = task.trainingInformation) === null || _b === void 0 ? void 0 : _b.clippingRadius;
|
|
20
|
+
var weightsDiff = (0, immutable_1.List)(updatedWeights)
|
|
21
|
+
.zip((0, immutable_1.List)(staleWeights))
|
|
22
|
+
.map(function (_a) {
|
|
23
|
+
var _b = (0, tslib_1.__read)(_a, 2), w1 = _b[0], w2 = _b[1];
|
|
24
|
+
return w1.add(-w2);
|
|
25
|
+
});
|
|
26
|
+
var newWeightsDiff;
|
|
27
|
+
if (clippingRadius !== undefined) {
|
|
28
|
+
// Frobenius norm
|
|
29
|
+
var norm_1 = Math.sqrt(weightsDiff.map(function (w) { return w.square().sum().dataSync()[0]; }).reduce(function (a, b) { return a + b; }));
|
|
30
|
+
newWeightsDiff = weightsDiff.map(function (w) {
|
|
31
|
+
var clipped = w.div(Math.max(1, norm_1 / clippingRadius));
|
|
32
|
+
if (noiseScale !== undefined) {
|
|
33
|
+
// Add clipping and noise
|
|
34
|
+
var noise = tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
|
|
35
|
+
return clipped.add(noise);
|
|
36
|
+
}
|
|
37
|
+
else {
|
|
38
|
+
// Add clipping without any noise
|
|
39
|
+
return clipped;
|
|
40
|
+
}
|
|
41
|
+
});
|
|
42
|
+
}
|
|
43
|
+
else {
|
|
44
|
+
if (noiseScale !== undefined) {
|
|
45
|
+
// Add noise without any clipping
|
|
46
|
+
newWeightsDiff = weightsDiff.map(function (w) { return tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)); });
|
|
47
|
+
}
|
|
48
|
+
else {
|
|
49
|
+
return updatedWeights;
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
return (0, immutable_1.List)(staleWeights)
|
|
53
|
+
.zip(newWeightsDiff)
|
|
54
|
+
.map(function (_a) {
|
|
55
|
+
var _b = (0, tslib_1.__read)(_a, 2), w = _b[0], d = _b[1];
|
|
56
|
+
return w.add(d);
|
|
57
|
+
})
|
|
58
|
+
.toArray();
|
|
59
|
+
}
|
|
60
|
+
exports.addDifferentialPrivacy = addDifferentialPrivacy;
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.weights = exports.model = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
exports.model = (0, tslib_1.__importStar)(require("./model"));
|
|
6
|
+
exports.weights = (0, tslib_1.__importStar)(require("./weights"));
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
export declare type Encoded = number[];
|
|
3
|
+
export declare function isEncoded(raw: unknown): raw is Encoded;
|
|
4
|
+
export declare function encode(model: tf.LayersModel): Promise<Encoded>;
|
|
5
|
+
export declare function decode(encoded: Encoded): Promise<tf.LayersModel>;
|