@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/default_tasks/cifar10/index.js +60 -0
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/default_tasks/mnist.js +61 -0
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/default_tasks/simple_face/index.js +48 -0
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/default_tasks/titanic.js +88 -0
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.d.ts +5 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/task/task.d.ts +12 -0
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +25 -16
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -41
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -3
- package/dist/core/task/index.js +0 -8
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.d.ts +0 -10
- package/dist/core/task/task.js +0 -31
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/tasks/cifar10.d.ts +0 -3
- package/dist/core/tasks/cifar10.js +0 -65
- package/dist/core/tasks/geotags.d.ts +0 -3
- package/dist/core/tasks/geotags.js +0 -67
- package/dist/core/tasks/index.d.ts +0 -6
- package/dist/core/tasks/index.js +0 -10
- package/dist/core/tasks/lus_covid.d.ts +0 -3
- package/dist/core/tasks/lus_covid.js +0 -87
- package/dist/core/tasks/mnist.d.ts +0 -3
- package/dist/core/tasks/mnist.js +0 -60
- package/dist/core/tasks/simple_face.d.ts +0 -2
- package/dist/core/tasks/simple_face.js +0 -41
- package/dist/core/tasks/titanic.d.ts +0 -3
- package/dist/core/tasks/titanic.js +0 -88
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -8
- package/dist/core/weights/aggregation.js +0 -96
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import { Map, List, Range } from 'immutable';
|
|
2
|
+
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
import { AggregationStep, Base as Aggregator } from './base.js';
|
|
4
|
+
import { aggregation } from '../index.js';
|
|
5
|
+
/**
|
|
6
|
+
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
7
|
+
* An aggregation consists of two communication rounds:
|
|
8
|
+
* - first, nodes communicate their secret shares to each other;
|
|
9
|
+
* - then, they sum their received shares and communicate the result.
|
|
10
|
+
* Finally, nodes are able to average the received partial sums to establish the aggregation result.
|
|
11
|
+
*/
|
|
12
|
+
export class SecureAggregator extends Aggregator {
|
|
13
|
+
maxShareValue;
|
|
14
|
+
constructor(model, maxShareValue = 100) {
|
|
15
|
+
super(model, 0, 2);
|
|
16
|
+
this.maxShareValue = maxShareValue;
|
|
17
|
+
}
|
|
18
|
+
aggregate() {
|
|
19
|
+
this.log(AggregationStep.AGGREGATE);
|
|
20
|
+
if (this.communicationRound === 0) {
|
|
21
|
+
// Sum the received shares
|
|
22
|
+
const result = aggregation.sum(this.contributions.get(0)?.values());
|
|
23
|
+
this.emit(result);
|
|
24
|
+
}
|
|
25
|
+
else if (this.communicationRound === 1) {
|
|
26
|
+
// Average the received partial sums
|
|
27
|
+
const result = aggregation.avg(this.contributions.get(1)?.values());
|
|
28
|
+
if (this.model !== undefined) {
|
|
29
|
+
this.model.weights = result;
|
|
30
|
+
}
|
|
31
|
+
this.emit(result);
|
|
32
|
+
}
|
|
33
|
+
else {
|
|
34
|
+
throw new Error('communication round is out of bounds');
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
add(nodeId, contribution, round, communicationRound) {
|
|
38
|
+
if (this.nodes.has(nodeId) && this.isWithinRoundCutoff(round)) {
|
|
39
|
+
this.log(this.contributions.hasIn([communicationRound, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
|
|
40
|
+
this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
|
|
41
|
+
this.informant?.update();
|
|
42
|
+
if (this.isFull()) {
|
|
43
|
+
this.aggregate();
|
|
44
|
+
}
|
|
45
|
+
return true;
|
|
46
|
+
}
|
|
47
|
+
return false;
|
|
48
|
+
}
|
|
49
|
+
isFull() {
|
|
50
|
+
const contribs = this.contributions.get(this.communicationRound);
|
|
51
|
+
if (contribs === undefined) {
|
|
52
|
+
return false;
|
|
53
|
+
}
|
|
54
|
+
return contribs.size === this.nodes.size;
|
|
55
|
+
}
|
|
56
|
+
makePayloads(weights) {
|
|
57
|
+
if (this.communicationRound === 0) {
|
|
58
|
+
const shares = this.generateAllShares(weights);
|
|
59
|
+
// Abitrarily assign our shares to the available nodes
|
|
60
|
+
return Map(List(this.nodes).zip(shares));
|
|
61
|
+
}
|
|
62
|
+
else {
|
|
63
|
+
// Send our partial sum to every other nodes
|
|
64
|
+
return this.nodes.toMap().map(() => weights);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
/**
|
|
68
|
+
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
|
|
69
|
+
*/
|
|
70
|
+
generateAllShares(secret) {
|
|
71
|
+
if (this.nodes.size === 0) {
|
|
72
|
+
throw new Error('too few participants to generate shares');
|
|
73
|
+
}
|
|
74
|
+
// Generate N-1 shares
|
|
75
|
+
const shares = Range(0, this.nodes.size - 1)
|
|
76
|
+
.map(() => this.generateRandomShare(secret))
|
|
77
|
+
.toList();
|
|
78
|
+
// The last share completes the sum
|
|
79
|
+
return shares.push(secret.sub(aggregation.sum(shares)));
|
|
80
|
+
}
|
|
81
|
+
/**
|
|
82
|
+
* Generates one share in the same shape as the secret that is populated with values randomly chosen from
|
|
83
|
+
* a uniform distribution between (-maxShareValue, maxShareValue).
|
|
84
|
+
*/
|
|
85
|
+
generateRandomShare(secret) {
|
|
86
|
+
const MAX_SEED_BITS = 47;
|
|
87
|
+
const random = crypto.getRandomValues(new BigInt64Array(1))[0];
|
|
88
|
+
const seed = Number(BigInt.asUintN(MAX_SEED_BITS, random));
|
|
89
|
+
return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, 'float32', seed));
|
|
90
|
+
}
|
|
91
|
+
}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import type { AggregatorBase } from './aggregator/index.js';
|
|
2
|
+
export declare class AsyncInformant<T> {
|
|
3
|
+
private readonly aggregator;
|
|
4
|
+
private _round;
|
|
5
|
+
private _currentNumberOfParticipants;
|
|
6
|
+
private _totalNumberOfParticipants;
|
|
7
|
+
private _averageNumberOfParticipants;
|
|
8
|
+
constructor(aggregator: AggregatorBase<T>);
|
|
9
|
+
update(): void;
|
|
10
|
+
get round(): number;
|
|
11
|
+
get currentNumberOfParticipants(): number;
|
|
12
|
+
get totalNumberOfParticipants(): number;
|
|
13
|
+
get averageNumberOfParticipants(): number;
|
|
14
|
+
getAllStatistics(): Record<'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number>;
|
|
15
|
+
}
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
export class AsyncInformant {
|
|
2
|
+
aggregator;
|
|
3
|
+
_round = 0;
|
|
4
|
+
_currentNumberOfParticipants = 0;
|
|
5
|
+
_totalNumberOfParticipants = 0;
|
|
6
|
+
_averageNumberOfParticipants = 0;
|
|
7
|
+
constructor(aggregator) {
|
|
8
|
+
this.aggregator = aggregator;
|
|
9
|
+
}
|
|
10
|
+
update() {
|
|
11
|
+
if (this.round === 0 || this.round < this.aggregator.round) {
|
|
12
|
+
this._round = this.aggregator.round;
|
|
13
|
+
this._currentNumberOfParticipants = this.aggregator.size;
|
|
14
|
+
this._averageNumberOfParticipants = this.totalNumberOfParticipants / this.round;
|
|
15
|
+
this._totalNumberOfParticipants += this.currentNumberOfParticipants;
|
|
16
|
+
}
|
|
17
|
+
else {
|
|
18
|
+
this._round = this.aggregator.round;
|
|
19
|
+
}
|
|
20
|
+
}
|
|
21
|
+
// Getter functions
|
|
22
|
+
get round() {
|
|
23
|
+
return this._round;
|
|
24
|
+
}
|
|
25
|
+
get currentNumberOfParticipants() {
|
|
26
|
+
return this._currentNumberOfParticipants;
|
|
27
|
+
}
|
|
28
|
+
get totalNumberOfParticipants() {
|
|
29
|
+
return this._totalNumberOfParticipants;
|
|
30
|
+
}
|
|
31
|
+
get averageNumberOfParticipants() {
|
|
32
|
+
return this._averageNumberOfParticipants;
|
|
33
|
+
}
|
|
34
|
+
getAllStatistics() {
|
|
35
|
+
return {
|
|
36
|
+
round: this.round,
|
|
37
|
+
currentNumberOfParticipants: this.currentNumberOfParticipants,
|
|
38
|
+
totalNumberOfParticipants: this.totalNumberOfParticipants,
|
|
39
|
+
averageNumberOfParticipants: this.averageNumberOfParticipants
|
|
40
|
+
};
|
|
41
|
+
}
|
|
42
|
+
}
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import type { Set } from 'immutable';
|
|
2
|
+
import type { Model, Task, WeightsContainer } from '../index.js';
|
|
3
|
+
import type { NodeID } from './types.js';
|
|
4
|
+
import type { EventConnection } from './event_connection.js';
|
|
5
|
+
import type { Aggregator } from '../aggregator/index.js';
|
|
6
|
+
/**
|
|
7
|
+
* Main, abstract, class representing a Disco client in a network, which handles
|
|
8
|
+
* communication with other nodes, be it peers or a server.
|
|
9
|
+
*/
|
|
10
|
+
export declare abstract class Base {
|
|
11
|
+
/**
|
|
12
|
+
* The network server's URL to connect to.
|
|
13
|
+
*/
|
|
14
|
+
readonly url: URL;
|
|
15
|
+
/**
|
|
16
|
+
* The client's corresponding task.
|
|
17
|
+
*/
|
|
18
|
+
readonly task: Task;
|
|
19
|
+
/**
|
|
20
|
+
* The client's aggregator.
|
|
21
|
+
*/
|
|
22
|
+
readonly aggregator: Aggregator;
|
|
23
|
+
/**
|
|
24
|
+
* Own ID provided by the network's server.
|
|
25
|
+
*/
|
|
26
|
+
protected _ownId?: NodeID;
|
|
27
|
+
/**
|
|
28
|
+
* The network's server.
|
|
29
|
+
*/
|
|
30
|
+
protected _server?: EventConnection;
|
|
31
|
+
/**
|
|
32
|
+
* The aggregator's result produced after aggregation.
|
|
33
|
+
*/
|
|
34
|
+
protected aggregationResult?: Promise<WeightsContainer>;
|
|
35
|
+
constructor(
|
|
36
|
+
/**
|
|
37
|
+
* The network server's URL to connect to.
|
|
38
|
+
*/
|
|
39
|
+
url: URL,
|
|
40
|
+
/**
|
|
41
|
+
* The client's corresponding task.
|
|
42
|
+
*/
|
|
43
|
+
task: Task,
|
|
44
|
+
/**
|
|
45
|
+
* The client's aggregator.
|
|
46
|
+
*/
|
|
47
|
+
aggregator: Aggregator);
|
|
48
|
+
/**
|
|
49
|
+
* Handles the connection process from the client to any sort of network server.
|
|
50
|
+
*/
|
|
51
|
+
connect(): Promise<void>;
|
|
52
|
+
/**
|
|
53
|
+
* Handles the disconnection process of the client from any sort of network server.
|
|
54
|
+
*/
|
|
55
|
+
disconnect(): Promise<void>;
|
|
56
|
+
/**
|
|
57
|
+
* Fetches the latest model available on the network's server, for the adequate task.
|
|
58
|
+
* @returns The latest model
|
|
59
|
+
*/
|
|
60
|
+
getLatestModel(): Promise<Model>;
|
|
61
|
+
/**
|
|
62
|
+
* Communication callback called at the beginning of every training round.
|
|
63
|
+
* @param _weights The most recent local weight updates
|
|
64
|
+
* @param _round The current training round
|
|
65
|
+
*/
|
|
66
|
+
onRoundBeginCommunication(_weights: WeightsContainer, _round: number): Promise<void>;
|
|
67
|
+
/**
|
|
68
|
+
* Communication callback called the end of every training round.
|
|
69
|
+
* @param _weights The most recent local weight updates
|
|
70
|
+
* @param _round The current training round
|
|
71
|
+
*/
|
|
72
|
+
onRoundEndCommunication(_weights: WeightsContainer, _round: number): Promise<void>;
|
|
73
|
+
get nodes(): Set<NodeID>;
|
|
74
|
+
get ownId(): NodeID;
|
|
75
|
+
get server(): EventConnection;
|
|
76
|
+
}
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import axios from 'axios';
|
|
2
|
+
import { serialization } from '../index.js';
|
|
3
|
+
/**
|
|
4
|
+
* Main, abstract, class representing a Disco client in a network, which handles
|
|
5
|
+
* communication with other nodes, be it peers or a server.
|
|
6
|
+
*/
|
|
7
|
+
export class Base {
|
|
8
|
+
url;
|
|
9
|
+
task;
|
|
10
|
+
aggregator;
|
|
11
|
+
/**
|
|
12
|
+
* Own ID provided by the network's server.
|
|
13
|
+
*/
|
|
14
|
+
_ownId;
|
|
15
|
+
/**
|
|
16
|
+
* The network's server.
|
|
17
|
+
*/
|
|
18
|
+
_server;
|
|
19
|
+
/**
|
|
20
|
+
* The aggregator's result produced after aggregation.
|
|
21
|
+
*/
|
|
22
|
+
aggregationResult;
|
|
23
|
+
constructor(
|
|
24
|
+
/**
|
|
25
|
+
* The network server's URL to connect to.
|
|
26
|
+
*/
|
|
27
|
+
url,
|
|
28
|
+
/**
|
|
29
|
+
* The client's corresponding task.
|
|
30
|
+
*/
|
|
31
|
+
task,
|
|
32
|
+
/**
|
|
33
|
+
* The client's aggregator.
|
|
34
|
+
*/
|
|
35
|
+
aggregator) {
|
|
36
|
+
this.url = url;
|
|
37
|
+
this.task = task;
|
|
38
|
+
this.aggregator = aggregator;
|
|
39
|
+
}
|
|
40
|
+
/**
|
|
41
|
+
* Handles the connection process from the client to any sort of network server.
|
|
42
|
+
*/
|
|
43
|
+
async connect() { }
|
|
44
|
+
/**
|
|
45
|
+
* Handles the disconnection process of the client from any sort of network server.
|
|
46
|
+
*/
|
|
47
|
+
async disconnect() { }
|
|
48
|
+
/**
|
|
49
|
+
* Fetches the latest model available on the network's server, for the adequate task.
|
|
50
|
+
* @returns The latest model
|
|
51
|
+
*/
|
|
52
|
+
async getLatestModel() {
|
|
53
|
+
const url = new URL('', this.url.href);
|
|
54
|
+
if (!url.pathname.endsWith('/')) {
|
|
55
|
+
url.pathname += '/';
|
|
56
|
+
}
|
|
57
|
+
url.pathname += `tasks/${this.task.id}/model.json`;
|
|
58
|
+
const response = await axios.get(url.href, { responseType: 'arraybuffer' });
|
|
59
|
+
return await serialization.model.decode(new Uint8Array(response.data));
|
|
60
|
+
}
|
|
61
|
+
/**
|
|
62
|
+
* Communication callback called at the beginning of every training round.
|
|
63
|
+
* @param _weights The most recent local weight updates
|
|
64
|
+
* @param _round The current training round
|
|
65
|
+
*/
|
|
66
|
+
async onRoundBeginCommunication(_weights, _round) { }
|
|
67
|
+
/**
|
|
68
|
+
* Communication callback called the end of every training round.
|
|
69
|
+
* @param _weights The most recent local weight updates
|
|
70
|
+
* @param _round The current training round
|
|
71
|
+
*/
|
|
72
|
+
async onRoundEndCommunication(_weights, _round) { }
|
|
73
|
+
get nodes() {
|
|
74
|
+
return this.aggregator.nodes;
|
|
75
|
+
}
|
|
76
|
+
get ownId() {
|
|
77
|
+
if (this._ownId === undefined) {
|
|
78
|
+
throw new Error('the node is not connected');
|
|
79
|
+
}
|
|
80
|
+
return this._ownId;
|
|
81
|
+
}
|
|
82
|
+
get server() {
|
|
83
|
+
if (this._server === undefined) {
|
|
84
|
+
throw new Error('server undefined, not connected');
|
|
85
|
+
}
|
|
86
|
+
return this._server;
|
|
87
|
+
}
|
|
88
|
+
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import { type WeightsContainer } from '../../index.js';
|
|
2
|
+
import { Client } from '../index.js';
|
|
3
|
+
import { type PeerConnection } from '../event_connection.js';
|
|
4
|
+
import * as messages from './messages.js';
|
|
5
|
+
/**
|
|
6
|
+
* Represents a decentralized client in a network of peers. Peers coordinate each other with the
|
|
7
|
+
* help of the network's server, yet only exchange payloads between each other. Communication
|
|
8
|
+
* with the server is based off regular WebSockets, whereas peer-to-peer communication uses
|
|
9
|
+
* WebRTC for Node.js.
|
|
10
|
+
*/
|
|
11
|
+
export declare class Base extends Client {
|
|
12
|
+
/**
|
|
13
|
+
* The pool of peers to communicate with during the current training round.
|
|
14
|
+
*/
|
|
15
|
+
private pool?;
|
|
16
|
+
private connections?;
|
|
17
|
+
/**
|
|
18
|
+
* Send message to server that this client is ready for the next training round.
|
|
19
|
+
*/
|
|
20
|
+
private waitForPeers;
|
|
21
|
+
protected sendMessagetoPeer(peer: PeerConnection, msg: messages.PeerMessage): void;
|
|
22
|
+
/**
|
|
23
|
+
* Creation of the WebSocket for the server, connection of client to that WebSocket,
|
|
24
|
+
* deals with message reception from the decentralized client's perspective (messages received by client).
|
|
25
|
+
*/
|
|
26
|
+
private connectServer;
|
|
27
|
+
connect(): Promise<void>;
|
|
28
|
+
disconnect(): Promise<void>;
|
|
29
|
+
onRoundBeginCommunication(_: WeightsContainer, round: number): Promise<void>;
|
|
30
|
+
onRoundEndCommunication(weights: WeightsContainer, round: number): Promise<void>;
|
|
31
|
+
private receivePayloads;
|
|
32
|
+
}
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import { Map, Set } from 'immutable';
|
|
2
|
+
import { serialization } from '../../index.js';
|
|
3
|
+
import { Client } from '../index.js';
|
|
4
|
+
import { type } from '../messages.js';
|
|
5
|
+
import { timeout } from '../utils.js';
|
|
6
|
+
import { WebSocketServer, waitMessage, waitMessageWithTimeout } from '../event_connection.js';
|
|
7
|
+
import { PeerPool } from './peer_pool.js';
|
|
8
|
+
import * as messages from './messages.js';
|
|
9
|
+
/**
|
|
10
|
+
* Represents a decentralized client in a network of peers. Peers coordinate each other with the
|
|
11
|
+
* help of the network's server, yet only exchange payloads between each other. Communication
|
|
12
|
+
* with the server is based off regular WebSockets, whereas peer-to-peer communication uses
|
|
13
|
+
* WebRTC for Node.js.
|
|
14
|
+
*/
|
|
15
|
+
export class Base extends Client {
|
|
16
|
+
/**
|
|
17
|
+
* The pool of peers to communicate with during the current training round.
|
|
18
|
+
*/
|
|
19
|
+
pool;
|
|
20
|
+
connections;
|
|
21
|
+
/**
|
|
22
|
+
* Send message to server that this client is ready for the next training round.
|
|
23
|
+
*/
|
|
24
|
+
async waitForPeers(round) {
|
|
25
|
+
console.info(`[${this.ownId}] is ready for round`, round);
|
|
26
|
+
// Broadcast our readiness
|
|
27
|
+
const readyMessage = { type: type.PeerIsReady };
|
|
28
|
+
if (this.server === undefined) {
|
|
29
|
+
throw new Error('server undefined, could not connect peers');
|
|
30
|
+
}
|
|
31
|
+
this.server.send(readyMessage);
|
|
32
|
+
// Wait for peers to be connected before sending any update information
|
|
33
|
+
try {
|
|
34
|
+
const receivedMessage = await waitMessageWithTimeout(this.server, type.PeersForRound);
|
|
35
|
+
if (this.nodes.size > 0) {
|
|
36
|
+
throw new Error('got new peer list from server but was already received for this round');
|
|
37
|
+
}
|
|
38
|
+
const peers = Set(receivedMessage.peers);
|
|
39
|
+
console.info(`[${this.ownId}] received peers for round:`, peers.toJS());
|
|
40
|
+
if (this.ownId !== undefined && peers.has(this.ownId)) {
|
|
41
|
+
throw new Error('received peer list contains our own id');
|
|
42
|
+
}
|
|
43
|
+
this.aggregator.setNodes(peers.add(this.ownId));
|
|
44
|
+
if (this.pool === undefined) {
|
|
45
|
+
throw new Error('waiting for peers but peer pool is undefined');
|
|
46
|
+
}
|
|
47
|
+
const connections = await this.pool.getPeers(peers, this.server,
|
|
48
|
+
// Init receipt of peers weights
|
|
49
|
+
(conn) => { this.receivePayloads(conn, round); });
|
|
50
|
+
console.info(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS());
|
|
51
|
+
return connections;
|
|
52
|
+
}
|
|
53
|
+
catch (e) {
|
|
54
|
+
console.error(e);
|
|
55
|
+
this.aggregator.setNodes(Set(this.ownId));
|
|
56
|
+
return Map();
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
sendMessagetoPeer(peer, msg) {
|
|
60
|
+
console.info(`[${this.ownId}] send message to peer`, msg.peer, msg);
|
|
61
|
+
peer.send(msg);
|
|
62
|
+
}
|
|
63
|
+
/**
|
|
64
|
+
* Creation of the WebSocket for the server, connection of client to that WebSocket,
|
|
65
|
+
* deals with message reception from the decentralized client's perspective (messages received by client).
|
|
66
|
+
*/
|
|
67
|
+
async connectServer(url) {
|
|
68
|
+
const server = await WebSocketServer.connect(url, messages.isMessageFromServer, messages.isMessageToServer);
|
|
69
|
+
server.on(type.SignalForPeer, (event) => {
|
|
70
|
+
console.info(`[${this.ownId}] received signal from`, event.peer);
|
|
71
|
+
if (this.pool === undefined) {
|
|
72
|
+
throw new Error('received signal but peer pool is undefined');
|
|
73
|
+
}
|
|
74
|
+
this.pool.signal(event.peer, event.signal);
|
|
75
|
+
});
|
|
76
|
+
return server;
|
|
77
|
+
}
|
|
78
|
+
async connect() {
|
|
79
|
+
const serverURL = new URL('', this.url.href);
|
|
80
|
+
switch (this.url.protocol) {
|
|
81
|
+
case 'http:':
|
|
82
|
+
serverURL.protocol = 'ws:';
|
|
83
|
+
break;
|
|
84
|
+
case 'https:':
|
|
85
|
+
serverURL.protocol = 'wss:';
|
|
86
|
+
break;
|
|
87
|
+
default:
|
|
88
|
+
throw new Error(`unknown protocol: ${this.url.protocol}`);
|
|
89
|
+
}
|
|
90
|
+
serverURL.pathname += `deai/${this.task.id}`;
|
|
91
|
+
this._server = await this.connectServer(serverURL);
|
|
92
|
+
const msg = {
|
|
93
|
+
type: type.ClientConnected
|
|
94
|
+
};
|
|
95
|
+
this.server.send(msg);
|
|
96
|
+
const peerIdMsg = await waitMessage(this.server, type.AssignNodeID);
|
|
97
|
+
console.info(`[${peerIdMsg.id}] assigned id generated by server`);
|
|
98
|
+
if (this._ownId !== undefined) {
|
|
99
|
+
throw new Error('received id from server but was already received');
|
|
100
|
+
}
|
|
101
|
+
this._ownId = peerIdMsg.id;
|
|
102
|
+
this.pool = new PeerPool(peerIdMsg.id);
|
|
103
|
+
}
|
|
104
|
+
async disconnect() {
|
|
105
|
+
// Disconnect from peers
|
|
106
|
+
await this.pool?.shutdown();
|
|
107
|
+
this.pool = undefined;
|
|
108
|
+
if (this.connections !== undefined) {
|
|
109
|
+
const peers = this.connections.keySeq().toSet();
|
|
110
|
+
this.aggregator.setNodes(this.aggregator.nodes.subtract(peers));
|
|
111
|
+
}
|
|
112
|
+
// Disconnect from server
|
|
113
|
+
await this.server?.disconnect();
|
|
114
|
+
this._server = undefined;
|
|
115
|
+
this._ownId = undefined;
|
|
116
|
+
return Promise.resolve();
|
|
117
|
+
}
|
|
118
|
+
async onRoundBeginCommunication(_, round) {
|
|
119
|
+
// Reset peers list at each round of training to make sure client works with an updated peers
|
|
120
|
+
// list, maintained by the server. Adds any received weights to the aggregator.
|
|
121
|
+
this.connections = await this.waitForPeers(round);
|
|
122
|
+
// Store the promise for the current round's aggregation result.
|
|
123
|
+
this.aggregationResult = this.aggregator.receiveResult();
|
|
124
|
+
}
|
|
125
|
+
async onRoundEndCommunication(weights, round) {
|
|
126
|
+
let result = weights;
|
|
127
|
+
// Perform the required communication rounds. Each communication round consists in sending our local payload,
|
|
128
|
+
// followed by an aggregation step triggered by the receipt of other payloads, and handled by the aggregator.
|
|
129
|
+
// A communication round's payload is the aggregation result of the previous communication round. The first
|
|
130
|
+
// communication round simply sends our training result, i.e. model weights updates. This scheme allows for
|
|
131
|
+
// the aggregator to define any complex multi-round aggregation mechanism.
|
|
132
|
+
for (let r = 0; r < this.aggregator.communicationRounds; r++) {
|
|
133
|
+
// Generate our payloads for this communication round and send them to all ready connected peers
|
|
134
|
+
if (this.connections !== undefined) {
|
|
135
|
+
const payloads = this.aggregator.makePayloads(result);
|
|
136
|
+
try {
|
|
137
|
+
await Promise.all(payloads.map(async (payload, id) => {
|
|
138
|
+
if (id === this.ownId) {
|
|
139
|
+
this.aggregator.add(this.ownId, payload, round, r);
|
|
140
|
+
}
|
|
141
|
+
else {
|
|
142
|
+
const connection = this.connections?.get(id);
|
|
143
|
+
if (connection !== undefined) {
|
|
144
|
+
const encoded = await serialization.weights.encode(payload);
|
|
145
|
+
this.sendMessagetoPeer(connection, {
|
|
146
|
+
type: type.Payload,
|
|
147
|
+
peer: id,
|
|
148
|
+
round: r,
|
|
149
|
+
payload: encoded
|
|
150
|
+
});
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
}));
|
|
154
|
+
}
|
|
155
|
+
catch {
|
|
156
|
+
throw new Error('error while sending weights');
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
if (this.aggregationResult === undefined) {
|
|
160
|
+
throw new TypeError('aggregation result promise is undefined');
|
|
161
|
+
}
|
|
162
|
+
// Wait for aggregation before proceeding to the next communication round.
|
|
163
|
+
// The current result will be used as payload for the eventual next communication round.
|
|
164
|
+
result = await Promise.race([this.aggregationResult, timeout()]);
|
|
165
|
+
// There is at least one communication round remaining
|
|
166
|
+
if (r < this.aggregator.communicationRounds - 1) {
|
|
167
|
+
// Reuse the aggregation result
|
|
168
|
+
this.aggregationResult = this.aggregator.receiveResult();
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
// Reset the peers list for the next round
|
|
172
|
+
this.aggregator.resetNodes();
|
|
173
|
+
}
|
|
174
|
+
receivePayloads(connections, round) {
|
|
175
|
+
console.info(`[${this.ownId}] Accepting new contributions for round ${round}`);
|
|
176
|
+
connections.forEach(async (connection, peerId) => {
|
|
177
|
+
let receivedPayloads = 0;
|
|
178
|
+
do {
|
|
179
|
+
try {
|
|
180
|
+
const message = await waitMessageWithTimeout(connection, type.Payload);
|
|
181
|
+
const decoded = serialization.weights.decode(message.payload);
|
|
182
|
+
if (!this.aggregator.add(peerId, decoded, round, message.round)) {
|
|
183
|
+
console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
catch (e) {
|
|
187
|
+
console.warn(e instanceof Error ? e.message : e);
|
|
188
|
+
}
|
|
189
|
+
} while (++receivedPayloads < this.aggregator.communicationRounds);
|
|
190
|
+
});
|
|
191
|
+
}
|
|
192
|
+
}
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import { weights } from '../../serialization/index.js';
|
|
2
|
+
import { type SignalData } from './peer.js';
|
|
3
|
+
import { type NodeID } from '../types.js';
|
|
4
|
+
import { type, type ClientConnected, type AssignNodeID } from '../messages.js';
|
|
5
|
+
export interface SignalForPeer {
|
|
6
|
+
type: type.SignalForPeer;
|
|
7
|
+
peer: NodeID;
|
|
8
|
+
signal: SignalData;
|
|
9
|
+
}
|
|
10
|
+
export interface PeerIsReady {
|
|
11
|
+
type: type.PeerIsReady;
|
|
12
|
+
}
|
|
13
|
+
export interface PeersForRound {
|
|
14
|
+
type: type.PeersForRound;
|
|
15
|
+
peers: NodeID[];
|
|
16
|
+
}
|
|
17
|
+
export interface Payload {
|
|
18
|
+
type: type.Payload;
|
|
19
|
+
peer: NodeID;
|
|
20
|
+
round: number;
|
|
21
|
+
payload: weights.Encoded;
|
|
22
|
+
}
|
|
23
|
+
export type MessageFromServer = AssignNodeID | SignalForPeer | PeersForRound;
|
|
24
|
+
export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady;
|
|
25
|
+
export type PeerMessage = Payload;
|
|
26
|
+
export declare function isMessageFromServer(o: unknown): o is MessageFromServer;
|
|
27
|
+
export declare function isMessageToServer(o: unknown): o is MessageToServer;
|
|
28
|
+
export declare function isPeerMessage(o: unknown): o is PeerMessage;
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import { weights } from '../../serialization/index.js';
|
|
2
|
+
import { isNodeID } from '../types.js';
|
|
3
|
+
import { type, hasMessageType } from '../messages.js';
|
|
4
|
+
export function isMessageFromServer(o) {
|
|
5
|
+
if (!hasMessageType(o)) {
|
|
6
|
+
return false;
|
|
7
|
+
}
|
|
8
|
+
switch (o.type) {
|
|
9
|
+
case type.AssignNodeID:
|
|
10
|
+
return 'id' in o && isNodeID(o.id);
|
|
11
|
+
case type.SignalForPeer:
|
|
12
|
+
return 'peer' in o && isNodeID(o.peer) &&
|
|
13
|
+
'signal' in o; // TODO check signal content?
|
|
14
|
+
case type.PeersForRound:
|
|
15
|
+
return 'peers' in o && Array.isArray(o.peers) && o.peers.every(isNodeID);
|
|
16
|
+
}
|
|
17
|
+
return false;
|
|
18
|
+
}
|
|
19
|
+
export function isMessageToServer(o) {
|
|
20
|
+
if (!hasMessageType(o)) {
|
|
21
|
+
return false;
|
|
22
|
+
}
|
|
23
|
+
switch (o.type) {
|
|
24
|
+
case type.ClientConnected:
|
|
25
|
+
return true;
|
|
26
|
+
case type.SignalForPeer:
|
|
27
|
+
return 'peer' in o && isNodeID(o.peer) &&
|
|
28
|
+
'signal' in o; // TODO check signal content?
|
|
29
|
+
case type.PeerIsReady:
|
|
30
|
+
return true;
|
|
31
|
+
}
|
|
32
|
+
return false;
|
|
33
|
+
}
|
|
34
|
+
export function isPeerMessage(o) {
|
|
35
|
+
if (!hasMessageType(o)) {
|
|
36
|
+
return false;
|
|
37
|
+
}
|
|
38
|
+
switch (o.type) {
|
|
39
|
+
case type.Payload:
|
|
40
|
+
return ('peer' in o && isNodeID(o.peer) &&
|
|
41
|
+
'payload' in o && weights.isEncoded(o.payload));
|
|
42
|
+
}
|
|
43
|
+
return false;
|
|
44
|
+
}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
/// <reference types="node" resolution-mode="require"/>
|
|
2
|
+
import type { NodeID } from '../types.js';
|
|
3
|
+
export type SignalData = {
|
|
4
|
+
type: 'answer' | 'offer' | 'pranswer' | 'rollback';
|
|
5
|
+
sdp?: string;
|
|
6
|
+
} | {
|
|
7
|
+
type: 'transceiverRequest';
|
|
8
|
+
transceiverRequest: {
|
|
9
|
+
kind: string;
|
|
10
|
+
};
|
|
11
|
+
} | {
|
|
12
|
+
type: 'renegotiate';
|
|
13
|
+
renegotiate: true;
|
|
14
|
+
} | {
|
|
15
|
+
type: 'candidate';
|
|
16
|
+
candidate: RTCIceCandidate;
|
|
17
|
+
};
|
|
18
|
+
interface Events {
|
|
19
|
+
'close': () => void;
|
|
20
|
+
'connect': () => void;
|
|
21
|
+
'signal': (signal: SignalData) => void;
|
|
22
|
+
'data': (data: Buffer) => void;
|
|
23
|
+
}
|
|
24
|
+
export declare class Peer {
|
|
25
|
+
readonly id: NodeID;
|
|
26
|
+
private readonly peer;
|
|
27
|
+
private bufferSize?;
|
|
28
|
+
private sendCounter;
|
|
29
|
+
private sendQueue;
|
|
30
|
+
private receiving;
|
|
31
|
+
constructor(id: NodeID, initiator?: boolean);
|
|
32
|
+
send(msg: Buffer): void;
|
|
33
|
+
private flush;
|
|
34
|
+
get maxChunkSize(): number;
|
|
35
|
+
private chunk;
|
|
36
|
+
destroy(): Promise<void>;
|
|
37
|
+
signal(signal: SignalData): void;
|
|
38
|
+
on<K extends keyof Events>(event: K, listener: Events[K]): void;
|
|
39
|
+
}
|
|
40
|
+
export {};
|