@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/default_tasks/cifar10/index.js +60 -0
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/default_tasks/mnist.js +61 -0
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/default_tasks/simple_face/index.js +48 -0
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/default_tasks/titanic.js +88 -0
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.d.ts +5 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/task/task.d.ts +12 -0
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +25 -16
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -41
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -3
- package/dist/core/task/index.js +0 -8
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.d.ts +0 -10
- package/dist/core/task/task.js +0 -31
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/tasks/cifar10.d.ts +0 -3
- package/dist/core/tasks/cifar10.js +0 -65
- package/dist/core/tasks/geotags.d.ts +0 -3
- package/dist/core/tasks/geotags.js +0 -67
- package/dist/core/tasks/index.d.ts +0 -6
- package/dist/core/tasks/index.js +0 -10
- package/dist/core/tasks/lus_covid.d.ts +0 -3
- package/dist/core/tasks/lus_covid.js +0 -87
- package/dist/core/tasks/mnist.d.ts +0 -3
- package/dist/core/tasks/mnist.js +0 -60
- package/dist/core/tasks/simple_face.d.ts +0 -2
- package/dist/core/tasks/simple_face.js +0 -41
- package/dist/core/tasks/titanic.d.ts +0 -3
- package/dist/core/tasks/titanic.js +0 -88
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -8
- package/dist/core/weights/aggregation.js +0 -96
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import { client as clients, EmptyMemory, ConsoleLogger } from '../index.js';
|
|
2
|
+
import { MeanAggregator } from '../aggregator/mean.js';
|
|
3
|
+
import { TrainerBuilder } from './trainer/trainer_builder.js';
|
|
4
|
+
/**
|
|
5
|
+
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
6
|
+
* a convenient object providing a reduced yet complete API that wraps model training,
|
|
7
|
+
* communication with nodes, logs and model memory.
|
|
8
|
+
*/
|
|
9
|
+
export class Disco {
|
|
10
|
+
task;
|
|
11
|
+
logger;
|
|
12
|
+
memory;
|
|
13
|
+
client;
|
|
14
|
+
trainer;
|
|
15
|
+
constructor(task, options) {
|
|
16
|
+
if (options.scheme === undefined) {
|
|
17
|
+
options.scheme = task.trainingInformation.scheme;
|
|
18
|
+
}
|
|
19
|
+
if (options.aggregator === undefined) {
|
|
20
|
+
options.aggregator = new MeanAggregator();
|
|
21
|
+
}
|
|
22
|
+
if (options.client === undefined) {
|
|
23
|
+
if (options.url === undefined) {
|
|
24
|
+
throw new Error('could not determine client from given parameters');
|
|
25
|
+
}
|
|
26
|
+
if (typeof options.url === 'string') {
|
|
27
|
+
options.url = new URL(options.url);
|
|
28
|
+
}
|
|
29
|
+
switch (options.scheme) {
|
|
30
|
+
case 'federated':
|
|
31
|
+
options.client = new clients.federated.FederatedClient(options.url, task, options.aggregator);
|
|
32
|
+
break;
|
|
33
|
+
case 'decentralized':
|
|
34
|
+
options.client = new clients.decentralized.DecentralizedClient(options.url, task, options.aggregator);
|
|
35
|
+
break;
|
|
36
|
+
case 'local':
|
|
37
|
+
options.client = new clients.Local(options.url, task, options.aggregator);
|
|
38
|
+
break;
|
|
39
|
+
default: {
|
|
40
|
+
const _ = options.scheme;
|
|
41
|
+
throw new Error('should never happen');
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
if (options.logger === undefined) {
|
|
46
|
+
options.logger = new ConsoleLogger();
|
|
47
|
+
}
|
|
48
|
+
if (options.memory === undefined) {
|
|
49
|
+
options.memory = new EmptyMemory();
|
|
50
|
+
}
|
|
51
|
+
if (options.client.task !== task) {
|
|
52
|
+
throw new Error('client not setup for given task');
|
|
53
|
+
}
|
|
54
|
+
this.task = task;
|
|
55
|
+
this.client = options.client;
|
|
56
|
+
this.memory = options.memory;
|
|
57
|
+
this.logger = options.logger;
|
|
58
|
+
const trainerBuilder = new TrainerBuilder(this.memory, this.task);
|
|
59
|
+
this.trainer = trainerBuilder.build(this.client, options.scheme !== 'local');
|
|
60
|
+
}
|
|
61
|
+
/**
|
|
62
|
+
* Starts a training instance for the Disco object's task on the provided data tuple.
|
|
63
|
+
* @param dataTuple The data tuple
|
|
64
|
+
*/
|
|
65
|
+
// TODO RoundLogs should contain number of participants but Trainer doesn't need client
|
|
66
|
+
async *fit(dataTuple) {
|
|
67
|
+
this.logger.success("Training started.");
|
|
68
|
+
const trainData = dataTuple.train.preprocess().batch();
|
|
69
|
+
const validationData = dataTuple.validation?.preprocess().batch() ?? trainData;
|
|
70
|
+
await this.client.connect();
|
|
71
|
+
const trainer = await this.trainer;
|
|
72
|
+
for await (const roundLogs of trainer.fitModel(trainData.dataset, validationData.dataset)) {
|
|
73
|
+
let msg = `Round: ${roundLogs.round}\n`;
|
|
74
|
+
for (const epochLogs of roundLogs.epochs.values()) {
|
|
75
|
+
msg += ` Epoch: ${epochLogs.epoch}\n`;
|
|
76
|
+
msg += ` Training loss: ${epochLogs.training.loss}\n`;
|
|
77
|
+
if (epochLogs.training.accuracy !== undefined) {
|
|
78
|
+
msg += ` Training accuracy: ${epochLogs.training.accuracy}\n`;
|
|
79
|
+
}
|
|
80
|
+
if (epochLogs.validation !== undefined) {
|
|
81
|
+
msg += ` Validation loss: ${epochLogs.validation.loss}\n`;
|
|
82
|
+
msg += ` Validation accuracy: ${epochLogs.validation.accuracy}\n`;
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
this.logger.success(msg);
|
|
86
|
+
yield {
|
|
87
|
+
...roundLogs,
|
|
88
|
+
participants: this.client.nodes.size + 1 // add ourself
|
|
89
|
+
};
|
|
90
|
+
}
|
|
91
|
+
this.logger.success("Training finished.");
|
|
92
|
+
}
|
|
93
|
+
/**
|
|
94
|
+
* Stops the ongoing training instance without disconnecting the client.
|
|
95
|
+
*/
|
|
96
|
+
async pause() {
|
|
97
|
+
const trainer = await this.trainer;
|
|
98
|
+
await trainer.stopTraining();
|
|
99
|
+
}
|
|
100
|
+
/**
|
|
101
|
+
* Completely stops the ongoing training instance.
|
|
102
|
+
*/
|
|
103
|
+
async close() {
|
|
104
|
+
await this.pause();
|
|
105
|
+
await this.client.disconnect();
|
|
106
|
+
}
|
|
107
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export { Disco } from './disco.js';
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import type { Model, Memory, Task, client as clients } from "../../index.js";
|
|
2
|
+
import { Trainer } from "./trainer.js";
|
|
3
|
+
/**
|
|
4
|
+
* Class whose role is to train a model in a distributed way with a given dataset.
|
|
5
|
+
*/
|
|
6
|
+
export declare class DistributedTrainer extends Trainer {
|
|
7
|
+
private readonly task;
|
|
8
|
+
private readonly memory;
|
|
9
|
+
private readonly client;
|
|
10
|
+
private readonly aggregator;
|
|
11
|
+
/**
|
|
12
|
+
* DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
|
|
13
|
+
*/
|
|
14
|
+
constructor(task: Task, memory: Memory, model: Model, client: clients.Client);
|
|
15
|
+
onRoundBegin(round: number): Promise<void>;
|
|
16
|
+
/**
|
|
17
|
+
* Callback called every time a round is over
|
|
18
|
+
*/
|
|
19
|
+
onRoundEnd(round: number): Promise<void>;
|
|
20
|
+
}
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import { Trainer } from "./trainer.js";
|
|
2
|
+
/**
|
|
3
|
+
* Class whose role is to train a model in a distributed way with a given dataset.
|
|
4
|
+
*/
|
|
5
|
+
export class DistributedTrainer extends Trainer {
|
|
6
|
+
task;
|
|
7
|
+
memory;
|
|
8
|
+
client;
|
|
9
|
+
aggregator;
|
|
10
|
+
/**
|
|
11
|
+
* DistributedTrainer constructor, accepts same arguments as Trainer and in additional also a client who takes care of communicating weights.
|
|
12
|
+
*/
|
|
13
|
+
constructor(task, memory, model, client) {
|
|
14
|
+
super(task, model);
|
|
15
|
+
this.task = task;
|
|
16
|
+
this.memory = memory;
|
|
17
|
+
this.client = client;
|
|
18
|
+
this.aggregator = this.client.aggregator;
|
|
19
|
+
this.aggregator.setModel(model);
|
|
20
|
+
}
|
|
21
|
+
async onRoundBegin(round) {
|
|
22
|
+
await this.client.onRoundBeginCommunication(this.model.weights, round);
|
|
23
|
+
}
|
|
24
|
+
/**
|
|
25
|
+
* Callback called every time a round is over
|
|
26
|
+
*/
|
|
27
|
+
async onRoundEnd(round) {
|
|
28
|
+
await this.client.onRoundEndCommunication(this.model.weights, round);
|
|
29
|
+
if (this.aggregator.model !== undefined) {
|
|
30
|
+
// The aggregator's own aggregation is async. The trainer updates its model to match the aggregator's
|
|
31
|
+
// after it has completed a round of training.
|
|
32
|
+
this.model.weights = this.aggregator.model.weights;
|
|
33
|
+
}
|
|
34
|
+
await this.memory.updateWorkingModel({ taskID: this.task.id, name: this.task.trainingInformation.modelID }, this.model);
|
|
35
|
+
}
|
|
36
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import type { Memory, Model, Task } from "../../index.js";
|
|
2
|
+
import { Trainer } from "./trainer.js";
|
|
3
|
+
/** Class whose role is to locally (alone) train a model on a given dataset,
|
|
4
|
+
* without any collaborators.
|
|
5
|
+
*/
|
|
6
|
+
export declare class LocalTrainer extends Trainer {
|
|
7
|
+
private readonly task;
|
|
8
|
+
private readonly memory;
|
|
9
|
+
constructor(task: Task, memory: Memory, model: Model);
|
|
10
|
+
onRoundBegin(): Promise<void>;
|
|
11
|
+
onRoundEnd(): Promise<void>;
|
|
12
|
+
}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import { Trainer } from "./trainer.js";
|
|
2
|
+
/** Class whose role is to locally (alone) train a model on a given dataset,
|
|
3
|
+
* without any collaborators.
|
|
4
|
+
*/
|
|
5
|
+
export class LocalTrainer extends Trainer {
|
|
6
|
+
task;
|
|
7
|
+
memory;
|
|
8
|
+
constructor(task, memory, model) {
|
|
9
|
+
super(task, model);
|
|
10
|
+
this.task = task;
|
|
11
|
+
this.memory = memory;
|
|
12
|
+
}
|
|
13
|
+
async onRoundBegin() {
|
|
14
|
+
return await Promise.resolve();
|
|
15
|
+
}
|
|
16
|
+
async onRoundEnd() {
|
|
17
|
+
await this.memory.updateWorkingModel({ taskID: this.task.id, name: this.task.trainingInformation.modelID }, this.model);
|
|
18
|
+
}
|
|
19
|
+
}
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import type tf from "@tensorflow/tfjs";
|
|
2
|
+
import { List } from "immutable";
|
|
3
|
+
import type { Model, Task } from "../../index.js";
|
|
4
|
+
import { EpochLogs } from "../../models/model.js";
|
|
5
|
+
export interface RoundLogs {
|
|
6
|
+
round: number;
|
|
7
|
+
epochs: List<EpochLogs>;
|
|
8
|
+
}
|
|
9
|
+
/** Abstract class whose role is to train a model with a given dataset. This can be either done
|
|
10
|
+
* locally (alone) or in a distributed way with collaborators.
|
|
11
|
+
*
|
|
12
|
+
* 1. Call `fitModel(dataset)` to start training.
|
|
13
|
+
* 2. which will then call onRoundEnd once the round has ended.
|
|
14
|
+
*
|
|
15
|
+
* The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators.
|
|
16
|
+
*/
|
|
17
|
+
export declare abstract class Trainer {
|
|
18
|
+
#private;
|
|
19
|
+
readonly model: Model;
|
|
20
|
+
private training?;
|
|
21
|
+
constructor(task: Task, model: Model);
|
|
22
|
+
protected abstract onRoundBegin(round: number): Promise<void>;
|
|
23
|
+
protected abstract onRoundEnd(round: number): Promise<void>;
|
|
24
|
+
/**
|
|
25
|
+
* Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
|
|
26
|
+
*/
|
|
27
|
+
stopTraining(): Promise<void>;
|
|
28
|
+
/**
|
|
29
|
+
* Start training the model with the given dataset
|
|
30
|
+
* @param dataset
|
|
31
|
+
*/
|
|
32
|
+
fitModel(dataset: tf.data.Dataset<tf.TensorContainer>, valDataset: tf.data.Dataset<tf.TensorContainer>): AsyncGenerator<RoundLogs>;
|
|
33
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import { List } from "immutable";
|
|
2
|
+
/** Abstract class whose role is to train a model with a given dataset. This can be either done
|
|
3
|
+
* locally (alone) or in a distributed way with collaborators.
|
|
4
|
+
*
|
|
5
|
+
* 1. Call `fitModel(dataset)` to start training.
|
|
6
|
+
* 2. which will then call onRoundEnd once the round has ended.
|
|
7
|
+
*
|
|
8
|
+
* The onRoundEnd needs to be implemented to specify what actions to do when the round has ended, such as a communication step with collaborators.
|
|
9
|
+
*/
|
|
10
|
+
export class Trainer {
|
|
11
|
+
model;
|
|
12
|
+
#roundDuration;
|
|
13
|
+
#epochs;
|
|
14
|
+
training;
|
|
15
|
+
constructor(task, model) {
|
|
16
|
+
this.model = model;
|
|
17
|
+
this.#roundDuration = task.trainingInformation.roundDuration;
|
|
18
|
+
this.#epochs = task.trainingInformation.epochs;
|
|
19
|
+
}
|
|
20
|
+
/**
|
|
21
|
+
* Request stop training to be used from the Disco instance or any class that is taking care of the trainer.
|
|
22
|
+
*/
|
|
23
|
+
async stopTraining() {
|
|
24
|
+
await this.training?.return();
|
|
25
|
+
}
|
|
26
|
+
/**
|
|
27
|
+
* Start training the model with the given dataset
|
|
28
|
+
* @param dataset
|
|
29
|
+
*/
|
|
30
|
+
async *fitModel(dataset, valDataset) {
|
|
31
|
+
if (this.training !== undefined) {
|
|
32
|
+
throw new Error("training already running, cancel it before launching a new one");
|
|
33
|
+
}
|
|
34
|
+
await this.onRoundBegin(0);
|
|
35
|
+
this.training = this.model.train(dataset, valDataset, this.#epochs);
|
|
36
|
+
for await (const logs of this.training) {
|
|
37
|
+
// for now, round (sharing on network) == epoch (full pass over local data)
|
|
38
|
+
yield {
|
|
39
|
+
round: logs.epoch,
|
|
40
|
+
epochs: List.of(logs),
|
|
41
|
+
};
|
|
42
|
+
if (logs.epoch % this.#roundDuration === 0) {
|
|
43
|
+
const round = Math.trunc(logs.epoch / this.#roundDuration);
|
|
44
|
+
await this.onRoundEnd(round);
|
|
45
|
+
await this.onRoundBegin(round);
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
const round = Math.trunc(this.#epochs / this.#roundDuration);
|
|
49
|
+
await this.onRoundEnd(round);
|
|
50
|
+
this.training = undefined;
|
|
51
|
+
}
|
|
52
|
+
}
|
|
@@ -1,13 +1,12 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { Trainer } from './trainer';
|
|
1
|
+
import type { client as clients, Task, Memory } from '../../index.js';
|
|
2
|
+
import type { Trainer } from './trainer.js';
|
|
3
3
|
/**
|
|
4
4
|
* A class that helps build the Trainer and auxiliary classes.
|
|
5
5
|
*/
|
|
6
6
|
export declare class TrainerBuilder {
|
|
7
7
|
private readonly memory;
|
|
8
8
|
private readonly task;
|
|
9
|
-
|
|
10
|
-
constructor(memory: Memory, task: Task, trainingInformant: TrainingInformant);
|
|
9
|
+
constructor(memory: Memory, task: Task);
|
|
11
10
|
/**
|
|
12
11
|
* Builds a trainer object.
|
|
13
12
|
*
|
|
@@ -15,11 +14,10 @@ export declare class TrainerBuilder {
|
|
|
15
14
|
* @param distributed whether to build a distributed or local trainer
|
|
16
15
|
* @returns
|
|
17
16
|
*/
|
|
18
|
-
build(client: Client, distributed?: boolean): Promise<Trainer>;
|
|
17
|
+
build(client: clients.Client, distributed?: boolean): Promise<Trainer>;
|
|
19
18
|
/**
|
|
20
|
-
* If a model exists in memory,
|
|
19
|
+
* If a model exists in memory, load it, otherwise load model from server
|
|
21
20
|
* @returns
|
|
22
21
|
*/
|
|
23
22
|
private getModel;
|
|
24
|
-
private updateModelInformation;
|
|
25
23
|
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import { ModelType } from '../../index.js';
|
|
2
|
+
import { DistributedTrainer } from './distributed_trainer.js';
|
|
3
|
+
import { LocalTrainer } from './local_trainer.js';
|
|
4
|
+
/**
|
|
5
|
+
* A class that helps build the Trainer and auxiliary classes.
|
|
6
|
+
*/
|
|
7
|
+
export class TrainerBuilder {
|
|
8
|
+
memory;
|
|
9
|
+
task;
|
|
10
|
+
constructor(memory, task) {
|
|
11
|
+
this.memory = memory;
|
|
12
|
+
this.task = task;
|
|
13
|
+
}
|
|
14
|
+
/**
|
|
15
|
+
* Builds a trainer object.
|
|
16
|
+
*
|
|
17
|
+
* @param client client to share weights with (either distributed or federated)
|
|
18
|
+
* @param distributed whether to build a distributed or local trainer
|
|
19
|
+
* @returns
|
|
20
|
+
*/
|
|
21
|
+
async build(client, distributed = false) {
|
|
22
|
+
const model = await this.getModel(client);
|
|
23
|
+
if (distributed) {
|
|
24
|
+
return new DistributedTrainer(this.task, this.memory, model, client);
|
|
25
|
+
}
|
|
26
|
+
else {
|
|
27
|
+
return new LocalTrainer(this.task, this.memory, model);
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
/**
|
|
31
|
+
* If a model exists in memory, load it, otherwise load model from server
|
|
32
|
+
* @returns
|
|
33
|
+
*/
|
|
34
|
+
async getModel(client) {
|
|
35
|
+
const modelID = this.task.trainingInformation?.modelID;
|
|
36
|
+
if (modelID === undefined) {
|
|
37
|
+
throw new TypeError('model ID is undefined');
|
|
38
|
+
}
|
|
39
|
+
const info = { type: ModelType.WORKING, taskID: this.task.id, name: modelID };
|
|
40
|
+
const model = await (await this.memory.contains(info) ? this.memory.getModel(info) : client.getLatestModel());
|
|
41
|
+
return model;
|
|
42
|
+
}
|
|
43
|
+
}
|
package/dist/types.d.ts
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
import type { Map } from 'immutable';
|
|
2
|
+
import type { WeightsContainer } from './index.js';
|
|
3
|
+
import type { NodeID } from './client/index.js';
|
|
4
|
+
export type Path = string;
|
|
5
|
+
export type MetadataKey = string;
|
|
6
|
+
export type MetadataValue = string;
|
|
7
|
+
export type Features = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][];
|
|
8
|
+
export type Contributions = Map<NodeID, WeightsContainer>;
|
package/dist/types.js
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
type Listener<T> = (_: T) => void;
|
|
2
|
+
/**
|
|
3
|
+
* Call handlers on given events
|
|
4
|
+
*
|
|
5
|
+
* @typeParam I object/mapping from event name to emitted value type
|
|
6
|
+
*/
|
|
7
|
+
export declare class EventEmitter<I extends Record<string, unknown>> {
|
|
8
|
+
private listeners;
|
|
9
|
+
/**
|
|
10
|
+
* @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
|
|
11
|
+
*/
|
|
12
|
+
constructor(initialListeners?: {
|
|
13
|
+
[E in keyof I]?: Listener<I[E]>;
|
|
14
|
+
});
|
|
15
|
+
/**
|
|
16
|
+
* Register listener to call on event
|
|
17
|
+
*
|
|
18
|
+
* @param event event name to listen to
|
|
19
|
+
* @param listener handler to call
|
|
20
|
+
*/
|
|
21
|
+
on<E extends keyof I>(event: E, listener: Listener<I[E]>): void;
|
|
22
|
+
/**
|
|
23
|
+
* Register listener to call once on next event
|
|
24
|
+
*
|
|
25
|
+
* @param event event name to listen to
|
|
26
|
+
* @param listener handler to call next time
|
|
27
|
+
*/
|
|
28
|
+
once<E extends keyof I>(event: E, listener: Listener<I[E]>): void;
|
|
29
|
+
/**
|
|
30
|
+
* Send value to registered listeners of event name
|
|
31
|
+
*
|
|
32
|
+
* @param event send to listeners of event name
|
|
33
|
+
* @param value what to call listeners with
|
|
34
|
+
*/
|
|
35
|
+
emit<E extends keyof I>(event: E, value: I[E]): void;
|
|
36
|
+
}
|
|
37
|
+
/** `EventEmitter` for all events */
|
|
38
|
+
export declare class Sink extends EventEmitter<Record<string, unknown>> {
|
|
39
|
+
}
|
|
40
|
+
export {};
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
// inspired by https://danilafe.com/blog/typescript_typesafe_events/
|
|
2
|
+
import { List } from 'immutable';
|
|
3
|
+
/**
|
|
4
|
+
* Call handlers on given events
|
|
5
|
+
*
|
|
6
|
+
* @typeParam I object/mapping from event name to emitted value type
|
|
7
|
+
*/
|
|
8
|
+
export class EventEmitter {
|
|
9
|
+
listeners = {};
|
|
10
|
+
/**
|
|
11
|
+
* @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
|
|
12
|
+
*/
|
|
13
|
+
constructor(initialListeners = {}) {
|
|
14
|
+
for (const event in initialListeners) {
|
|
15
|
+
const listener = initialListeners[event];
|
|
16
|
+
if (listener !== undefined) {
|
|
17
|
+
this.on(event, listener);
|
|
18
|
+
}
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
/**
|
|
22
|
+
* Register listener to call on event
|
|
23
|
+
*
|
|
24
|
+
* @param event event name to listen to
|
|
25
|
+
* @param listener handler to call
|
|
26
|
+
*/
|
|
27
|
+
on(event, listener) {
|
|
28
|
+
const eventListeners = this.listeners[event] ?? List();
|
|
29
|
+
this.listeners[event] = eventListeners.push([false, listener]);
|
|
30
|
+
}
|
|
31
|
+
/**
|
|
32
|
+
* Register listener to call once on next event
|
|
33
|
+
*
|
|
34
|
+
* @param event event name to listen to
|
|
35
|
+
* @param listener handler to call next time
|
|
36
|
+
*/
|
|
37
|
+
once(event, listener) {
|
|
38
|
+
const eventListeners = this.listeners[event] ?? List();
|
|
39
|
+
this.listeners[event] = eventListeners.push([true, listener]);
|
|
40
|
+
}
|
|
41
|
+
/**
|
|
42
|
+
* Send value to registered listeners of event name
|
|
43
|
+
*
|
|
44
|
+
* @param event send to listeners of event name
|
|
45
|
+
* @param value what to call listeners with
|
|
46
|
+
*/
|
|
47
|
+
emit(event, value) {
|
|
48
|
+
const eventListeners = this.listeners[event] ?? List();
|
|
49
|
+
this.listeners[event] = eventListeners.filterNot(([once]) => once);
|
|
50
|
+
eventListeners.forEach(([_, listener]) => {
|
|
51
|
+
listener(value);
|
|
52
|
+
});
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
/** `EventEmitter` for all events */
|
|
56
|
+
export class Sink extends EventEmitter {
|
|
57
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export { Validator } from './validator.js';
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export { Validator } from './validator.js';
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import type { data, Model, Task, Logger, client as clients, Memory, ModelSource, Features } from '../index.js';
|
|
3
|
+
export declare class Validator {
|
|
4
|
+
readonly task: Task;
|
|
5
|
+
readonly logger: Logger;
|
|
6
|
+
private readonly memory;
|
|
7
|
+
private readonly source?;
|
|
8
|
+
private readonly client?;
|
|
9
|
+
private readonly graphInformant;
|
|
10
|
+
private size;
|
|
11
|
+
private _confusionMatrix;
|
|
12
|
+
constructor(task: Task, logger: Logger, memory: Memory, source?: ModelSource | undefined, client?: clients.Client | undefined);
|
|
13
|
+
private getLabel;
|
|
14
|
+
assess(data: data.Data, useConfusionMatrix?: boolean): Promise<Array<{
|
|
15
|
+
groundTruth: number;
|
|
16
|
+
pred: number;
|
|
17
|
+
features: Features;
|
|
18
|
+
}>>;
|
|
19
|
+
predict(data: data.Data): Promise<Array<{
|
|
20
|
+
features: Features;
|
|
21
|
+
pred: number;
|
|
22
|
+
}>>;
|
|
23
|
+
getModel(): Promise<Model>;
|
|
24
|
+
get accuracyData(): List<number>;
|
|
25
|
+
get accuracy(): number;
|
|
26
|
+
get visitedSamples(): number;
|
|
27
|
+
get confusionMatrix(): number[][] | undefined;
|
|
28
|
+
}
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
import { GraphInformant } from '../index.js';
|
|
4
|
+
export class Validator {
|
|
5
|
+
task;
|
|
6
|
+
logger;
|
|
7
|
+
memory;
|
|
8
|
+
source;
|
|
9
|
+
client;
|
|
10
|
+
graphInformant = new GraphInformant();
|
|
11
|
+
size = 0;
|
|
12
|
+
_confusionMatrix;
|
|
13
|
+
constructor(task, logger, memory, source, client) {
|
|
14
|
+
this.task = task;
|
|
15
|
+
this.logger = logger;
|
|
16
|
+
this.memory = memory;
|
|
17
|
+
this.source = source;
|
|
18
|
+
this.client = client;
|
|
19
|
+
if (source === undefined && client === undefined) {
|
|
20
|
+
throw new Error('To initialize a Validator, either or both a source and client need to be specified');
|
|
21
|
+
}
|
|
22
|
+
}
|
|
23
|
+
async getLabel(ys) {
|
|
24
|
+
switch (ys.shape[1]) {
|
|
25
|
+
case 1:
|
|
26
|
+
return await ys.greaterEqual(tf.scalar(0.5)).data();
|
|
27
|
+
case 2:
|
|
28
|
+
return await ys.argMax(1).data();
|
|
29
|
+
default:
|
|
30
|
+
throw new Error(`unable to reduce tensor of shape: ${ys.shape.toString()}`);
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
async assess(data, useConfusionMatrix = false) {
|
|
34
|
+
const batchSize = this.task.trainingInformation?.batchSize;
|
|
35
|
+
if (batchSize === undefined) {
|
|
36
|
+
throw new TypeError('Batch size is undefined');
|
|
37
|
+
}
|
|
38
|
+
const model = await this.getModel();
|
|
39
|
+
let features = [];
|
|
40
|
+
const groundTruth = [];
|
|
41
|
+
let hits = 0;
|
|
42
|
+
// Get model predictions per batch and flatten the result
|
|
43
|
+
// Also build the features and ground truth arrays
|
|
44
|
+
const predictions = (await data.preprocess().dataset.batch(batchSize)
|
|
45
|
+
.mapAsync(async (e) => {
|
|
46
|
+
if (typeof e === 'object' && 'xs' in e && 'ys' in e) {
|
|
47
|
+
const xs = e.xs;
|
|
48
|
+
const ys = await this.getLabel(e.ys);
|
|
49
|
+
const pred = await this.getLabel(await model.predict(xs));
|
|
50
|
+
const currentFeatures = await xs.array();
|
|
51
|
+
if (Array.isArray(currentFeatures)) {
|
|
52
|
+
features = features.concat(currentFeatures);
|
|
53
|
+
}
|
|
54
|
+
else {
|
|
55
|
+
throw new TypeError('Data format is incorrect');
|
|
56
|
+
}
|
|
57
|
+
groundTruth.push(...Array.from(ys));
|
|
58
|
+
this.size += xs.shape[0];
|
|
59
|
+
hits += List(pred).zip(List(ys)).filter(([p, y]) => p === y).size;
|
|
60
|
+
// TODO: Confusion Matrix stats
|
|
61
|
+
const currentAccuracy = hits / this.size;
|
|
62
|
+
this.graphInformant.updateAccuracy(currentAccuracy);
|
|
63
|
+
return Array.from(pred);
|
|
64
|
+
}
|
|
65
|
+
else {
|
|
66
|
+
throw new Error('Input data is missing a feature or the label');
|
|
67
|
+
}
|
|
68
|
+
}).toArray()).flat();
|
|
69
|
+
this.logger.success(`Obtained validation accuracy of ${this.accuracy}`);
|
|
70
|
+
this.logger.success(`Visited ${this.visitedSamples} samples`);
|
|
71
|
+
if (useConfusionMatrix) {
|
|
72
|
+
try {
|
|
73
|
+
this._confusionMatrix = tf.math.confusionMatrix([], [], 0).arraySync();
|
|
74
|
+
}
|
|
75
|
+
catch (e) {
|
|
76
|
+
console.error(e instanceof Error ? e.message : e);
|
|
77
|
+
throw new Error('Failed to compute the confusion matrix');
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
return List(groundTruth)
|
|
81
|
+
.zip(List(predictions), List(features))
|
|
82
|
+
.map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
|
|
83
|
+
.toArray();
|
|
84
|
+
}
|
|
85
|
+
async predict(data) {
|
|
86
|
+
const batchSize = this.task.trainingInformation?.batchSize;
|
|
87
|
+
if (batchSize === undefined) {
|
|
88
|
+
throw new TypeError('Batch size is undefined');
|
|
89
|
+
}
|
|
90
|
+
const model = await this.getModel();
|
|
91
|
+
let features = [];
|
|
92
|
+
// Get model prediction per batch and flatten the result
|
|
93
|
+
// Also incrementally build the features array
|
|
94
|
+
const predictions = (await data.preprocess().dataset.batch(batchSize)
|
|
95
|
+
.mapAsync(async (e) => {
|
|
96
|
+
const xs = e;
|
|
97
|
+
const currentFeatures = await xs.array();
|
|
98
|
+
if (Array.isArray(currentFeatures)) {
|
|
99
|
+
features = features.concat(currentFeatures);
|
|
100
|
+
}
|
|
101
|
+
else {
|
|
102
|
+
throw new TypeError('Data format is incorrect');
|
|
103
|
+
}
|
|
104
|
+
const pred = await this.getLabel(await model.predict(xs));
|
|
105
|
+
return Array.from(pred);
|
|
106
|
+
}).toArray()).flat();
|
|
107
|
+
return List(features).zip(List(predictions))
|
|
108
|
+
.map(([f, p]) => ({ features: f, pred: p }))
|
|
109
|
+
.toArray();
|
|
110
|
+
}
|
|
111
|
+
async getModel() {
|
|
112
|
+
if (this.source !== undefined && await this.memory.contains(this.source)) {
|
|
113
|
+
return await this.memory.getModel(this.source);
|
|
114
|
+
}
|
|
115
|
+
if (this.client !== undefined) {
|
|
116
|
+
return await this.client.getLatestModel();
|
|
117
|
+
}
|
|
118
|
+
throw new Error('Could not load the model');
|
|
119
|
+
}
|
|
120
|
+
get accuracyData() {
|
|
121
|
+
return this.graphInformant.data();
|
|
122
|
+
}
|
|
123
|
+
get accuracy() {
|
|
124
|
+
return this.graphInformant.accuracy();
|
|
125
|
+
}
|
|
126
|
+
get visitedSamples() {
|
|
127
|
+
return this.size;
|
|
128
|
+
}
|
|
129
|
+
get confusionMatrix() {
|
|
130
|
+
return this._confusionMatrix;
|
|
131
|
+
}
|
|
132
|
+
}
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import type { TensorLike } from './weights_container.js';
|
|
2
|
+
import { WeightsContainer } from './weights_container.js';
|
|
3
|
+
type WeightsLike = Iterable<TensorLike>;
|
|
4
|
+
/**
|
|
5
|
+
* Sums the given iterable of weights entry-wise.
|
|
6
|
+
* @param weights The list of weights to sum
|
|
7
|
+
* @returns The summed weights
|
|
8
|
+
*/
|
|
9
|
+
export declare function sum(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
|
|
10
|
+
/**
|
|
11
|
+
* Computes the successive entry-wise difference between the weights of the given iterable.
|
|
12
|
+
* The operation is not commutative w.r.t. the iterable's ordering.
|
|
13
|
+
*/
|
|
14
|
+
export declare function diff(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
|
|
15
|
+
/**
|
|
16
|
+
* Averages the given iterable of weights entry-wise.
|
|
17
|
+
* @param weights The list of weights to average
|
|
18
|
+
* @returns The averaged weights
|
|
19
|
+
*/
|
|
20
|
+
export declare function avg(weights: Iterable<WeightsLike | WeightsContainer>): WeightsContainer;
|
|
21
|
+
export {};
|