@epfml/discojs 1.0.0 → 2.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +28 -8
- package/dist/{async_buffer.d.ts → core/async_buffer.d.ts} +3 -3
- package/dist/{async_buffer.js → core/async_buffer.js} +5 -6
- package/dist/{async_informant.d.ts → core/async_informant.d.ts} +0 -0
- package/dist/{async_informant.js → core/async_informant.js} +0 -0
- package/dist/{client → core/client}/base.d.ts +4 -7
- package/dist/{client → core/client}/base.js +3 -2
- package/dist/core/client/decentralized/base.d.ts +32 -0
- package/dist/core/client/decentralized/base.js +212 -0
- package/dist/core/client/decentralized/clear_text.d.ts +14 -0
- package/dist/core/client/decentralized/clear_text.js +96 -0
- package/dist/{client → core/client}/decentralized/index.d.ts +0 -0
- package/dist/{client → core/client}/decentralized/index.js +0 -0
- package/dist/core/client/decentralized/messages.d.ts +41 -0
- package/dist/core/client/decentralized/messages.js +54 -0
- package/dist/core/client/decentralized/peer.d.ts +26 -0
- package/dist/core/client/decentralized/peer.js +210 -0
- package/dist/core/client/decentralized/peer_pool.d.ts +14 -0
- package/dist/core/client/decentralized/peer_pool.js +92 -0
- package/dist/core/client/decentralized/sec_agg.d.ts +22 -0
- package/dist/core/client/decentralized/sec_agg.js +190 -0
- package/dist/core/client/decentralized/secret_shares.d.ts +3 -0
- package/dist/core/client/decentralized/secret_shares.js +39 -0
- package/dist/core/client/decentralized/types.d.ts +2 -0
- package/dist/core/client/decentralized/types.js +7 -0
- package/dist/core/client/event_connection.d.ts +37 -0
- package/dist/core/client/event_connection.js +158 -0
- package/dist/core/client/federated/client.d.ts +37 -0
- package/dist/core/client/federated/client.js +273 -0
- package/dist/core/client/federated/index.d.ts +2 -0
- package/dist/core/client/federated/index.js +7 -0
- package/dist/core/client/federated/messages.d.ts +38 -0
- package/dist/core/client/federated/messages.js +25 -0
- package/dist/{client → core/client}/index.d.ts +2 -1
- package/dist/{client → core/client}/index.js +3 -3
- package/dist/{client → core/client}/local.d.ts +2 -2
- package/dist/{client → core/client}/local.js +0 -0
- package/dist/core/client/messages.d.ts +28 -0
- package/dist/core/client/messages.js +33 -0
- package/dist/core/client/utils.d.ts +2 -0
- package/dist/core/client/utils.js +19 -0
- package/dist/core/dataset/data/data.d.ts +11 -0
- package/dist/core/dataset/data/data.js +20 -0
- package/dist/core/dataset/data/data_split.d.ts +5 -0
- package/dist/{client/decentralized/types.js → core/dataset/data/data_split.js} +0 -0
- package/dist/core/dataset/data/image_data.d.ts +8 -0
- package/dist/core/dataset/data/image_data.js +64 -0
- package/dist/core/dataset/data/index.d.ts +5 -0
- package/dist/core/dataset/data/index.js +11 -0
- package/dist/core/dataset/data/preprocessing.d.ts +13 -0
- package/dist/core/dataset/data/preprocessing.js +33 -0
- package/dist/core/dataset/data/tabular_data.d.ts +8 -0
- package/dist/core/dataset/data/tabular_data.js +40 -0
- package/dist/{dataset → core/dataset}/data_loader/data_loader.d.ts +4 -11
- package/dist/{dataset → core/dataset}/data_loader/data_loader.js +0 -0
- package/dist/core/dataset/data_loader/image_loader.d.ts +17 -0
- package/dist/core/dataset/data_loader/image_loader.js +141 -0
- package/dist/core/dataset/data_loader/index.d.ts +3 -0
- package/dist/core/dataset/data_loader/index.js +9 -0
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +29 -0
- package/dist/core/dataset/data_loader/tabular_loader.js +101 -0
- package/dist/core/dataset/dataset.d.ts +2 -0
- package/dist/{task/training_information.js → core/dataset/dataset.js} +0 -0
- package/dist/{dataset → core/dataset}/dataset_builder.d.ts +5 -5
- package/dist/{dataset → core/dataset}/dataset_builder.js +14 -10
- package/dist/core/dataset/index.d.ts +4 -0
- package/dist/core/dataset/index.js +14 -0
- package/dist/core/default_tasks/cifar10.d.ts +2 -0
- package/dist/core/default_tasks/cifar10.js +68 -0
- package/dist/core/default_tasks/geotags.d.ts +2 -0
- package/dist/core/default_tasks/geotags.js +69 -0
- package/dist/core/default_tasks/index.d.ts +6 -0
- package/dist/core/default_tasks/index.js +15 -0
- package/dist/core/default_tasks/lus_covid.d.ts +2 -0
- package/dist/core/default_tasks/lus_covid.js +96 -0
- package/dist/core/default_tasks/mnist.d.ts +2 -0
- package/dist/core/default_tasks/mnist.js +69 -0
- package/dist/core/default_tasks/simple_face.d.ts +2 -0
- package/dist/core/default_tasks/simple_face.js +53 -0
- package/dist/core/default_tasks/titanic.d.ts +2 -0
- package/dist/core/default_tasks/titanic.js +97 -0
- package/dist/core/index.d.ts +18 -0
- package/dist/core/index.js +39 -0
- package/dist/{informant → core/informant}/graph_informant.d.ts +0 -0
- package/dist/{informant → core/informant}/graph_informant.js +0 -0
- package/dist/{informant → core/informant}/index.d.ts +0 -0
- package/dist/{informant → core/informant}/index.js +0 -0
- package/dist/{informant → core/informant}/training_informant/base.d.ts +3 -3
- package/dist/{informant → core/informant}/training_informant/base.js +3 -2
- package/dist/{informant → core/informant}/training_informant/decentralized.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/decentralized.js +0 -0
- package/dist/{informant → core/informant}/training_informant/federated.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/federated.js +0 -0
- package/dist/{informant → core/informant}/training_informant/index.d.ts +0 -0
- package/dist/{informant → core/informant}/training_informant/index.js +0 -0
- package/dist/{informant → core/informant}/training_informant/local.d.ts +2 -2
- package/dist/{informant → core/informant}/training_informant/local.js +2 -2
- package/dist/{logging → core/logging}/console_logger.d.ts +0 -0
- package/dist/{logging → core/logging}/console_logger.js +0 -0
- package/dist/{logging → core/logging}/index.d.ts +0 -0
- package/dist/{logging → core/logging}/index.js +0 -0
- package/dist/{logging → core/logging}/logger.d.ts +0 -0
- package/dist/{logging → core/logging}/logger.js +0 -0
- package/dist/{logging → core/logging}/trainer_logger.d.ts +0 -0
- package/dist/{logging → core/logging}/trainer_logger.js +0 -0
- package/dist/{memory → core/memory}/base.d.ts +2 -2
- package/dist/{memory → core/memory}/base.js +0 -0
- package/dist/{memory → core/memory}/empty.d.ts +0 -0
- package/dist/{memory → core/memory}/empty.js +0 -0
- package/dist/core/memory/index.d.ts +3 -0
- package/dist/core/memory/index.js +9 -0
- package/dist/{memory → core/memory}/model_type.d.ts +0 -0
- package/dist/{memory → core/memory}/model_type.js +0 -0
- package/dist/{privacy.d.ts → core/privacy.d.ts} +2 -3
- package/dist/{privacy.js → core/privacy.js} +3 -16
- package/dist/{serialization → core/serialization}/index.d.ts +0 -0
- package/dist/{serialization → core/serialization}/index.js +0 -0
- package/dist/{serialization → core/serialization}/model.d.ts +0 -0
- package/dist/{serialization → core/serialization}/model.js +0 -0
- package/dist/core/serialization/weights.d.ts +5 -0
- package/dist/{serialization → core/serialization}/weights.js +11 -9
- package/dist/{task → core/task}/data_example.d.ts +0 -0
- package/dist/{task → core/task}/data_example.js +0 -0
- package/dist/core/task/digest.d.ts +5 -0
- package/dist/core/task/digest.js +18 -0
- package/dist/{task → core/task}/display_information.d.ts +5 -5
- package/dist/{task → core/task}/display_information.js +5 -10
- package/dist/{task → core/task}/index.d.ts +3 -0
- package/dist/core/task/index.js +15 -0
- package/dist/core/task/model_compile_data.d.ts +6 -0
- package/dist/core/task/model_compile_data.js +22 -0
- package/dist/{task → core/task}/summary.d.ts +0 -0
- package/dist/{task → core/task}/summary.js +0 -4
- package/dist/{task → core/task}/task.d.ts +4 -2
- package/dist/{task → core/task}/task.js +10 -7
- package/dist/core/task/task_handler.d.ts +5 -0
- package/dist/core/task/task_handler.js +53 -0
- package/dist/core/task/task_provider.d.ts +6 -0
- package/dist/core/task/task_provider.js +13 -0
- package/dist/{task → core/task}/training_information.d.ts +10 -14
- package/dist/core/task/training_information.js +66 -0
- package/dist/core/training/disco.d.ts +23 -0
- package/dist/core/training/disco.js +130 -0
- package/dist/{training → core/training}/index.d.ts +0 -0
- package/dist/{training → core/training}/index.js +0 -0
- package/dist/{training → core/training}/trainer/distributed_trainer.d.ts +1 -2
- package/dist/{training → core/training}/trainer/distributed_trainer.js +6 -5
- package/dist/{training → core/training}/trainer/local_trainer.d.ts +2 -2
- package/dist/{training → core/training}/trainer/local_trainer.js +0 -0
- package/dist/{training → core/training}/trainer/round_tracker.d.ts +0 -0
- package/dist/{training → core/training}/trainer/round_tracker.js +0 -0
- package/dist/{training → core/training}/trainer/trainer.d.ts +1 -2
- package/dist/{training → core/training}/trainer/trainer.js +2 -2
- package/dist/{training → core/training}/trainer/trainer_builder.d.ts +0 -0
- package/dist/{training → core/training}/trainer/trainer_builder.js +0 -0
- package/dist/core/training/training_schemes.d.ts +5 -0
- package/dist/{training → core/training}/training_schemes.js +2 -2
- package/dist/{types.d.ts → core/types.d.ts} +0 -0
- package/dist/{types.js → core/types.js} +0 -0
- package/dist/{validation → core/validation}/index.d.ts +0 -0
- package/dist/{validation → core/validation}/index.js +0 -0
- package/dist/{validation → core/validation}/validator.d.ts +5 -8
- package/dist/{validation → core/validation}/validator.js +9 -11
- package/dist/core/weights/aggregation.d.ts +7 -0
- package/dist/core/weights/aggregation.js +72 -0
- package/dist/core/weights/index.d.ts +2 -0
- package/dist/core/weights/index.js +7 -0
- package/dist/core/weights/weights_container.d.ts +19 -0
- package/dist/core/weights/weights_container.js +64 -0
- package/dist/dataset/data_loader/image_loader.d.ts +3 -15
- package/dist/dataset/data_loader/image_loader.js +12 -125
- package/dist/dataset/data_loader/index.d.ts +2 -3
- package/dist/dataset/data_loader/index.js +3 -5
- package/dist/dataset/data_loader/tabular_loader.d.ts +3 -28
- package/dist/dataset/data_loader/tabular_loader.js +11 -92
- package/dist/imports.d.ts +2 -0
- package/dist/imports.js +7 -0
- package/dist/index.d.ts +2 -19
- package/dist/index.js +3 -39
- package/dist/memory/index.d.ts +1 -3
- package/dist/memory/index.js +3 -7
- package/dist/memory/memory.d.ts +26 -0
- package/dist/memory/memory.js +160 -0
- package/package.json +13 -26
- package/dist/aggregation.d.ts +0 -5
- package/dist/aggregation.js +0 -33
- package/dist/client/decentralized/base.d.ts +0 -43
- package/dist/client/decentralized/base.js +0 -243
- package/dist/client/decentralized/clear_text.d.ts +0 -13
- package/dist/client/decentralized/clear_text.js +0 -78
- package/dist/client/decentralized/messages.d.ts +0 -37
- package/dist/client/decentralized/messages.js +0 -15
- package/dist/client/decentralized/sec_agg.d.ts +0 -18
- package/dist/client/decentralized/sec_agg.js +0 -169
- package/dist/client/decentralized/secret_shares.d.ts +0 -5
- package/dist/client/decentralized/secret_shares.js +0 -58
- package/dist/client/decentralized/types.d.ts +0 -1
- package/dist/client/federated.d.ts +0 -30
- package/dist/client/federated.js +0 -218
- package/dist/dataset/index.d.ts +0 -2
- package/dist/dataset/index.js +0 -7
- package/dist/model_actor.d.ts +0 -16
- package/dist/model_actor.js +0 -20
- package/dist/serialization/weights.d.ts +0 -5
- package/dist/task/index.js +0 -8
- package/dist/task/model_compile_data.d.ts +0 -6
- package/dist/task/model_compile_data.js +0 -12
- package/dist/tasks/cifar10.d.ts +0 -4
- package/dist/tasks/cifar10.js +0 -76
- package/dist/tasks/index.d.ts +0 -5
- package/dist/tasks/index.js +0 -9
- package/dist/tasks/lus_covid.d.ts +0 -4
- package/dist/tasks/lus_covid.js +0 -85
- package/dist/tasks/mnist.d.ts +0 -4
- package/dist/tasks/mnist.js +0 -58
- package/dist/tasks/simple_face.d.ts +0 -4
- package/dist/tasks/simple_face.js +0 -84
- package/dist/tasks/titanic.d.ts +0 -4
- package/dist/tasks/titanic.js +0 -88
- package/dist/tfjs.d.ts +0 -2
- package/dist/tfjs.js +0 -6
- package/dist/training/disco.d.ts +0 -14
- package/dist/training/disco.js +0 -70
- package/dist/training/training_schemes.d.ts +0 -5
|
@@ -1,32 +1,28 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { Preprocessing } from '../dataset/data/preprocessing';
|
|
2
2
|
import { ModelCompileData } from './model_compile_data';
|
|
3
|
+
export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
|
|
3
4
|
export interface TrainingInformation {
|
|
4
5
|
modelID: string;
|
|
5
6
|
epochs: number;
|
|
6
7
|
roundDuration: number;
|
|
7
8
|
validationSplit: number;
|
|
8
9
|
batchSize: number;
|
|
9
|
-
|
|
10
|
+
preprocessingFunctions?: Preprocessing[];
|
|
10
11
|
modelCompileData: ModelCompileData;
|
|
11
12
|
dataType: string;
|
|
12
|
-
maxShareValue?: number;
|
|
13
|
-
minimumReadyPeers?: number;
|
|
14
|
-
decentralizedSecure?: boolean;
|
|
15
|
-
receivedMessagesThreshold?: number;
|
|
16
13
|
inputColumns?: string[];
|
|
17
14
|
outputColumns?: string[];
|
|
18
|
-
threshold?: number;
|
|
19
15
|
IMAGE_H?: number;
|
|
20
16
|
IMAGE_W?: number;
|
|
17
|
+
modelURL?: string;
|
|
21
18
|
LABEL_LIST?: string[];
|
|
22
|
-
aggregateImagesById?: boolean;
|
|
23
19
|
learningRate?: number;
|
|
24
|
-
|
|
25
|
-
csvLabels?: boolean;
|
|
26
|
-
RESIZED_IMAGE_H?: number;
|
|
27
|
-
RESIZED_IMAGE_W?: number;
|
|
28
|
-
LABEL_ASSIGNMENT?: DataExample[];
|
|
29
|
-
scheme?: string;
|
|
20
|
+
scheme: string;
|
|
30
21
|
noiseScale?: number;
|
|
31
22
|
clippingRadius?: number;
|
|
23
|
+
decentralizedSecure?: boolean;
|
|
24
|
+
byzantineRobustAggregator?: boolean;
|
|
25
|
+
tauPercentile?: number;
|
|
26
|
+
maxShareValue?: number;
|
|
27
|
+
minimumReadyPeers?: number;
|
|
32
28
|
}
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.isTrainingInformation = void 0;
|
|
4
|
+
var model_compile_data_1 = require("./model_compile_data");
|
|
5
|
+
function isTrainingInformation(raw) {
|
|
6
|
+
if (typeof raw !== 'object') {
|
|
7
|
+
return false;
|
|
8
|
+
}
|
|
9
|
+
if (raw === null) {
|
|
10
|
+
return false;
|
|
11
|
+
}
|
|
12
|
+
var _a = raw, dataType = _a.dataType, scheme = _a.scheme, epochs = _a.epochs,
|
|
13
|
+
// roundDuration,
|
|
14
|
+
validationSplit = _a.validationSplit, batchSize = _a.batchSize, modelCompileData = _a.modelCompileData, modelID = _a.modelID, preprocessingFunctions = _a.preprocessingFunctions, inputColumns = _a.inputColumns, outputColumns = _a.outputColumns, IMAGE_H = _a.IMAGE_H, IMAGE_W = _a.IMAGE_W, roundDuration = _a.roundDuration, modelURL = _a.modelURL, learningRate = _a.learningRate, decentralizedSecure = _a.decentralizedSecure, maxShareValue = _a.maxShareValue, minimumReadyPeers = _a.minimumReadyPeers, LABEL_LIST = _a.LABEL_LIST, noiseScale = _a.noiseScale, clippingRadius = _a.clippingRadius;
|
|
15
|
+
if (typeof dataType !== 'string' ||
|
|
16
|
+
typeof modelID !== 'string' ||
|
|
17
|
+
typeof epochs !== 'number' ||
|
|
18
|
+
typeof batchSize !== 'number' ||
|
|
19
|
+
typeof roundDuration !== 'number' ||
|
|
20
|
+
typeof validationSplit !== 'number' ||
|
|
21
|
+
(modelURL !== undefined && typeof modelURL !== 'string') ||
|
|
22
|
+
(noiseScale !== undefined && typeof noiseScale !== 'number') ||
|
|
23
|
+
(clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
|
|
24
|
+
(learningRate !== undefined && typeof learningRate !== 'number') ||
|
|
25
|
+
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
26
|
+
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
27
|
+
(minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number')) {
|
|
28
|
+
return false;
|
|
29
|
+
}
|
|
30
|
+
// interdepences on data type
|
|
31
|
+
switch (dataType) {
|
|
32
|
+
case 'image':
|
|
33
|
+
if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
|
|
34
|
+
return false;
|
|
35
|
+
}
|
|
36
|
+
break;
|
|
37
|
+
case 'tabular':
|
|
38
|
+
if (!(Array.isArray(inputColumns) && inputColumns.every(function (e) { return typeof e === 'string'; }))) {
|
|
39
|
+
return false;
|
|
40
|
+
}
|
|
41
|
+
if (!(Array.isArray(outputColumns) && outputColumns.every(function (e) { return typeof e === 'string'; }))) {
|
|
42
|
+
return false;
|
|
43
|
+
}
|
|
44
|
+
break;
|
|
45
|
+
}
|
|
46
|
+
// interdepences on scheme
|
|
47
|
+
switch (scheme) {
|
|
48
|
+
case 'decentralized':
|
|
49
|
+
break;
|
|
50
|
+
case 'federated':
|
|
51
|
+
break;
|
|
52
|
+
case 'local':
|
|
53
|
+
break;
|
|
54
|
+
}
|
|
55
|
+
if (!(0, model_compile_data_1.isModelCompileData)(modelCompileData)) {
|
|
56
|
+
return false;
|
|
57
|
+
}
|
|
58
|
+
if (LABEL_LIST !== undefined && !(Array.isArray(LABEL_LIST) && LABEL_LIST.every(function (e) { return typeof e === 'string'; }))) {
|
|
59
|
+
return false;
|
|
60
|
+
}
|
|
61
|
+
if (preprocessingFunctions !== undefined && !(Array.isArray(preprocessingFunctions) && preprocessingFunctions.every(function (e) { return typeof e === 'string'; }))) {
|
|
62
|
+
return false;
|
|
63
|
+
}
|
|
64
|
+
return true;
|
|
65
|
+
}
|
|
66
|
+
exports.isTrainingInformation = isTrainingInformation;
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import { Client, data, Logger, Task, TrainingInformant, TrainingSchemes, Memory } from '..';
|
|
2
|
+
import { TrainerLog } from '../logging/trainer_logger';
|
|
3
|
+
interface DiscoOptions {
|
|
4
|
+
client?: Client;
|
|
5
|
+
url?: string | URL;
|
|
6
|
+
scheme?: TrainingSchemes;
|
|
7
|
+
informant?: TrainingInformant;
|
|
8
|
+
logger?: Logger;
|
|
9
|
+
memory?: Memory;
|
|
10
|
+
}
|
|
11
|
+
export declare class Disco {
|
|
12
|
+
readonly task: Task;
|
|
13
|
+
readonly logger: Logger;
|
|
14
|
+
readonly memory: Memory;
|
|
15
|
+
private readonly client;
|
|
16
|
+
private readonly trainer;
|
|
17
|
+
constructor(task: Task, options: DiscoOptions);
|
|
18
|
+
fit(dataTuple: data.DataSplit): Promise<void>;
|
|
19
|
+
pause(): Promise<void>;
|
|
20
|
+
close(): Promise<void>;
|
|
21
|
+
logs(): Promise<TrainerLog>;
|
|
22
|
+
}
|
|
23
|
+
export {};
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.Disco = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var __1 = require("..");
|
|
6
|
+
var trainer_builder_1 = require("./trainer/trainer_builder");
|
|
7
|
+
// Handles the training loop, server communication & provides the user with feedback.
|
|
8
|
+
var Disco = /** @class */ (function () {
|
|
9
|
+
// client need to be connected
|
|
10
|
+
function Disco(task, options) {
|
|
11
|
+
if (options.scheme === undefined) {
|
|
12
|
+
options.scheme = __1.TrainingSchemes[task.trainingInformation.scheme];
|
|
13
|
+
}
|
|
14
|
+
if (options.client === undefined) {
|
|
15
|
+
if (options.url === undefined) {
|
|
16
|
+
throw new Error('could not determine client from given parameters');
|
|
17
|
+
}
|
|
18
|
+
if (typeof options.url === 'string') {
|
|
19
|
+
options.url = new URL(options.url);
|
|
20
|
+
}
|
|
21
|
+
switch (options.scheme) {
|
|
22
|
+
case __1.TrainingSchemes.FEDERATED:
|
|
23
|
+
options.client = new __1.client.federated.Client(options.url, task);
|
|
24
|
+
break;
|
|
25
|
+
case __1.TrainingSchemes.DECENTRALIZED:
|
|
26
|
+
options.client = new __1.client.federated.Client(options.url, task);
|
|
27
|
+
break;
|
|
28
|
+
default:
|
|
29
|
+
options.client = new __1.client.Local(options.url, task);
|
|
30
|
+
break;
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
if (options.informant === undefined) {
|
|
34
|
+
switch (options.scheme) {
|
|
35
|
+
case __1.TrainingSchemes.FEDERATED:
|
|
36
|
+
options.informant = new __1.informant.FederatedInformant(task);
|
|
37
|
+
break;
|
|
38
|
+
case __1.TrainingSchemes.DECENTRALIZED:
|
|
39
|
+
options.informant = new __1.informant.DecentralizedInformant(task);
|
|
40
|
+
break;
|
|
41
|
+
default:
|
|
42
|
+
options.informant = new __1.informant.LocalInformant(task);
|
|
43
|
+
break;
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
if (options.logger === undefined) {
|
|
47
|
+
options.logger = new __1.ConsoleLogger();
|
|
48
|
+
}
|
|
49
|
+
if (options.memory === undefined) {
|
|
50
|
+
options.memory = new __1.EmptyMemory();
|
|
51
|
+
}
|
|
52
|
+
if (options.client.task !== task) {
|
|
53
|
+
throw new Error('client not setup for given task');
|
|
54
|
+
}
|
|
55
|
+
if (options.informant.task.taskID !== task.taskID) {
|
|
56
|
+
throw new Error('informant not setup for given task');
|
|
57
|
+
}
|
|
58
|
+
this.task = task;
|
|
59
|
+
this.client = options.client;
|
|
60
|
+
this.memory = options.memory;
|
|
61
|
+
this.logger = options.logger;
|
|
62
|
+
var trainerBuilder = new trainer_builder_1.TrainerBuilder(this.memory, this.task, options.informant);
|
|
63
|
+
this.trainer = trainerBuilder.build(this.client, options.scheme !== __1.TrainingSchemes.LOCAL);
|
|
64
|
+
}
|
|
65
|
+
Disco.prototype.fit = function (dataTuple) {
|
|
66
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
67
|
+
var trainDataset, valDataset;
|
|
68
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
69
|
+
switch (_a.label) {
|
|
70
|
+
case 0:
|
|
71
|
+
this.logger.success('Thank you for your contribution. Data preprocessing has started');
|
|
72
|
+
trainDataset = dataTuple.train.batch().preprocess();
|
|
73
|
+
valDataset = dataTuple.validation !== undefined
|
|
74
|
+
? dataTuple.validation.batch().preprocess()
|
|
75
|
+
: trainDataset;
|
|
76
|
+
return [4 /*yield*/, this.client.connect()];
|
|
77
|
+
case 1:
|
|
78
|
+
_a.sent();
|
|
79
|
+
return [4 /*yield*/, this.trainer];
|
|
80
|
+
case 2: return [4 /*yield*/, (_a.sent()).trainModel(trainDataset.dataset, valDataset.dataset)];
|
|
81
|
+
case 3:
|
|
82
|
+
_a.sent();
|
|
83
|
+
return [2 /*return*/];
|
|
84
|
+
}
|
|
85
|
+
});
|
|
86
|
+
});
|
|
87
|
+
};
|
|
88
|
+
// Stops the training function. Does not disconnect the client.
|
|
89
|
+
Disco.prototype.pause = function () {
|
|
90
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
91
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
92
|
+
switch (_a.label) {
|
|
93
|
+
case 0: return [4 /*yield*/, this.trainer];
|
|
94
|
+
case 1: return [4 /*yield*/, (_a.sent()).stopTraining()];
|
|
95
|
+
case 2:
|
|
96
|
+
_a.sent();
|
|
97
|
+
this.logger.success('Training was successfully interrupted.');
|
|
98
|
+
return [2 /*return*/];
|
|
99
|
+
}
|
|
100
|
+
});
|
|
101
|
+
});
|
|
102
|
+
};
|
|
103
|
+
Disco.prototype.close = function () {
|
|
104
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
105
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
106
|
+
switch (_a.label) {
|
|
107
|
+
case 0: return [4 /*yield*/, this.pause()];
|
|
108
|
+
case 1:
|
|
109
|
+
_a.sent();
|
|
110
|
+
return [4 /*yield*/, this.client.disconnect()];
|
|
111
|
+
case 2:
|
|
112
|
+
_a.sent();
|
|
113
|
+
return [2 /*return*/];
|
|
114
|
+
}
|
|
115
|
+
});
|
|
116
|
+
});
|
|
117
|
+
};
|
|
118
|
+
Disco.prototype.logs = function () {
|
|
119
|
+
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
120
|
+
return (0, tslib_1.__generator)(this, function (_a) {
|
|
121
|
+
switch (_a.label) {
|
|
122
|
+
case 0: return [4 /*yield*/, this.trainer];
|
|
123
|
+
case 1: return [2 /*return*/, (_a.sent()).getTrainerLog()];
|
|
124
|
+
}
|
|
125
|
+
});
|
|
126
|
+
});
|
|
127
|
+
};
|
|
128
|
+
return Disco;
|
|
129
|
+
}());
|
|
130
|
+
exports.Disco = Disco;
|
|
File without changes
|
|
File without changes
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
import
|
|
2
|
-
import { Client, Memory, Task, TrainingInformant } from '../..';
|
|
1
|
+
import { tf, Client, Memory, Task, TrainingInformant } from '../..';
|
|
3
2
|
import { Trainer } from './trainer';
|
|
4
3
|
/**
|
|
5
4
|
* Class whose role is to train a model in a distributed way with a given dataset.
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.DistributedTrainer = void 0;
|
|
4
4
|
var tslib_1 = require("tslib");
|
|
5
|
+
var __1 = require("../..");
|
|
5
6
|
var trainer_1 = require("./trainer");
|
|
6
7
|
/**
|
|
7
8
|
* Class whose role is to train a model in a distributed way with a given dataset.
|
|
@@ -25,13 +26,13 @@ var DistributedTrainer = /** @class */ (function (_super) {
|
|
|
25
26
|
return (0, tslib_1.__generator)(this, function (_a) {
|
|
26
27
|
switch (_a.label) {
|
|
27
28
|
case 0:
|
|
28
|
-
currentRoundWeights =
|
|
29
|
-
previousRoundWeights =
|
|
29
|
+
currentRoundWeights = __1.WeightsContainer.from(this.model);
|
|
30
|
+
previousRoundWeights = __1.WeightsContainer.from(this.previousRoundModel);
|
|
30
31
|
return [4 /*yield*/, this.client.onRoundEndCommunication(currentRoundWeights, previousRoundWeights, this.roundTracker.round, this.trainingInformant)];
|
|
31
32
|
case 1:
|
|
32
33
|
aggregatedWeights = _a.sent();
|
|
33
|
-
this.previousRoundModel.setWeights(currentRoundWeights);
|
|
34
|
-
this.model.setWeights(aggregatedWeights);
|
|
34
|
+
this.previousRoundModel.setWeights(currentRoundWeights.weights);
|
|
35
|
+
this.model.setWeights(aggregatedWeights.weights);
|
|
35
36
|
return [4 /*yield*/, this.memory.updateWorkingModel({ taskID: this.task.taskID, name: this.trainingInformation.modelID }, this.model)];
|
|
36
37
|
case 2:
|
|
37
38
|
_a.sent();
|
|
@@ -48,7 +49,7 @@ var DistributedTrainer = /** @class */ (function (_super) {
|
|
|
48
49
|
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
49
50
|
return (0, tslib_1.__generator)(this, function (_a) {
|
|
50
51
|
switch (_a.label) {
|
|
51
|
-
case 0: return [4 /*yield*/, this.client.onTrainEndCommunication(
|
|
52
|
+
case 0: return [4 /*yield*/, this.client.onTrainEndCommunication(__1.WeightsContainer.from(this.model), this.trainingInformant)];
|
|
52
53
|
case 1:
|
|
53
54
|
_a.sent();
|
|
54
55
|
return [4 /*yield*/, _super.prototype.onTrainEnd.call(this)];
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
+
import { tf } from '../..';
|
|
1
2
|
import { Trainer } from './trainer';
|
|
2
|
-
import { Logs } from '@tensorflow/tfjs';
|
|
3
3
|
/** Class whose role is to locally (alone) train a model on a given dataset, without any collaborators.
|
|
4
4
|
*/
|
|
5
5
|
export declare class LocalTrainer extends Trainer {
|
|
@@ -7,5 +7,5 @@ export declare class LocalTrainer extends Trainer {
|
|
|
7
7
|
* Callback called every time a round is over. For local training, a round is typically an epoch
|
|
8
8
|
*/
|
|
9
9
|
onRoundEnd(accuracy: number): Promise<void>;
|
|
10
|
-
protected onEpochEnd(epoch: number, logs?: Logs): void;
|
|
10
|
+
protected onEpochEnd(epoch: number, logs?: tf.Logs): void;
|
|
11
11
|
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
import
|
|
2
|
-
import { Memory, Task, TrainingInformant, TrainingInformation } from '@/.';
|
|
1
|
+
import { tf, Memory, Task, TrainingInformant, TrainingInformation } from '../..';
|
|
3
2
|
import { RoundTracker } from './round_tracker';
|
|
4
3
|
import { TrainerLog } from '../../logging/trainer_logger';
|
|
5
4
|
/** Abstract class whose role is to train a model with a given dataset. This can be either done
|
|
@@ -102,9 +102,9 @@ var Trainer = /** @class */ (function () {
|
|
|
102
102
|
case 0:
|
|
103
103
|
this.resetStopTrainerState();
|
|
104
104
|
// Assign callbacks and start training
|
|
105
|
-
return [4 /*yield*/, this.model.fitDataset(dataset
|
|
105
|
+
return [4 /*yield*/, this.model.fitDataset(dataset, {
|
|
106
106
|
epochs: this.trainingInformation.epochs,
|
|
107
|
-
validationData: valDataset
|
|
107
|
+
validationData: valDataset,
|
|
108
108
|
callbacks: {
|
|
109
109
|
onEpochEnd: function (epoch, logs) { return _this.onEpochEnd(epoch, logs); },
|
|
110
110
|
onBatchEnd: function (epoch, logs) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
|
|
File without changes
|
|
File without changes
|
|
@@ -5,6 +5,6 @@ exports.TrainingSchemes = void 0;
|
|
|
5
5
|
var TrainingSchemes;
|
|
6
6
|
(function (TrainingSchemes) {
|
|
7
7
|
TrainingSchemes["LOCAL"] = "local";
|
|
8
|
-
TrainingSchemes["DECENTRALIZED"] = "
|
|
9
|
-
TrainingSchemes["FEDERATED"] = "
|
|
8
|
+
TrainingSchemes["DECENTRALIZED"] = "decentralized";
|
|
9
|
+
TrainingSchemes["FEDERATED"] = "federated";
|
|
10
10
|
})(TrainingSchemes = exports.TrainingSchemes || (exports.TrainingSchemes = {}));
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
@@ -1,18 +1,15 @@
|
|
|
1
|
-
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { ModelActor } from '../model_actor';
|
|
3
|
-
import { Task } from '@/task';
|
|
4
|
-
import { Data } from '@/dataset';
|
|
5
|
-
import { Logger } from '@/logging';
|
|
6
1
|
import { List } from 'immutable';
|
|
7
|
-
import { Client, Memory, ModelSource } from '..';
|
|
8
|
-
export declare class Validator
|
|
2
|
+
import { tf, data, Task, Logger, Client, Memory, ModelSource } from '..';
|
|
3
|
+
export declare class Validator {
|
|
4
|
+
readonly task: Task;
|
|
5
|
+
readonly logger: Logger;
|
|
9
6
|
private readonly memory;
|
|
10
7
|
private readonly source?;
|
|
11
8
|
private readonly client?;
|
|
12
9
|
private readonly graphInformant;
|
|
13
10
|
private size;
|
|
14
11
|
constructor(task: Task, logger: Logger, memory: Memory, source?: ModelSource | undefined, client?: Client | undefined);
|
|
15
|
-
assess(data: Data): Promise<void>;
|
|
12
|
+
assess(data: data.Data): Promise<void>;
|
|
16
13
|
getModel(): Promise<tf.LayersModel>;
|
|
17
14
|
accuracyData(): List<number>;
|
|
18
15
|
accuracy(): number;
|
|
@@ -2,22 +2,20 @@
|
|
|
2
2
|
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
3
|
exports.Validator = void 0;
|
|
4
4
|
var tslib_1 = require("tslib");
|
|
5
|
-
var model_actor_1 = require("../model_actor");
|
|
6
5
|
var immutable_1 = require("immutable");
|
|
7
6
|
var __1 = require("..");
|
|
8
|
-
var Validator = /** @class */ (function (
|
|
9
|
-
(0, tslib_1.__extends)(Validator, _super);
|
|
7
|
+
var Validator = /** @class */ (function () {
|
|
10
8
|
function Validator(task, logger, memory, source, client) {
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
9
|
+
this.task = task;
|
|
10
|
+
this.logger = logger;
|
|
11
|
+
this.memory = memory;
|
|
12
|
+
this.source = source;
|
|
13
|
+
this.client = client;
|
|
14
|
+
this.graphInformant = new __1.GraphInformant();
|
|
15
|
+
this.size = 0;
|
|
17
16
|
if (source === undefined && client === undefined) {
|
|
18
17
|
throw new Error('cannot identify model');
|
|
19
18
|
}
|
|
20
|
-
return _this;
|
|
21
19
|
}
|
|
22
20
|
Validator.prototype.assess = function (data) {
|
|
23
21
|
var _a, _b, _c;
|
|
@@ -102,5 +100,5 @@ var Validator = /** @class */ (function (_super) {
|
|
|
102
100
|
return this.size;
|
|
103
101
|
};
|
|
104
102
|
return Validator;
|
|
105
|
-
}(
|
|
103
|
+
}());
|
|
106
104
|
exports.Validator = Validator;
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import { TensorLike, WeightsContainer } from './weights_container';
|
|
2
|
+
declare type WeightsLike = Iterable<TensorLike>;
|
|
3
|
+
export declare function sum(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
|
|
4
|
+
export declare function diff(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
|
|
5
|
+
export declare function avg(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
|
|
6
|
+
export declare function avgClippingWeights(peersWeights: Iterable<WeightsLike | WeightsContainer>, currentModel: WeightsContainer, tauPercentile: number): WeightsContainer;
|
|
7
|
+
export {};
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.avgClippingWeights = exports.avg = exports.diff = exports.sum = void 0;
|
|
4
|
+
var immutable_1 = require("immutable");
|
|
5
|
+
var __1 = require("..");
|
|
6
|
+
var weights_container_1 = require("./weights_container");
|
|
7
|
+
function parseWeights(weights) {
|
|
8
|
+
var _a;
|
|
9
|
+
var r = (0, immutable_1.List)(weights).map(function (w) {
|
|
10
|
+
return w instanceof weights_container_1.WeightsContainer ? w : new weights_container_1.WeightsContainer(w);
|
|
11
|
+
});
|
|
12
|
+
var weightsSize = (_a = r.first()) === null || _a === void 0 ? void 0 : _a.weights.length;
|
|
13
|
+
if (weightsSize === undefined) {
|
|
14
|
+
throw new Error('no weights to work with');
|
|
15
|
+
}
|
|
16
|
+
if (r.rest().every(function (w) { return w.weights.length !== weightsSize; })) {
|
|
17
|
+
throw new Error('weights dimensions are different for some of the operands');
|
|
18
|
+
}
|
|
19
|
+
return r;
|
|
20
|
+
}
|
|
21
|
+
function centerWeights(weights, currentModel) {
|
|
22
|
+
return parseWeights(weights).map(function (model) { return model.mapWith(currentModel, __1.tf.sub); });
|
|
23
|
+
}
|
|
24
|
+
function clipWeights(modelList, normArray, tau) {
|
|
25
|
+
return modelList.map(function (weights) { return weights.map(function (w, i) { return __1.tf.prod(w, Math.min(1, tau / (normArray[i]))); }); });
|
|
26
|
+
}
|
|
27
|
+
function computeQuantile(array, q) {
|
|
28
|
+
var sorted = array.sort(function (a, b) { return a - b; });
|
|
29
|
+
var pos = (sorted.length - 1) * q;
|
|
30
|
+
var base = Math.floor(pos);
|
|
31
|
+
var rest = pos - base;
|
|
32
|
+
if (sorted[base + 1] !== undefined) {
|
|
33
|
+
return sorted[base] + rest * (sorted[base + 1] - sorted[base]);
|
|
34
|
+
}
|
|
35
|
+
else {
|
|
36
|
+
return sorted[base];
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
function reduce(weights, fn) {
|
|
40
|
+
return parseWeights(weights).reduce(function (acc, ws) {
|
|
41
|
+
return new weights_container_1.WeightsContainer(acc.weights.map(function (w, i) {
|
|
42
|
+
return fn(w, ws.get(i));
|
|
43
|
+
}));
|
|
44
|
+
});
|
|
45
|
+
}
|
|
46
|
+
function sum(weights) {
|
|
47
|
+
return reduce(weights, __1.tf.add);
|
|
48
|
+
}
|
|
49
|
+
exports.sum = sum;
|
|
50
|
+
function diff(weights) {
|
|
51
|
+
return reduce(weights, __1.tf.sub);
|
|
52
|
+
}
|
|
53
|
+
exports.diff = diff;
|
|
54
|
+
function avg(weights) {
|
|
55
|
+
var size = (0, immutable_1.List)(weights).size;
|
|
56
|
+
return sum(weights).map(function (ws) { return ws.div(size); });
|
|
57
|
+
}
|
|
58
|
+
exports.avg = avg;
|
|
59
|
+
// See: https://arxiv.org/abs/2012.10333
|
|
60
|
+
function avgClippingWeights(peersWeights, currentModel, tauPercentile) {
|
|
61
|
+
// Computing the centered peers weights with respect to the previous model aggragation
|
|
62
|
+
var centeredPeersWeights = centerWeights(peersWeights, currentModel);
|
|
63
|
+
// Computing the Matrix Norm (Frobenius Norm) of the centered peers weights
|
|
64
|
+
var normArray = Array.from(centeredPeersWeights.map(function (model) { return model.frobeniusNorm(); }));
|
|
65
|
+
// Computing the parameter tau as third percentile with respect to the norm array
|
|
66
|
+
var tau = computeQuantile(normArray, tauPercentile);
|
|
67
|
+
// Computing the centered clipped peers weights given the norm array and the parameter tau
|
|
68
|
+
var centeredMean = clipWeights(centeredPeersWeights, normArray, tau);
|
|
69
|
+
// Aggregating all centered clipped peers weights
|
|
70
|
+
return avg(centeredMean);
|
|
71
|
+
}
|
|
72
|
+
exports.avgClippingWeights = avgClippingWeights;
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.aggregation = exports.WeightsContainer = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var weights_container_1 = require("./weights_container");
|
|
6
|
+
Object.defineProperty(exports, "WeightsContainer", { enumerable: true, get: function () { return weights_container_1.WeightsContainer; } });
|
|
7
|
+
exports.aggregation = (0, tslib_1.__importStar)(require("./aggregation"));
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import { tf, Weights } from '..';
|
|
2
|
+
export declare type TensorLike = tf.Tensor | ArrayLike<number>;
|
|
3
|
+
export declare class WeightsContainer {
|
|
4
|
+
private readonly _weights;
|
|
5
|
+
constructor(weights: Iterable<TensorLike>);
|
|
6
|
+
get weights(): Weights;
|
|
7
|
+
add(other: WeightsContainer): WeightsContainer;
|
|
8
|
+
sub(other: WeightsContainer): WeightsContainer;
|
|
9
|
+
mapWith(other: WeightsContainer, fn: (a: tf.Tensor, b: tf.Tensor) => tf.Tensor): WeightsContainer;
|
|
10
|
+
map(fn: (t: tf.Tensor, i: number) => tf.Tensor): WeightsContainer;
|
|
11
|
+
map(fn: (t: tf.Tensor) => tf.Tensor): WeightsContainer;
|
|
12
|
+
reduce(fn: (acc: tf.Tensor, t: tf.Tensor) => tf.Tensor): tf.Tensor;
|
|
13
|
+
get(index: number): tf.Tensor | undefined;
|
|
14
|
+
frobeniusNorm(): number;
|
|
15
|
+
static of(...weights: TensorLike[]): WeightsContainer;
|
|
16
|
+
static from(model: tf.LayersModel): WeightsContainer;
|
|
17
|
+
static add(a: Iterable<TensorLike>, b: Iterable<TensorLike>): WeightsContainer;
|
|
18
|
+
static sub(a: Iterable<TensorLike>, b: Iterable<TensorLike>): WeightsContainer;
|
|
19
|
+
}
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"use strict";
|
|
2
|
+
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
+
exports.WeightsContainer = void 0;
|
|
4
|
+
var tslib_1 = require("tslib");
|
|
5
|
+
var immutable_1 = require("immutable");
|
|
6
|
+
var __1 = require("..");
|
|
7
|
+
var WeightsContainer = /** @class */ (function () {
|
|
8
|
+
function WeightsContainer(weights) {
|
|
9
|
+
this._weights = (0, immutable_1.List)(weights).map(function (w) {
|
|
10
|
+
return w instanceof __1.tf.Tensor ? w : __1.tf.tensor(w);
|
|
11
|
+
});
|
|
12
|
+
}
|
|
13
|
+
Object.defineProperty(WeightsContainer.prototype, "weights", {
|
|
14
|
+
get: function () {
|
|
15
|
+
return this._weights.toArray();
|
|
16
|
+
},
|
|
17
|
+
enumerable: false,
|
|
18
|
+
configurable: true
|
|
19
|
+
});
|
|
20
|
+
WeightsContainer.prototype.add = function (other) {
|
|
21
|
+
return this.mapWith(other, __1.tf.add);
|
|
22
|
+
};
|
|
23
|
+
WeightsContainer.prototype.sub = function (other) {
|
|
24
|
+
return this.mapWith(other, __1.tf.sub);
|
|
25
|
+
};
|
|
26
|
+
WeightsContainer.prototype.mapWith = function (other, fn) {
|
|
27
|
+
return new WeightsContainer(this._weights
|
|
28
|
+
.zip(other._weights)
|
|
29
|
+
.map(function (_a) {
|
|
30
|
+
var _b = (0, tslib_1.__read)(_a, 2), w1 = _b[0], w2 = _b[1];
|
|
31
|
+
return fn(w1, w2);
|
|
32
|
+
}));
|
|
33
|
+
};
|
|
34
|
+
WeightsContainer.prototype.map = function (fn) {
|
|
35
|
+
return new WeightsContainer(this._weights.map(fn));
|
|
36
|
+
};
|
|
37
|
+
WeightsContainer.prototype.reduce = function (fn) {
|
|
38
|
+
return this._weights.reduce(fn);
|
|
39
|
+
};
|
|
40
|
+
WeightsContainer.prototype.get = function (index) {
|
|
41
|
+
return this._weights.get(index);
|
|
42
|
+
};
|
|
43
|
+
WeightsContainer.prototype.frobeniusNorm = function () {
|
|
44
|
+
return Math.sqrt(this.map(function (w) { return w.square().sum(); }).reduce(function (a, b) { return a.add(b); }).dataSync()[0]);
|
|
45
|
+
};
|
|
46
|
+
WeightsContainer.of = function () {
|
|
47
|
+
var weights = [];
|
|
48
|
+
for (var _i = 0; _i < arguments.length; _i++) {
|
|
49
|
+
weights[_i] = arguments[_i];
|
|
50
|
+
}
|
|
51
|
+
return new this(weights);
|
|
52
|
+
};
|
|
53
|
+
WeightsContainer.from = function (model) {
|
|
54
|
+
return new this(model.weights.map(function (w) { return w.read(); }));
|
|
55
|
+
};
|
|
56
|
+
WeightsContainer.add = function (a, b) {
|
|
57
|
+
return new this(a).add(new this(b));
|
|
58
|
+
};
|
|
59
|
+
WeightsContainer.sub = function (a, b) {
|
|
60
|
+
return new this(a).sub(new this(b));
|
|
61
|
+
};
|
|
62
|
+
return WeightsContainer;
|
|
63
|
+
}());
|
|
64
|
+
exports.WeightsContainer = WeightsContainer;
|
|
@@ -1,16 +1,4 @@
|
|
|
1
|
-
import { tf } from '../..';
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
/**
|
|
5
|
-
* TODO @s314cy:
|
|
6
|
-
* Load labels and correctly match them with their respective images, with the following constraints:
|
|
7
|
-
* 1. Images are given as 1 image/1 file
|
|
8
|
-
* 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels
|
|
9
|
-
*/
|
|
10
|
-
export declare abstract class ImageLoader<Source> extends DataLoader<Source> {
|
|
11
|
-
abstract readImageFrom(source: Source): Promise<tf.Tensor3D>;
|
|
12
|
-
load(image: Source, config?: DataConfig): Promise<Dataset>;
|
|
13
|
-
private buildDataset;
|
|
14
|
-
loadAll(images: Source[], config?: DataConfig): Promise<DataTuple>;
|
|
15
|
-
shuffle(array: number[]): void;
|
|
1
|
+
import { tf, data } from '../..';
|
|
2
|
+
export declare class WebImageLoader extends data.ImageLoader<File> {
|
|
3
|
+
readImageFrom(source: File): Promise<tf.Tensor3D>;
|
|
16
4
|
}
|