@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/{core/task → task}/task.d.ts +7 -7
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +24 -15
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/default_tasks/cifar10.d.ts +0 -2
- package/dist/core/default_tasks/geotags.d.ts +0 -2
- package/dist/core/default_tasks/geotags.js +0 -69
- package/dist/core/default_tasks/index.d.ts +0 -6
- package/dist/core/default_tasks/index.js +0 -15
- package/dist/core/default_tasks/lus_covid.d.ts +0 -2
- package/dist/core/default_tasks/lus_covid.js +0 -96
- package/dist/core/default_tasks/mnist.d.ts +0 -2
- package/dist/core/default_tasks/simple_face.d.ts +0 -2
- package/dist/core/default_tasks/titanic.d.ts +0 -2
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -39
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/digest.js +0 -18
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -6
- package/dist/core/task/index.js +0 -15
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.js +0 -35
- package/dist/core/task/task_handler.d.ts +0 -5
- package/dist/core/task/task_handler.js +0 -53
- package/dist/core/task/task_provider.d.ts +0 -6
- package/dist/core/task/task_provider.js +0 -13
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -7
- package/dist/core/weights/aggregation.js +0 -72
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/digest.d.ts +0 -0
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -1,7 +1,4 @@
|
|
|
1
|
-
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.addDifferentialPrivacy = void 0;
|
|
4
|
-
var _1 = require(".");
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
5
2
|
/**
|
|
6
3
|
* Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
|
|
7
4
|
* The previous round's weights are the last weights pulled from server/peers.
|
|
@@ -11,20 +8,19 @@ var _1 = require(".");
|
|
|
11
8
|
* @param task the task
|
|
12
9
|
* @returns the noised weights for the current round
|
|
13
10
|
*/
|
|
14
|
-
function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
var newWeightsDiff;
|
|
11
|
+
export function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
|
|
12
|
+
const noiseScale = task.trainingInformation?.noiseScale;
|
|
13
|
+
const clippingRadius = task.trainingInformation?.clippingRadius;
|
|
14
|
+
const weightsDiff = updatedWeights.sub(staleWeights);
|
|
15
|
+
let newWeightsDiff;
|
|
20
16
|
if (clippingRadius !== undefined) {
|
|
21
17
|
// Frobenius norm
|
|
22
|
-
|
|
23
|
-
newWeightsDiff = weightsDiff.map(
|
|
24
|
-
|
|
18
|
+
const norm = weightsDiff.frobeniusNorm();
|
|
19
|
+
newWeightsDiff = weightsDiff.map((w) => {
|
|
20
|
+
const clipped = w.div(Math.max(1, norm / clippingRadius));
|
|
25
21
|
if (noiseScale !== undefined) {
|
|
26
22
|
// Add clipping and noise
|
|
27
|
-
|
|
23
|
+
const noise = tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
|
|
28
24
|
return clipped.add(noise);
|
|
29
25
|
}
|
|
30
26
|
else {
|
|
@@ -36,7 +32,7 @@ function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
|
|
|
36
32
|
else {
|
|
37
33
|
if (noiseScale !== undefined) {
|
|
38
34
|
// Add noise without any clipping
|
|
39
|
-
newWeightsDiff = weightsDiff.map(
|
|
35
|
+
newWeightsDiff = weightsDiff.map((w) => tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)));
|
|
40
36
|
}
|
|
41
37
|
else {
|
|
42
38
|
return updatedWeights;
|
|
@@ -44,4 +40,3 @@ function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
|
|
|
44
40
|
}
|
|
45
41
|
return staleWeights.add(newWeightsDiff);
|
|
46
42
|
}
|
|
47
|
-
exports.addDifferentialPrivacy = addDifferentialPrivacy;
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
import type { Model } from '../index.js';
|
|
2
|
+
export type Encoded = Uint8Array;
|
|
3
|
+
export declare function isEncoded(raw: unknown): raw is Encoded;
|
|
4
|
+
export declare function encode(model: Model): Promise<Encoded>;
|
|
5
|
+
export declare function decode(encoded: unknown): Promise<Model>;
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import msgpack from 'msgpack-lite';
|
|
2
|
+
import { models, serialization } from '../index.js';
|
|
3
|
+
const Type = {
|
|
4
|
+
TFJS: 0,
|
|
5
|
+
GPT: 1
|
|
6
|
+
};
|
|
7
|
+
export function isEncoded(raw) {
|
|
8
|
+
return raw instanceof Uint8Array;
|
|
9
|
+
}
|
|
10
|
+
export async function encode(model) {
|
|
11
|
+
if (model instanceof models.TFJS) {
|
|
12
|
+
const serialized = await model.serialize();
|
|
13
|
+
return msgpack.encode([Type.TFJS, serialized]);
|
|
14
|
+
}
|
|
15
|
+
if (model instanceof models.GPT) {
|
|
16
|
+
const { weights, config } = model.serialize();
|
|
17
|
+
const serializedWeights = await serialization.weights.encode(weights);
|
|
18
|
+
return msgpack.encode([Type.GPT, serializedWeights, config]);
|
|
19
|
+
}
|
|
20
|
+
throw new Error('unknown model type');
|
|
21
|
+
}
|
|
22
|
+
export async function decode(encoded) {
|
|
23
|
+
if (!isEncoded(encoded)) {
|
|
24
|
+
throw new Error("Invalid encoding, raw encoding isn't an instance of Uint8Array");
|
|
25
|
+
}
|
|
26
|
+
const raw = msgpack.decode(encoded);
|
|
27
|
+
if (!Array.isArray(raw) || raw.length < 2) {
|
|
28
|
+
throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values");
|
|
29
|
+
}
|
|
30
|
+
const type = raw[0];
|
|
31
|
+
if (typeof type !== 'number') {
|
|
32
|
+
throw new Error('invalid encoding, first encoding field should be the model type');
|
|
33
|
+
}
|
|
34
|
+
const rawModel = raw[1];
|
|
35
|
+
switch (type) {
|
|
36
|
+
case Type.TFJS:
|
|
37
|
+
if (raw.length !== 2) {
|
|
38
|
+
throw new Error('invalid encoding, TFJS model encoding should be an array of length 2');
|
|
39
|
+
}
|
|
40
|
+
// TODO totally unsafe casting
|
|
41
|
+
return await models.TFJS.deserialize(rawModel);
|
|
42
|
+
case Type.GPT: {
|
|
43
|
+
let config;
|
|
44
|
+
if (raw.length == 2) {
|
|
45
|
+
config = undefined;
|
|
46
|
+
}
|
|
47
|
+
else if (raw.length == 3) {
|
|
48
|
+
config = raw[2];
|
|
49
|
+
}
|
|
50
|
+
else {
|
|
51
|
+
throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3');
|
|
52
|
+
}
|
|
53
|
+
if (!Array.isArray(rawModel)) {
|
|
54
|
+
throw new Error('invalid encoding, gpt-tfjs model weights should be an array');
|
|
55
|
+
}
|
|
56
|
+
const arr = rawModel;
|
|
57
|
+
if (arr.some((r) => typeof r !== 'number')) {
|
|
58
|
+
throw new Error("invalid encoding, gpt-tfjs weights should be numbers");
|
|
59
|
+
}
|
|
60
|
+
const nums = arr;
|
|
61
|
+
const weights = serialization.weights.decode(nums);
|
|
62
|
+
return models.GPT.deserialize({ weights, config });
|
|
63
|
+
}
|
|
64
|
+
default:
|
|
65
|
+
throw new Error('invalid encoding, model type unrecognized');
|
|
66
|
+
}
|
|
67
|
+
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { WeightsContainer } from '
|
|
2
|
-
export
|
|
1
|
+
import { WeightsContainer } from '../index.js';
|
|
2
|
+
export type Encoded = number[];
|
|
3
3
|
export declare function isEncoded(raw: unknown): raw is Encoded;
|
|
4
4
|
export declare function encode(weights: WeightsContainer): Promise<Encoded>;
|
|
5
5
|
export declare function decode(encoded: Encoded): WeightsContainer;
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import * as msgpack from 'msgpack-lite';
|
|
2
|
+
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
import { WeightsContainer } from '../index.js';
|
|
4
|
+
function isSerialized(raw) {
|
|
5
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
6
|
+
return false;
|
|
7
|
+
}
|
|
8
|
+
const { shape, data } = raw;
|
|
9
|
+
if (!(Array.isArray(shape) && shape.every((e) => typeof e === 'number')) ||
|
|
10
|
+
!(Array.isArray(data) && data.every((e) => typeof e === 'number'))) {
|
|
11
|
+
return false;
|
|
12
|
+
}
|
|
13
|
+
const _ = {
|
|
14
|
+
shape: shape,
|
|
15
|
+
data: data,
|
|
16
|
+
};
|
|
17
|
+
return true;
|
|
18
|
+
}
|
|
19
|
+
export function isEncoded(raw) {
|
|
20
|
+
return Array.isArray(raw) && raw.every((e) => typeof e === 'number');
|
|
21
|
+
}
|
|
22
|
+
export async function encode(weights) {
|
|
23
|
+
const serialized = await Promise.all(weights.weights.map(async (t) => {
|
|
24
|
+
return {
|
|
25
|
+
shape: t.shape,
|
|
26
|
+
data: [...await t.data()]
|
|
27
|
+
};
|
|
28
|
+
}));
|
|
29
|
+
return [...msgpack.encode(serialized).values()];
|
|
30
|
+
}
|
|
31
|
+
export function decode(encoded) {
|
|
32
|
+
const raw = msgpack.decode(encoded);
|
|
33
|
+
if (!(Array.isArray(raw) && raw.every(isSerialized))) {
|
|
34
|
+
throw new Error('expected to decode an array of serialized weights');
|
|
35
|
+
}
|
|
36
|
+
return new WeightsContainer(raw.map((w) => tf.tensor(w.data, w.shape)));
|
|
37
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
export function isDataExample(raw) {
|
|
2
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
3
|
+
return false;
|
|
4
|
+
}
|
|
5
|
+
const { columnName, columnData } = raw;
|
|
6
|
+
if (typeof columnName !== 'string' ||
|
|
7
|
+
(typeof columnData !== 'string' && typeof columnData !== 'number')) {
|
|
8
|
+
return false;
|
|
9
|
+
}
|
|
10
|
+
const repack = { columnName, columnData };
|
|
11
|
+
const _correct = repack;
|
|
12
|
+
const _total = repack;
|
|
13
|
+
return true;
|
|
14
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
export function isDigest(raw) {
|
|
2
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
3
|
+
return false;
|
|
4
|
+
}
|
|
5
|
+
const { algorithm, value } = raw;
|
|
6
|
+
if (!(typeof algorithm === 'string' &&
|
|
7
|
+
typeof value === 'string')) {
|
|
8
|
+
return false;
|
|
9
|
+
}
|
|
10
|
+
const repack = { algorithm, value };
|
|
11
|
+
const _correct = repack;
|
|
12
|
+
const _total = repack;
|
|
13
|
+
return true;
|
|
14
|
+
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { Summary } from './summary';
|
|
2
|
-
import { DataExample } from './data_example';
|
|
3
|
-
|
|
1
|
+
import { type Summary } from './summary.js';
|
|
2
|
+
import { type DataExample } from './data_example.js';
|
|
3
|
+
import { type LabelType } from './label_type.js';
|
|
4
4
|
export interface DisplayInformation {
|
|
5
5
|
taskTitle?: string;
|
|
6
6
|
summary?: Summary;
|
|
@@ -12,4 +12,6 @@ export interface DisplayInformation {
|
|
|
12
12
|
headers?: string[];
|
|
13
13
|
dataExampleImage?: string;
|
|
14
14
|
limitations?: string;
|
|
15
|
+
labelDisplay?: LabelType;
|
|
15
16
|
}
|
|
17
|
+
export declare function isDisplayInformation(raw: unknown): raw is DisplayInformation;
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import { isSummary } from './summary.js';
|
|
2
|
+
import { isDataExample } from './data_example.js';
|
|
3
|
+
import { isLabelType } from './label_type.js';
|
|
4
|
+
export function isDisplayInformation(raw) {
|
|
5
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
6
|
+
return false;
|
|
7
|
+
}
|
|
8
|
+
const { dataExample, dataExampleImage, dataExampleText, dataFormatInformation, headers, labelDisplay, limitations, model, summary, taskTitle, tradeoffs } = raw;
|
|
9
|
+
if (typeof taskTitle !== 'string' ||
|
|
10
|
+
(dataExampleText !== undefined && typeof dataExampleText !== 'string') ||
|
|
11
|
+
(dataFormatInformation !== undefined && typeof dataFormatInformation !== 'string') ||
|
|
12
|
+
(tradeoffs !== undefined && typeof tradeoffs !== 'string') ||
|
|
13
|
+
(model !== undefined && typeof model !== 'string') ||
|
|
14
|
+
(dataExampleImage !== undefined && typeof dataExampleImage !== 'string') ||
|
|
15
|
+
(labelDisplay !== undefined && !isLabelType(labelDisplay)) ||
|
|
16
|
+
(limitations !== undefined && typeof limitations !== 'string')) {
|
|
17
|
+
return false;
|
|
18
|
+
}
|
|
19
|
+
if (summary !== undefined && !isSummary(summary)) {
|
|
20
|
+
return false;
|
|
21
|
+
}
|
|
22
|
+
if (dataExample !== undefined && !(Array.isArray(dataExample) &&
|
|
23
|
+
dataExample.every(isDataExample))) {
|
|
24
|
+
return false;
|
|
25
|
+
}
|
|
26
|
+
if (headers !== undefined && !(Array.isArray(headers) &&
|
|
27
|
+
headers.every((e) => typeof e === 'string'))) {
|
|
28
|
+
return false;
|
|
29
|
+
}
|
|
30
|
+
const repack = {
|
|
31
|
+
dataExample,
|
|
32
|
+
dataExampleImage,
|
|
33
|
+
dataExampleText,
|
|
34
|
+
dataFormatInformation,
|
|
35
|
+
headers,
|
|
36
|
+
labelDisplay,
|
|
37
|
+
limitations,
|
|
38
|
+
model,
|
|
39
|
+
summary,
|
|
40
|
+
taskTitle,
|
|
41
|
+
tradeoffs,
|
|
42
|
+
};
|
|
43
|
+
const _correct = repack;
|
|
44
|
+
const _total = repack;
|
|
45
|
+
return true;
|
|
46
|
+
}
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
export { isTask, type Task, isTaskID, type TaskID } from './task.js';
|
|
2
|
+
export { type TaskProvider } from './task_provider.js';
|
|
3
|
+
export { isDigest, type Digest } from './digest.js';
|
|
4
|
+
export { isDisplayInformation, type DisplayInformation } from './display_information.js';
|
|
5
|
+
export type { TrainingInformation } from './training_information.js';
|
|
6
|
+
export { pushTask, fetchTasks } from './task_handler.js';
|
|
7
|
+
export { LabelTypeEnum } from './label_type.js';
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
export var LabelTypeEnum;
|
|
2
|
+
(function (LabelTypeEnum) {
|
|
3
|
+
LabelTypeEnum[LabelTypeEnum["TEXT"] = 0] = "TEXT";
|
|
4
|
+
LabelTypeEnum[LabelTypeEnum["POLYGON_MAP"] = 1] = "POLYGON_MAP";
|
|
5
|
+
})(LabelTypeEnum || (LabelTypeEnum = {}));
|
|
6
|
+
function isLabelTypeEnum(raw) {
|
|
7
|
+
switch (raw) {
|
|
8
|
+
case LabelTypeEnum.TEXT: break;
|
|
9
|
+
case LabelTypeEnum.POLYGON_MAP: break;
|
|
10
|
+
default: return false;
|
|
11
|
+
}
|
|
12
|
+
const _ = raw;
|
|
13
|
+
return true;
|
|
14
|
+
}
|
|
15
|
+
export function isLabelType(raw) {
|
|
16
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
17
|
+
return false;
|
|
18
|
+
}
|
|
19
|
+
const { labelType, mapBaseUrl } = raw;
|
|
20
|
+
if (!isLabelTypeEnum(labelType) ||
|
|
21
|
+
(mapBaseUrl !== undefined && typeof mapBaseUrl !== 'string')) {
|
|
22
|
+
return false;
|
|
23
|
+
}
|
|
24
|
+
const repack = { labelType, mapBaseUrl };
|
|
25
|
+
const _correct = repack;
|
|
26
|
+
const _total = repack;
|
|
27
|
+
return true;
|
|
28
|
+
}
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
export function isSummary(raw) {
|
|
2
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
3
|
+
return false;
|
|
4
|
+
}
|
|
5
|
+
const { preview, overview } = raw;
|
|
6
|
+
if (!(typeof preview === 'string' && typeof overview === 'string')) {
|
|
7
|
+
return false;
|
|
8
|
+
}
|
|
9
|
+
const repack = { preview, overview };
|
|
10
|
+
const _correct = repack;
|
|
11
|
+
const _total = repack;
|
|
12
|
+
return true;
|
|
13
|
+
}
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
import { DisplayInformation } from './display_information';
|
|
2
|
-
import { TrainingInformation } from './training_information';
|
|
3
|
-
import { Digest } from './digest';
|
|
4
|
-
export
|
|
5
|
-
export declare function isTaskID(obj: unknown): obj is TaskID;
|
|
6
|
-
export declare function isTask(raw: unknown): raw is Task;
|
|
1
|
+
import { type DisplayInformation } from './display_information.js';
|
|
2
|
+
import { type TrainingInformation } from './training_information.js';
|
|
3
|
+
import { type Digest } from './digest.js';
|
|
4
|
+
export type TaskID = string;
|
|
7
5
|
export interface Task {
|
|
8
|
-
|
|
6
|
+
id: TaskID;
|
|
9
7
|
digest?: Digest;
|
|
10
8
|
displayInformation: DisplayInformation;
|
|
11
9
|
trainingInformation: TrainingInformation;
|
|
12
10
|
}
|
|
11
|
+
export declare function isTaskID(obj: unknown): obj is TaskID;
|
|
12
|
+
export declare function isTask(raw: unknown): raw is Task;
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import { isDisplayInformation } from './display_information.js';
|
|
2
|
+
import { isTrainingInformation } from './training_information.js';
|
|
3
|
+
import { isDigest } from './digest.js';
|
|
4
|
+
export function isTaskID(obj) {
|
|
5
|
+
return typeof obj === 'string';
|
|
6
|
+
}
|
|
7
|
+
export function isTask(raw) {
|
|
8
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
9
|
+
return false;
|
|
10
|
+
}
|
|
11
|
+
const { id, digest, displayInformation, trainingInformation } = raw;
|
|
12
|
+
if (!isTaskID(id) ||
|
|
13
|
+
(digest !== undefined && !isDigest(digest)) ||
|
|
14
|
+
!isDisplayInformation(displayInformation) ||
|
|
15
|
+
!isTrainingInformation(trainingInformation)) {
|
|
16
|
+
return false;
|
|
17
|
+
}
|
|
18
|
+
const repack = { id, digest, displayInformation, trainingInformation };
|
|
19
|
+
const _correct = repack;
|
|
20
|
+
const _total = repack;
|
|
21
|
+
return true;
|
|
22
|
+
}
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
import { Map } from 'immutable';
|
|
2
|
+
import type { Model } from '../index.js';
|
|
3
|
+
import type { Task, TaskID } from './task.js';
|
|
4
|
+
export declare function pushTask(url: URL, task: Task, model: Model): Promise<void>;
|
|
5
|
+
export declare function fetchTasks(url: URL): Promise<Map<TaskID, Task>>;
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import axios from 'axios';
|
|
2
|
+
import { Map } from 'immutable';
|
|
3
|
+
import { serialization } from '../index.js';
|
|
4
|
+
import { isTask } from './task.js';
|
|
5
|
+
const TASK_ENDPOINT = 'tasks';
|
|
6
|
+
export async function pushTask(url, task, model) {
|
|
7
|
+
await axios.post(url.href + TASK_ENDPOINT, {
|
|
8
|
+
task,
|
|
9
|
+
model: await serialization.model.encode(model),
|
|
10
|
+
weights: await serialization.weights.encode(model.weights)
|
|
11
|
+
});
|
|
12
|
+
}
|
|
13
|
+
export async function fetchTasks(url) {
|
|
14
|
+
const response = await axios.get(new URL(TASK_ENDPOINT, url).href);
|
|
15
|
+
const tasks = response.data;
|
|
16
|
+
if (!(Array.isArray(tasks) && tasks.every(isTask))) {
|
|
17
|
+
throw new Error('invalid tasks response');
|
|
18
|
+
}
|
|
19
|
+
return Map(tasks.map((t) => [t.id, t]));
|
|
20
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
|
|
1
|
+
import type { AggregatorChoice } from '../aggregator/get.js';
|
|
2
|
+
import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
|
|
3
|
+
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
4
4
|
export interface TrainingInformation {
|
|
5
5
|
modelID: string;
|
|
6
6
|
epochs: number;
|
|
@@ -8,21 +8,20 @@ export interface TrainingInformation {
|
|
|
8
8
|
validationSplit: number;
|
|
9
9
|
batchSize: number;
|
|
10
10
|
preprocessingFunctions?: Preprocessing[];
|
|
11
|
-
|
|
12
|
-
dataType: string;
|
|
11
|
+
dataType: 'image' | 'tabular' | 'text';
|
|
13
12
|
inputColumns?: string[];
|
|
14
13
|
outputColumns?: string[];
|
|
15
14
|
IMAGE_H?: number;
|
|
16
15
|
IMAGE_W?: number;
|
|
17
|
-
modelURL?: string;
|
|
18
16
|
LABEL_LIST?: string[];
|
|
19
|
-
|
|
20
|
-
scheme: string;
|
|
17
|
+
scheme: 'decentralized' | 'federated' | 'local';
|
|
21
18
|
noiseScale?: number;
|
|
22
19
|
clippingRadius?: number;
|
|
23
20
|
decentralizedSecure?: boolean;
|
|
24
|
-
byzantineRobustAggregator?: boolean;
|
|
25
|
-
tauPercentile?: number;
|
|
26
21
|
maxShareValue?: number;
|
|
27
22
|
minimumReadyPeers?: number;
|
|
23
|
+
aggregator?: AggregatorChoice;
|
|
24
|
+
tokenizer?: string | PreTrainedTokenizer;
|
|
25
|
+
maxSequenceLength?: number;
|
|
28
26
|
}
|
|
27
|
+
export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
2
|
+
function isStringArray(raw) {
|
|
3
|
+
if (!Array.isArray(raw)) {
|
|
4
|
+
return false;
|
|
5
|
+
}
|
|
6
|
+
const arr = raw; // isArray is unsafely guarding with any[]
|
|
7
|
+
return arr.every((e) => typeof e === 'string');
|
|
8
|
+
}
|
|
9
|
+
export function isTrainingInformation(raw) {
|
|
10
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
11
|
+
return false;
|
|
12
|
+
}
|
|
13
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, clippingRadius, dataType, decentralizedSecure, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, noiseScale, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, } = raw;
|
|
14
|
+
if (typeof dataType !== 'string' ||
|
|
15
|
+
typeof modelID !== 'string' ||
|
|
16
|
+
typeof epochs !== 'number' ||
|
|
17
|
+
typeof batchSize !== 'number' ||
|
|
18
|
+
typeof roundDuration !== 'number' ||
|
|
19
|
+
typeof validationSplit !== 'number' ||
|
|
20
|
+
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
21
|
+
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
22
|
+
(aggregator !== undefined && typeof aggregator !== 'number') ||
|
|
23
|
+
(clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
|
|
24
|
+
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
25
|
+
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
26
|
+
(minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
|
|
27
|
+
(noiseScale !== undefined && typeof noiseScale !== 'number') ||
|
|
28
|
+
(IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
|
|
29
|
+
(IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
|
|
30
|
+
(LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
|
|
31
|
+
(inputColumns !== undefined && !isStringArray(inputColumns)) ||
|
|
32
|
+
(outputColumns !== undefined && !isStringArray(outputColumns)) ||
|
|
33
|
+
(preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
|
|
34
|
+
return false;
|
|
35
|
+
}
|
|
36
|
+
switch (dataType) {
|
|
37
|
+
case 'image': break;
|
|
38
|
+
case 'tabular': break;
|
|
39
|
+
case 'text': break;
|
|
40
|
+
default: return false;
|
|
41
|
+
}
|
|
42
|
+
// interdepences on data type
|
|
43
|
+
if (dataType === 'image') {
|
|
44
|
+
if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
|
|
45
|
+
return false;
|
|
46
|
+
}
|
|
47
|
+
}
|
|
48
|
+
else if (dataType in ['text', 'tabular']) {
|
|
49
|
+
if (!(Array.isArray(inputColumns) && inputColumns.every((e) => typeof e === 'string'))) {
|
|
50
|
+
return false;
|
|
51
|
+
}
|
|
52
|
+
if (!(Array.isArray(outputColumns) && outputColumns.every((e) => typeof e === 'string'))) {
|
|
53
|
+
return false;
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
switch (scheme) {
|
|
57
|
+
case 'decentralized': break;
|
|
58
|
+
case 'federated': break;
|
|
59
|
+
case 'local': break;
|
|
60
|
+
default: return false;
|
|
61
|
+
}
|
|
62
|
+
const repack = {
|
|
63
|
+
IMAGE_W,
|
|
64
|
+
IMAGE_H,
|
|
65
|
+
LABEL_LIST,
|
|
66
|
+
aggregator,
|
|
67
|
+
batchSize,
|
|
68
|
+
clippingRadius,
|
|
69
|
+
dataType,
|
|
70
|
+
decentralizedSecure,
|
|
71
|
+
epochs,
|
|
72
|
+
inputColumns,
|
|
73
|
+
maxShareValue,
|
|
74
|
+
minimumReadyPeers,
|
|
75
|
+
modelID,
|
|
76
|
+
noiseScale,
|
|
77
|
+
outputColumns,
|
|
78
|
+
preprocessingFunctions,
|
|
79
|
+
roundDuration,
|
|
80
|
+
scheme,
|
|
81
|
+
validationSplit,
|
|
82
|
+
tokenizer,
|
|
83
|
+
maxSequenceLength
|
|
84
|
+
};
|
|
85
|
+
const _correct = repack;
|
|
86
|
+
const _total = repack;
|
|
87
|
+
return true;
|
|
88
|
+
}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import type { data, Logger, Memory, Task, TrainingInformation } from '../index.js';
|
|
2
|
+
import { client as clients } from '../index.js';
|
|
3
|
+
import type { Aggregator } from '../aggregator/index.js';
|
|
4
|
+
import type { RoundLogs } from './trainer/trainer.js';
|
|
5
|
+
export interface DiscoOptions {
|
|
6
|
+
client?: clients.Client;
|
|
7
|
+
aggregator?: Aggregator;
|
|
8
|
+
url?: string | URL;
|
|
9
|
+
scheme?: TrainingInformation['scheme'];
|
|
10
|
+
logger?: Logger;
|
|
11
|
+
memory?: Memory;
|
|
12
|
+
}
|
|
13
|
+
/**
|
|
14
|
+
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
15
|
+
* a convenient object providing a reduced yet complete API that wraps model training,
|
|
16
|
+
* communication with nodes, logs and model memory.
|
|
17
|
+
*/
|
|
18
|
+
export declare class Disco {
|
|
19
|
+
readonly task: Task;
|
|
20
|
+
readonly logger: Logger;
|
|
21
|
+
readonly memory: Memory;
|
|
22
|
+
private readonly client;
|
|
23
|
+
private readonly trainer;
|
|
24
|
+
constructor(task: Task, options: DiscoOptions);
|
|
25
|
+
/**
|
|
26
|
+
* Starts a training instance for the Disco object's task on the provided data tuple.
|
|
27
|
+
* @param dataTuple The data tuple
|
|
28
|
+
*/
|
|
29
|
+
fit(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
|
|
30
|
+
participants: number;
|
|
31
|
+
}>;
|
|
32
|
+
/**
|
|
33
|
+
* Stops the ongoing training instance without disconnecting the client.
|
|
34
|
+
*/
|
|
35
|
+
pause(): Promise<void>;
|
|
36
|
+
/**
|
|
37
|
+
* Completely stops the ongoing training instance.
|
|
38
|
+
*/
|
|
39
|
+
close(): Promise<void>;
|
|
40
|
+
}
|