@epfml/discojs 2.1.1 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/{core/default_tasks/cifar10.js → default_tasks/cifar10/index.js} +28 -36
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/mnist.js +26 -34
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/{core/default_tasks/simple_face.js → default_tasks/simple_face/index.js} +17 -22
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/{core/default_tasks → default_tasks}/titanic.js +24 -33
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/{core/task → task}/task.d.ts +7 -7
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +24 -15
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/default_tasks/cifar10.d.ts +0 -2
- package/dist/core/default_tasks/geotags.d.ts +0 -2
- package/dist/core/default_tasks/geotags.js +0 -69
- package/dist/core/default_tasks/index.d.ts +0 -6
- package/dist/core/default_tasks/index.js +0 -15
- package/dist/core/default_tasks/lus_covid.d.ts +0 -2
- package/dist/core/default_tasks/lus_covid.js +0 -96
- package/dist/core/default_tasks/mnist.d.ts +0 -2
- package/dist/core/default_tasks/simple_face.d.ts +0 -2
- package/dist/core/default_tasks/titanic.d.ts +0 -2
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -39
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/digest.js +0 -18
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -6
- package/dist/core/task/index.js +0 -15
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.js +0 -35
- package/dist/core/task/task_handler.d.ts +0 -5
- package/dist/core/task/task_handler.js +0 -53
- package/dist/core/task/task_provider.d.ts +0 -6
- package/dist/core/task/task_provider.js +0 -13
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -7
- package/dist/core/weights/aggregation.js +0 -72
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/digest.d.ts +0 -0
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import { List, Map, Range, Seq } from 'immutable';
|
|
2
|
+
import wrtc from 'isomorphic-wrtc';
|
|
3
|
+
import SimplePeer from 'simple-peer';
|
|
4
|
+
// message id + (chunk counter == 0) + chunk count
|
|
5
|
+
const FIRST_HEADER_SIZE = 2 + 1 + 1;
|
|
6
|
+
// message id + chunk counter
|
|
7
|
+
const HEADER_SIZE = 2 + 1;
|
|
8
|
+
// at which interval to poll
|
|
9
|
+
const TICK = 10;
|
|
10
|
+
// Peer wraps a SimplePeer, adding message fragmentation
|
|
11
|
+
//
|
|
12
|
+
// WebRTC implementations have various maximum message size
|
|
13
|
+
// but with huge models, our messages might be bigger.
|
|
14
|
+
// We split messages by chunks and reconstruct these
|
|
15
|
+
// on the other side.
|
|
16
|
+
//
|
|
17
|
+
// As the WebRTC's DataChannel is not a stream, we need
|
|
18
|
+
// reorder messages, so we use a header on each chunk
|
|
19
|
+
// with a message id and chunk counter. The first chunk
|
|
20
|
+
// (chunk counter == 0), also add the total number of chunk.
|
|
21
|
+
//
|
|
22
|
+
// see feross/simple-peer#393 for more info
|
|
23
|
+
export class Peer {
|
|
24
|
+
id;
|
|
25
|
+
peer;
|
|
26
|
+
bufferSize;
|
|
27
|
+
sendCounter = 0;
|
|
28
|
+
sendQueue = List();
|
|
29
|
+
receiving = Map();
|
|
30
|
+
constructor(id, initiator = false) {
|
|
31
|
+
this.id = id;
|
|
32
|
+
this.peer = new SimplePeer({ wrtc, initiator });
|
|
33
|
+
}
|
|
34
|
+
send(msg) {
|
|
35
|
+
const chunks = this.chunk(msg);
|
|
36
|
+
this.sendQueue = this.sendQueue.concat(chunks);
|
|
37
|
+
this.flush();
|
|
38
|
+
}
|
|
39
|
+
flush() {
|
|
40
|
+
if (this.bufferSize === undefined) {
|
|
41
|
+
throw new Error('flush without known buffer size');
|
|
42
|
+
}
|
|
43
|
+
const chunk = this.sendQueue.first();
|
|
44
|
+
if (chunk === undefined) {
|
|
45
|
+
return; // nothing to flush
|
|
46
|
+
}
|
|
47
|
+
const remainingBufferSize = this.bufferSize - this.peer.bufferSize;
|
|
48
|
+
if (chunk.length > remainingBufferSize) {
|
|
49
|
+
setTimeout(() => { this.flush(); }, TICK);
|
|
50
|
+
return;
|
|
51
|
+
}
|
|
52
|
+
this.sendQueue = this.sendQueue.shift();
|
|
53
|
+
this.peer.send(chunk);
|
|
54
|
+
// and loop
|
|
55
|
+
this.flush();
|
|
56
|
+
}
|
|
57
|
+
get maxChunkSize() {
|
|
58
|
+
if (this.bufferSize === undefined) {
|
|
59
|
+
throw new Error('chunk without known buffer size');
|
|
60
|
+
}
|
|
61
|
+
// in the perfect world of bug-free implementations
|
|
62
|
+
// we would return this.bufferSize
|
|
63
|
+
// sadly, we are not there yet
|
|
64
|
+
//
|
|
65
|
+
// based on MDN, taking 16K seems to be a pretty safe
|
|
66
|
+
// and widely supported buffer size
|
|
67
|
+
return 16 * (1 << 10);
|
|
68
|
+
}
|
|
69
|
+
chunk(b) {
|
|
70
|
+
const messageID = this.sendCounter;
|
|
71
|
+
this.sendCounter++;
|
|
72
|
+
if (this.sendCounter > 0xFFFF) {
|
|
73
|
+
throw new Error('too much messages sent to this peer');
|
|
74
|
+
}
|
|
75
|
+
// special case as Range(1, 0) yields a value
|
|
76
|
+
let tail = Seq.Indexed([]);
|
|
77
|
+
if (b.length > this.maxChunkSize) {
|
|
78
|
+
tail = Range(this.maxChunkSize - FIRST_HEADER_SIZE, b.length, this.maxChunkSize - HEADER_SIZE).map((offset) => b.subarray(offset, offset + this.maxChunkSize - HEADER_SIZE));
|
|
79
|
+
}
|
|
80
|
+
const totalChunkCount = 1 + tail.count();
|
|
81
|
+
if (totalChunkCount > 0xFF) {
|
|
82
|
+
throw new Error('too big message to even chunk it');
|
|
83
|
+
}
|
|
84
|
+
const firstChunk = Buffer.alloc((b.length > this.maxChunkSize - FIRST_HEADER_SIZE)
|
|
85
|
+
? this.maxChunkSize
|
|
86
|
+
: FIRST_HEADER_SIZE + b.length);
|
|
87
|
+
firstChunk.writeUint16BE(messageID);
|
|
88
|
+
firstChunk.writeUint8(0, 2);
|
|
89
|
+
firstChunk.writeUint8(totalChunkCount, 3);
|
|
90
|
+
b.copy(firstChunk, FIRST_HEADER_SIZE, 0, this.maxChunkSize - FIRST_HEADER_SIZE);
|
|
91
|
+
return Seq.Indexed([firstChunk])
|
|
92
|
+
.concat(Range(1).zip(tail)
|
|
93
|
+
.map(([id, raw]) => {
|
|
94
|
+
const chunk = Buffer.alloc(HEADER_SIZE + raw.length);
|
|
95
|
+
chunk.writeUint16BE(messageID);
|
|
96
|
+
chunk.writeUint8(id, 2);
|
|
97
|
+
raw.copy(chunk, HEADER_SIZE, 0);
|
|
98
|
+
return chunk;
|
|
99
|
+
}));
|
|
100
|
+
}
|
|
101
|
+
async destroy() {
|
|
102
|
+
return new Promise((resolve, reject) => {
|
|
103
|
+
this.peer.once('error', reject);
|
|
104
|
+
this.peer.once('close', resolve);
|
|
105
|
+
this.peer.destroy();
|
|
106
|
+
});
|
|
107
|
+
}
|
|
108
|
+
signal(signal) {
|
|
109
|
+
// extract max buffer size
|
|
110
|
+
if (signal.type === 'offer' || signal.type === 'answer') {
|
|
111
|
+
if (signal.sdp === undefined) {
|
|
112
|
+
throw new Error('signal answer|offer without session description');
|
|
113
|
+
}
|
|
114
|
+
if (this.bufferSize !== undefined) {
|
|
115
|
+
throw new Error('buffer size set twice');
|
|
116
|
+
}
|
|
117
|
+
const match = signal.sdp.match(/a=max-message-size:(\d+)/);
|
|
118
|
+
if (match === null) {
|
|
119
|
+
// TODO default value instead?
|
|
120
|
+
throw new Error('no max-message-size found in signal');
|
|
121
|
+
}
|
|
122
|
+
const max = parseInt(match[1], 10);
|
|
123
|
+
if (isNaN(max)) {
|
|
124
|
+
throw new Error(`unable to parse max-message-size as int: ${match[1]}`);
|
|
125
|
+
}
|
|
126
|
+
this.bufferSize = max;
|
|
127
|
+
}
|
|
128
|
+
this.peer.signal(signal);
|
|
129
|
+
}
|
|
130
|
+
on(event, listener) {
|
|
131
|
+
if (event !== 'data') {
|
|
132
|
+
this.peer.on(event, listener);
|
|
133
|
+
return;
|
|
134
|
+
}
|
|
135
|
+
// gotta help typescript here
|
|
136
|
+
const dataListener = listener;
|
|
137
|
+
this.peer.on('data', (data) => {
|
|
138
|
+
if (!Buffer.isBuffer(data) || data.length < HEADER_SIZE) {
|
|
139
|
+
throw new Error('received invalid message type');
|
|
140
|
+
}
|
|
141
|
+
const messageID = data.readUint16BE();
|
|
142
|
+
const chunkID = data.readUint8(2);
|
|
143
|
+
const received = this.receiving.get(messageID, {
|
|
144
|
+
total: undefined,
|
|
145
|
+
chunks: Map()
|
|
146
|
+
});
|
|
147
|
+
let total = received.total;
|
|
148
|
+
const chunks = received.chunks;
|
|
149
|
+
if (chunks.has(chunkID)) {
|
|
150
|
+
throw new Error(`chunk ${messageID}:${chunkID} already received`);
|
|
151
|
+
}
|
|
152
|
+
let chunk;
|
|
153
|
+
if (chunkID !== 0) {
|
|
154
|
+
chunk = Buffer.alloc(data.length - HEADER_SIZE);
|
|
155
|
+
data.copy(chunk, 0, HEADER_SIZE);
|
|
156
|
+
}
|
|
157
|
+
else {
|
|
158
|
+
if (data.length < FIRST_HEADER_SIZE) {
|
|
159
|
+
throw new Error('received invalid message type');
|
|
160
|
+
}
|
|
161
|
+
if (total !== undefined) {
|
|
162
|
+
throw new Error('first header received twice');
|
|
163
|
+
}
|
|
164
|
+
const readTotal = data.readUint8(3);
|
|
165
|
+
total = readTotal;
|
|
166
|
+
chunk = Buffer.alloc(data.length - FIRST_HEADER_SIZE);
|
|
167
|
+
data.copy(chunk, 0, FIRST_HEADER_SIZE);
|
|
168
|
+
if (chunks.keySeq().some((id) => id > readTotal)) {
|
|
169
|
+
throw new Error('received total of chunk but got now-out-of-bound chunks');
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
this.receiving = this.receiving.set(messageID, {
|
|
173
|
+
total,
|
|
174
|
+
chunks: chunks.set(chunkID, chunk)
|
|
175
|
+
});
|
|
176
|
+
const readyMessages = this.receiving
|
|
177
|
+
.filter(({ total, chunks }) => total !== undefined && chunks.size === total)
|
|
178
|
+
.sort()
|
|
179
|
+
.map(({ chunks }) => chunks.entrySeq().toList().sortBy(([id, _]) => id))
|
|
180
|
+
.map((chunks) => Buffer.concat(chunks.map(([_, b]) => b).toArray()));
|
|
181
|
+
this.receiving = this.receiving.deleteAll(readyMessages.keys());
|
|
182
|
+
readyMessages
|
|
183
|
+
.forEach((message) => {
|
|
184
|
+
// TODO debug
|
|
185
|
+
dataListener(message);
|
|
186
|
+
});
|
|
187
|
+
});
|
|
188
|
+
}
|
|
189
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import { Map, type Set } from 'immutable';
|
|
2
|
+
import { type SignalData } from './peer.js';
|
|
3
|
+
import { type NodeID } from '../types.js';
|
|
4
|
+
import { PeerConnection, type EventConnection } from '../event_connection.js';
|
|
5
|
+
export declare class PeerPool {
|
|
6
|
+
private readonly id;
|
|
7
|
+
private peers;
|
|
8
|
+
constructor(id: NodeID);
|
|
9
|
+
shutdown(): Promise<void>;
|
|
10
|
+
signal(peerId: NodeID, signal: SignalData): void;
|
|
11
|
+
getPeers(peersToConnect: Set<NodeID>, signallingServer: EventConnection, clientHandle: (connections: Map<NodeID, PeerConnection>) => void): Promise<Map<NodeID, PeerConnection>>;
|
|
12
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import { Map } from 'immutable';
|
|
2
|
+
import { Peer } from './peer.js';
|
|
3
|
+
import { PeerConnection } from '../event_connection.js';
|
|
4
|
+
// TODO cleanup old peers
|
|
5
|
+
export class PeerPool {
|
|
6
|
+
id;
|
|
7
|
+
peers = Map();
|
|
8
|
+
constructor(id) {
|
|
9
|
+
this.id = id;
|
|
10
|
+
}
|
|
11
|
+
async shutdown() {
|
|
12
|
+
console.info(`[${this.id}] shutdown their peers`);
|
|
13
|
+
await Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect()));
|
|
14
|
+
this.peers = Map();
|
|
15
|
+
}
|
|
16
|
+
signal(peerId, signal) {
|
|
17
|
+
console.info(`[${this.id}] signals for`, peerId);
|
|
18
|
+
const peer = this.peers.get(peerId);
|
|
19
|
+
if (peer === undefined) {
|
|
20
|
+
throw new Error(`received signal for unknown peer: ${peerId}`);
|
|
21
|
+
}
|
|
22
|
+
peer.signal(signal);
|
|
23
|
+
}
|
|
24
|
+
async getPeers(peersToConnect, signallingServer,
|
|
25
|
+
// TODO as event?
|
|
26
|
+
clientHandle) {
|
|
27
|
+
if (peersToConnect.contains(this.id)) {
|
|
28
|
+
throw new Error('peers to connect contains our id');
|
|
29
|
+
}
|
|
30
|
+
console.info(`[${this.id}] is connecting peers:`, peersToConnect.toJS());
|
|
31
|
+
const newPeers = Map(peersToConnect
|
|
32
|
+
.filter((id) => !this.peers.has(id))
|
|
33
|
+
.map((id) => [id, new Peer(id, id < this.id)]));
|
|
34
|
+
console.info(`[${this.id}] asked to connect new peers:`, newPeers.keySeq().toJS());
|
|
35
|
+
const newPeersConnections = newPeers.map((peer) => new PeerConnection(this.id, peer, signallingServer));
|
|
36
|
+
// adding peers to pool before connecting them because they must be set to call signal on them
|
|
37
|
+
this.peers = this.peers.merge(newPeersConnections);
|
|
38
|
+
clientHandle(this.peers);
|
|
39
|
+
await Promise.all(newPeersConnections.valueSeq().map((conn) => conn.connect()));
|
|
40
|
+
console.info(`[${this.id}] knowns connected peers:`, this.peers.keySeq().toJS());
|
|
41
|
+
return this.peers
|
|
42
|
+
.filter((_, id) => peersToConnect.has(id));
|
|
43
|
+
}
|
|
44
|
+
}
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import type { Peer, SignalData } from './decentralized/peer.js';
|
|
2
|
+
import { type NodeID } from './types.js';
|
|
3
|
+
import { type, type NarrowMessage, type Message } from './messages.js';
|
|
4
|
+
import { EventEmitter } from '../utils/event_emitter.js';
|
|
5
|
+
export interface EventConnection {
|
|
6
|
+
on: <K extends type>(type: K, handler: (event: NarrowMessage<K>) => void) => void;
|
|
7
|
+
once: <K extends type>(type: K, handler: (event: NarrowMessage<K>) => void) => void;
|
|
8
|
+
send: <T extends Message>(msg: T) => void;
|
|
9
|
+
disconnect: () => Promise<void>;
|
|
10
|
+
}
|
|
11
|
+
export declare function waitMessage<T extends type>(connection: EventConnection, type: T): Promise<NarrowMessage<T>>;
|
|
12
|
+
export declare function waitMessageWithTimeout<T extends type>(connection: EventConnection, type: T, timeoutMs?: number): Promise<NarrowMessage<T>>;
|
|
13
|
+
export declare class PeerConnection extends EventEmitter<{
|
|
14
|
+
[K in type]: NarrowMessage<K>;
|
|
15
|
+
}> implements EventConnection {
|
|
16
|
+
private readonly _ownId;
|
|
17
|
+
private readonly peer;
|
|
18
|
+
private readonly signallingServer;
|
|
19
|
+
constructor(_ownId: NodeID, peer: Peer, signallingServer: EventConnection);
|
|
20
|
+
connect(): Promise<void>;
|
|
21
|
+
signal(signal: SignalData): void;
|
|
22
|
+
send<T extends Message>(msg: T): void;
|
|
23
|
+
disconnect(): Promise<void>;
|
|
24
|
+
}
|
|
25
|
+
export declare class WebSocketServer extends EventEmitter<{
|
|
26
|
+
[K in type]: NarrowMessage<K>;
|
|
27
|
+
}> implements EventConnection {
|
|
28
|
+
private readonly socket;
|
|
29
|
+
private readonly validateSent?;
|
|
30
|
+
private constructor();
|
|
31
|
+
static connect(url: URL, validateReceived: (msg: unknown) => msg is Message, validateSent: (msg: Message) => boolean): Promise<WebSocketServer>;
|
|
32
|
+
disconnect(): Promise<void>;
|
|
33
|
+
send(msg: Message): void;
|
|
34
|
+
}
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import WebSocket from 'isomorphic-ws';
|
|
2
|
+
import msgpack from 'msgpack-lite';
|
|
3
|
+
import * as decentralizedMessages from './decentralized/messages.js';
|
|
4
|
+
import { type } from './messages.js';
|
|
5
|
+
import { timeout } from './utils.js';
|
|
6
|
+
import { EventEmitter } from '../utils/event_emitter.js';
|
|
7
|
+
export async function waitMessage(connection, type) {
|
|
8
|
+
return await new Promise((resolve) => {
|
|
9
|
+
// "once" is important because we can't resolve the same promise multiple times
|
|
10
|
+
connection.once(type, (event) => {
|
|
11
|
+
resolve(event);
|
|
12
|
+
});
|
|
13
|
+
});
|
|
14
|
+
}
|
|
15
|
+
export async function waitMessageWithTimeout(connection, type, timeoutMs) {
|
|
16
|
+
return await Promise.race([waitMessage(connection, type), timeout(timeoutMs)]);
|
|
17
|
+
}
|
|
18
|
+
export class PeerConnection extends EventEmitter {
|
|
19
|
+
_ownId;
|
|
20
|
+
peer;
|
|
21
|
+
signallingServer;
|
|
22
|
+
constructor(_ownId, peer, signallingServer) {
|
|
23
|
+
super();
|
|
24
|
+
this._ownId = _ownId;
|
|
25
|
+
this.peer = peer;
|
|
26
|
+
this.signallingServer = signallingServer;
|
|
27
|
+
}
|
|
28
|
+
async connect() {
|
|
29
|
+
this.peer.on('signal', (signal) => {
|
|
30
|
+
const msg = {
|
|
31
|
+
type: type.SignalForPeer,
|
|
32
|
+
peer: this.peer.id,
|
|
33
|
+
signal
|
|
34
|
+
};
|
|
35
|
+
this.signallingServer.send(msg);
|
|
36
|
+
});
|
|
37
|
+
this.peer.on('data', (data) => {
|
|
38
|
+
const msg = msgpack.decode(data);
|
|
39
|
+
if (!decentralizedMessages.isPeerMessage(msg)) {
|
|
40
|
+
throw new Error(`invalid message received: ${JSON.stringify(msg)}`);
|
|
41
|
+
}
|
|
42
|
+
this.emit(msg.type, msg);
|
|
43
|
+
});
|
|
44
|
+
this.peer.on('close', () => { console.warn('peer', this.peer.id, 'closed connection'); });
|
|
45
|
+
await new Promise((resolve) => {
|
|
46
|
+
this.peer.on('connect', resolve);
|
|
47
|
+
});
|
|
48
|
+
}
|
|
49
|
+
signal(signal) {
|
|
50
|
+
this.peer.signal(signal);
|
|
51
|
+
}
|
|
52
|
+
send(msg) {
|
|
53
|
+
if (!decentralizedMessages.isPeerMessage(msg)) {
|
|
54
|
+
throw new Error(`can't send this type of message: ${JSON.stringify(msg)}`);
|
|
55
|
+
}
|
|
56
|
+
this.peer.send(msgpack.encode(msg));
|
|
57
|
+
}
|
|
58
|
+
async disconnect() {
|
|
59
|
+
await this.peer.destroy();
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
export class WebSocketServer extends EventEmitter {
|
|
63
|
+
socket;
|
|
64
|
+
validateSent;
|
|
65
|
+
constructor(socket, validateSent) {
|
|
66
|
+
super();
|
|
67
|
+
this.socket = socket;
|
|
68
|
+
this.validateSent = validateSent;
|
|
69
|
+
}
|
|
70
|
+
static async connect(url, validateReceived, validateSent) {
|
|
71
|
+
const ws = new WebSocket(url);
|
|
72
|
+
ws.binaryType = 'arraybuffer';
|
|
73
|
+
const server = new WebSocketServer(ws, validateSent);
|
|
74
|
+
ws.onmessage = (event) => {
|
|
75
|
+
if (!(event.data instanceof ArrayBuffer)) {
|
|
76
|
+
throw new Error('server did not send an ArrayBuffer');
|
|
77
|
+
}
|
|
78
|
+
const msg = msgpack.decode(new Uint8Array(event.data));
|
|
79
|
+
// Validate message format
|
|
80
|
+
if (!validateReceived(msg)) {
|
|
81
|
+
throw new Error(`invalid message received: ${JSON.stringify(msg)}`);
|
|
82
|
+
}
|
|
83
|
+
server.emit(msg.type, msg);
|
|
84
|
+
};
|
|
85
|
+
return await new Promise((resolve, reject) => {
|
|
86
|
+
ws.onerror = (err) => {
|
|
87
|
+
reject(new Error(`Server unreachable: ${err.message}`));
|
|
88
|
+
};
|
|
89
|
+
ws.onopen = () => { resolve(server); };
|
|
90
|
+
});
|
|
91
|
+
}
|
|
92
|
+
disconnect() {
|
|
93
|
+
return new Promise((resolve, reject) => {
|
|
94
|
+
this.socket.once('close', resolve);
|
|
95
|
+
this.socket.once('error', reject);
|
|
96
|
+
this.socket.close();
|
|
97
|
+
});
|
|
98
|
+
}
|
|
99
|
+
send(msg) {
|
|
100
|
+
if (this.validateSent !== undefined && !this.validateSent(msg)) {
|
|
101
|
+
throw new Error(`can't send this type of message: ${JSON.stringify(msg)}`);
|
|
102
|
+
}
|
|
103
|
+
this.socket.send(msgpack.encode(msg));
|
|
104
|
+
}
|
|
105
|
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import { Map } from "immutable";
|
|
2
|
+
import { type MetadataKey, type MetadataValue, type WeightsContainer } from "../../index.js";
|
|
3
|
+
import { type NodeID } from "../types.js";
|
|
4
|
+
import { Base as Client } from "../base.js";
|
|
5
|
+
/**
|
|
6
|
+
* Client class that communicates with a centralized, federated server, when training
|
|
7
|
+
* a specific task in the federated setting.
|
|
8
|
+
*/
|
|
9
|
+
export declare class Base extends Client {
|
|
10
|
+
/**
|
|
11
|
+
* Arbitrary node id assigned to the federated server which we are communicating with.
|
|
12
|
+
* Indeed, the server acts as a node within the network. In the federated setting described
|
|
13
|
+
* by this client class, the server is the only node which we are communicating with.
|
|
14
|
+
*/
|
|
15
|
+
static readonly SERVER_NODE_ID = "federated-server-node-id";
|
|
16
|
+
/**
|
|
17
|
+
* Map of metadata values for each node id.
|
|
18
|
+
*/
|
|
19
|
+
private metadataMap?;
|
|
20
|
+
/**
|
|
21
|
+
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
22
|
+
*/
|
|
23
|
+
private connectServer;
|
|
24
|
+
/**
|
|
25
|
+
* Initializes the connection to the server and get our own node id.
|
|
26
|
+
* TODO: In the federated setting, should return the current server-side round
|
|
27
|
+
* for the task.
|
|
28
|
+
*/
|
|
29
|
+
connect(): Promise<void>;
|
|
30
|
+
/**
|
|
31
|
+
* Disconnection process when user quits the task.
|
|
32
|
+
*/
|
|
33
|
+
disconnect(): Promise<void>;
|
|
34
|
+
/**
|
|
35
|
+
* Send a message containing our local weight updates to the federated server.
|
|
36
|
+
* And waits for the server to reply with the most recent aggregated weights
|
|
37
|
+
* @param payload The weight updates to send
|
|
38
|
+
*/
|
|
39
|
+
private sendPayloadAndReceiveResult;
|
|
40
|
+
/**
|
|
41
|
+
* Waits for the server's result for its current (most recent) round and add it to our aggregator.
|
|
42
|
+
* Updates the aggregator's round if it's behind the server's.
|
|
43
|
+
*/
|
|
44
|
+
private receiveResult;
|
|
45
|
+
/**
|
|
46
|
+
* Fetch the metadata values maintained by the federated server, for a given metadata key.
|
|
47
|
+
* The values are indexed by node id.
|
|
48
|
+
* @param key The metadata key
|
|
49
|
+
* @returns The map of node id to metadata value
|
|
50
|
+
*/
|
|
51
|
+
receiveMetadataMap(key: MetadataKey): Promise<Map<NodeID, MetadataValue> | undefined>;
|
|
52
|
+
onRoundBeginCommunication(): Promise<void>;
|
|
53
|
+
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
|
|
54
|
+
}
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import { Map } from "immutable";
|
|
2
|
+
import { serialization, } from "../../index.js";
|
|
3
|
+
import { Base as Client } from "../base.js";
|
|
4
|
+
import { type } from "../messages.js";
|
|
5
|
+
import { waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
|
|
6
|
+
import * as messages from "./messages.js";
|
|
7
|
+
/**
|
|
8
|
+
* Client class that communicates with a centralized, federated server, when training
|
|
9
|
+
* a specific task in the federated setting.
|
|
10
|
+
*/
|
|
11
|
+
export class Base extends Client {
|
|
12
|
+
/**
|
|
13
|
+
* Arbitrary node id assigned to the federated server which we are communicating with.
|
|
14
|
+
* Indeed, the server acts as a node within the network. In the federated setting described
|
|
15
|
+
* by this client class, the server is the only node which we are communicating with.
|
|
16
|
+
*/
|
|
17
|
+
static SERVER_NODE_ID = "federated-server-node-id";
|
|
18
|
+
/**
|
|
19
|
+
* Map of metadata values for each node id.
|
|
20
|
+
*/
|
|
21
|
+
metadataMap;
|
|
22
|
+
/**
|
|
23
|
+
* Opens a new WebSocket connection with the server and listens to new messages over the channel
|
|
24
|
+
*/
|
|
25
|
+
async connectServer(url) {
|
|
26
|
+
const server = await WebSocketServer.connect(url, messages.isMessageFederated, messages.isMessageFederated);
|
|
27
|
+
return server;
|
|
28
|
+
}
|
|
29
|
+
/**
|
|
30
|
+
* Initializes the connection to the server and get our own node id.
|
|
31
|
+
* TODO: In the federated setting, should return the current server-side round
|
|
32
|
+
* for the task.
|
|
33
|
+
*/
|
|
34
|
+
async connect() {
|
|
35
|
+
const serverURL = new URL("", this.url.href);
|
|
36
|
+
switch (this.url.protocol) {
|
|
37
|
+
case "http:":
|
|
38
|
+
serverURL.protocol = "ws:";
|
|
39
|
+
break;
|
|
40
|
+
case "https:":
|
|
41
|
+
serverURL.protocol = "wss:";
|
|
42
|
+
break;
|
|
43
|
+
default:
|
|
44
|
+
throw new Error(`unknown protocol: ${this.url.protocol}`);
|
|
45
|
+
}
|
|
46
|
+
serverURL.pathname += `feai/${this.task.id}`;
|
|
47
|
+
this._server = await this.connectServer(serverURL);
|
|
48
|
+
this.aggregator.registerNode(Base.SERVER_NODE_ID);
|
|
49
|
+
const msg = {
|
|
50
|
+
type: type.ClientConnected,
|
|
51
|
+
};
|
|
52
|
+
this.server.send(msg);
|
|
53
|
+
const received = await waitMessageWithTimeout(this.server, type.AssignNodeID);
|
|
54
|
+
console.info(`[${received.id}] assign id generated by the server`);
|
|
55
|
+
this._ownId = received.id;
|
|
56
|
+
}
|
|
57
|
+
/**
|
|
58
|
+
* Disconnection process when user quits the task.
|
|
59
|
+
*/
|
|
60
|
+
async disconnect() {
|
|
61
|
+
await this.server.disconnect();
|
|
62
|
+
this._server = undefined;
|
|
63
|
+
this._ownId = undefined;
|
|
64
|
+
this.aggregator.setNodes(this.aggregator.nodes.delete(Base.SERVER_NODE_ID));
|
|
65
|
+
return Promise.resolve();
|
|
66
|
+
}
|
|
67
|
+
/**
|
|
68
|
+
* Send a message containing our local weight updates to the federated server.
|
|
69
|
+
* And waits for the server to reply with the most recent aggregated weights
|
|
70
|
+
* @param payload The weight updates to send
|
|
71
|
+
*/
|
|
72
|
+
async sendPayloadAndReceiveResult(payload) {
|
|
73
|
+
const msg = {
|
|
74
|
+
type: type.SendPayload,
|
|
75
|
+
payload: await serialization.weights.encode(payload),
|
|
76
|
+
round: this.aggregator.round,
|
|
77
|
+
};
|
|
78
|
+
this.server.send(msg);
|
|
79
|
+
// It is important than the client immediately awaits the server result or it may miss it
|
|
80
|
+
return await this.receiveResult();
|
|
81
|
+
}
|
|
82
|
+
/**
|
|
83
|
+
* Waits for the server's result for its current (most recent) round and add it to our aggregator.
|
|
84
|
+
* Updates the aggregator's round if it's behind the server's.
|
|
85
|
+
*/
|
|
86
|
+
async receiveResult() {
|
|
87
|
+
try {
|
|
88
|
+
const { payload, round } = await waitMessageWithTimeout(this.server, type.ReceiveServerPayload);
|
|
89
|
+
const serverRound = round;
|
|
90
|
+
// Store the server result only if it is not stale
|
|
91
|
+
if (this.aggregator.round <= round) {
|
|
92
|
+
const serverResult = serialization.weights.decode(payload);
|
|
93
|
+
// Update the local round to match the server's
|
|
94
|
+
if (this.aggregator.round < serverRound) {
|
|
95
|
+
this.aggregator.setRound(serverRound);
|
|
96
|
+
}
|
|
97
|
+
return serverResult;
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
catch (e) {
|
|
101
|
+
console.error(e);
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
/**
|
|
105
|
+
* Fetch the metadata values maintained by the federated server, for a given metadata key.
|
|
106
|
+
* The values are indexed by node id.
|
|
107
|
+
* @param key The metadata key
|
|
108
|
+
* @returns The map of node id to metadata value
|
|
109
|
+
*/
|
|
110
|
+
async receiveMetadataMap(key) {
|
|
111
|
+
this.metadataMap = undefined;
|
|
112
|
+
const msg = {
|
|
113
|
+
type: type.ReceiveServerMetadata,
|
|
114
|
+
taskId: this.task.id,
|
|
115
|
+
nodeId: this.ownId,
|
|
116
|
+
round: this.aggregator.round,
|
|
117
|
+
key,
|
|
118
|
+
};
|
|
119
|
+
this.server.send(msg);
|
|
120
|
+
const received = await waitMessageWithTimeout(this.server, type.ReceiveServerMetadata);
|
|
121
|
+
if (received.metadataMap !== undefined) {
|
|
122
|
+
this.metadataMap = Map(received.metadataMap.filter(([_, v]) => v !== undefined));
|
|
123
|
+
}
|
|
124
|
+
return this.metadataMap;
|
|
125
|
+
}
|
|
126
|
+
onRoundBeginCommunication() {
|
|
127
|
+
// Prepare the result promise for the incoming round
|
|
128
|
+
this.aggregationResult = this.aggregator.receiveResult();
|
|
129
|
+
return Promise.resolve();
|
|
130
|
+
}
|
|
131
|
+
async onRoundEndCommunication(weights, round) {
|
|
132
|
+
// NB: For now, we suppose a fully-federated setting.
|
|
133
|
+
if (this.aggregationResult === undefined) {
|
|
134
|
+
throw new Error("local aggregation result was not set");
|
|
135
|
+
}
|
|
136
|
+
// Send our local contribution to the server
|
|
137
|
+
// and receive the most recent weights as an answer to our contribution
|
|
138
|
+
const serverResult = await this.sendPayloadAndReceiveResult(this.aggregator.makePayloads(weights).first());
|
|
139
|
+
if (serverResult !== undefined &&
|
|
140
|
+
this.aggregator.add(Base.SERVER_NODE_ID, serverResult, round, 0)) {
|
|
141
|
+
// Regular case: the server sends us its aggregation result which will serve our
|
|
142
|
+
// own aggregation result.
|
|
143
|
+
}
|
|
144
|
+
else {
|
|
145
|
+
// Unexpected case: for some reason, the server result is stale.
|
|
146
|
+
// We proceed to the next round without its result.
|
|
147
|
+
console.info(`[${this.ownId}] Server result is either stale or not received`);
|
|
148
|
+
this.aggregator.nextRound();
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
}
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import { type client, type MetadataKey, type MetadataValue } from '../../index.js';
|
|
2
|
+
import { type weights } from '../../serialization/index.js';
|
|
3
|
+
import { type, type AssignNodeID, type ClientConnected } from '../messages.js';
|
|
4
|
+
export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | RequestServerStatistics | ReceiveServerStatistics | ReceiveServerMetadata | AssignNodeID;
|
|
5
|
+
export interface SendPayload {
|
|
6
|
+
type: type.SendPayload;
|
|
7
|
+
payload: weights.Encoded;
|
|
8
|
+
round: number;
|
|
9
|
+
}
|
|
10
|
+
export interface ReceiveServerPayload {
|
|
11
|
+
type: type.ReceiveServerPayload;
|
|
12
|
+
payload: weights.Encoded;
|
|
13
|
+
round: number;
|
|
14
|
+
}
|
|
15
|
+
export interface RequestServerStatistics {
|
|
16
|
+
type: type.RequestServerStatistics;
|
|
17
|
+
}
|
|
18
|
+
export interface ReceiveServerStatistics {
|
|
19
|
+
type: type.ReceiveServerStatistics;
|
|
20
|
+
statistics: Record<string, number>;
|
|
21
|
+
}
|
|
22
|
+
export interface ReceiveServerMetadata {
|
|
23
|
+
type: type.ReceiveServerMetadata;
|
|
24
|
+
nodeId: client.NodeID;
|
|
25
|
+
taskId: string;
|
|
26
|
+
round: number;
|
|
27
|
+
key: MetadataKey;
|
|
28
|
+
metadataMap?: Array<[client.NodeID, MetadataValue | undefined]>;
|
|
29
|
+
}
|
|
30
|
+
export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import { type, hasMessageType } from '../messages.js';
|
|
2
|
+
export function isMessageFederated(raw) {
|
|
3
|
+
if (!hasMessageType(raw)) {
|
|
4
|
+
return false;
|
|
5
|
+
}
|
|
6
|
+
switch (raw.type) {
|
|
7
|
+
case type.ClientConnected:
|
|
8
|
+
return true;
|
|
9
|
+
case type.SendPayload:
|
|
10
|
+
return true;
|
|
11
|
+
case type.ReceiveServerPayload:
|
|
12
|
+
return true;
|
|
13
|
+
case type.RequestServerStatistics:
|
|
14
|
+
return true;
|
|
15
|
+
case type.ReceiveServerStatistics:
|
|
16
|
+
return true;
|
|
17
|
+
case type.ReceiveServerMetadata:
|
|
18
|
+
return true;
|
|
19
|
+
case type.AssignNodeID:
|
|
20
|
+
return true;
|
|
21
|
+
default:
|
|
22
|
+
return false;
|
|
23
|
+
}
|
|
24
|
+
}
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
export { Base as Client } from './base.js';
|
|
2
|
+
export * from './types.js';
|
|
3
|
+
export * as aggregator from '../aggregator/index.js';
|
|
4
|
+
export * as decentralized from './decentralized/index.js';
|
|
5
|
+
export * as federated from './federated/index.js';
|
|
6
|
+
export * as messages from './messages.js';
|
|
7
|
+
export * as utils from './utils.js';
|
|
8
|
+
export { Local } from './local.js';
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
export { Base as Client } from './base.js';
|
|
2
|
+
export * from './types.js';
|
|
3
|
+
export * as aggregator from '../aggregator/index.js';
|
|
4
|
+
export * as decentralized from './decentralized/index.js';
|
|
5
|
+
export * as federated from './federated/index.js';
|
|
6
|
+
export * as messages from './messages.js';
|
|
7
|
+
export * as utils from './utils.js';
|
|
8
|
+
export { Local } from './local.js';
|