@epfml/discojs 1.0.0 → 2.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. package/README.md +28 -8
  2. package/dist/{async_buffer.d.ts → core/async_buffer.d.ts} +3 -3
  3. package/dist/{async_buffer.js → core/async_buffer.js} +5 -6
  4. package/dist/{async_informant.d.ts → core/async_informant.d.ts} +0 -0
  5. package/dist/{async_informant.js → core/async_informant.js} +0 -0
  6. package/dist/{client → core/client}/base.d.ts +4 -7
  7. package/dist/{client → core/client}/base.js +3 -2
  8. package/dist/core/client/decentralized/base.d.ts +32 -0
  9. package/dist/core/client/decentralized/base.js +212 -0
  10. package/dist/core/client/decentralized/clear_text.d.ts +14 -0
  11. package/dist/core/client/decentralized/clear_text.js +96 -0
  12. package/dist/{client → core/client}/decentralized/index.d.ts +0 -0
  13. package/dist/{client → core/client}/decentralized/index.js +0 -0
  14. package/dist/core/client/decentralized/messages.d.ts +41 -0
  15. package/dist/core/client/decentralized/messages.js +54 -0
  16. package/dist/core/client/decentralized/peer.d.ts +26 -0
  17. package/dist/core/client/decentralized/peer.js +210 -0
  18. package/dist/core/client/decentralized/peer_pool.d.ts +14 -0
  19. package/dist/core/client/decentralized/peer_pool.js +92 -0
  20. package/dist/core/client/decentralized/sec_agg.d.ts +22 -0
  21. package/dist/core/client/decentralized/sec_agg.js +190 -0
  22. package/dist/core/client/decentralized/secret_shares.d.ts +3 -0
  23. package/dist/core/client/decentralized/secret_shares.js +39 -0
  24. package/dist/core/client/decentralized/types.d.ts +2 -0
  25. package/dist/core/client/decentralized/types.js +7 -0
  26. package/dist/core/client/event_connection.d.ts +37 -0
  27. package/dist/core/client/event_connection.js +158 -0
  28. package/dist/core/client/federated/client.d.ts +37 -0
  29. package/dist/core/client/federated/client.js +273 -0
  30. package/dist/core/client/federated/index.d.ts +2 -0
  31. package/dist/core/client/federated/index.js +7 -0
  32. package/dist/core/client/federated/messages.d.ts +38 -0
  33. package/dist/core/client/federated/messages.js +25 -0
  34. package/dist/{client → core/client}/index.d.ts +2 -1
  35. package/dist/{client → core/client}/index.js +3 -3
  36. package/dist/{client → core/client}/local.d.ts +2 -2
  37. package/dist/{client → core/client}/local.js +0 -0
  38. package/dist/core/client/messages.d.ts +28 -0
  39. package/dist/core/client/messages.js +33 -0
  40. package/dist/core/client/utils.d.ts +2 -0
  41. package/dist/core/client/utils.js +19 -0
  42. package/dist/core/dataset/data/data.d.ts +11 -0
  43. package/dist/core/dataset/data/data.js +20 -0
  44. package/dist/core/dataset/data/data_split.d.ts +5 -0
  45. package/dist/{client/decentralized/types.js → core/dataset/data/data_split.js} +0 -0
  46. package/dist/core/dataset/data/image_data.d.ts +8 -0
  47. package/dist/core/dataset/data/image_data.js +64 -0
  48. package/dist/core/dataset/data/index.d.ts +5 -0
  49. package/dist/core/dataset/data/index.js +11 -0
  50. package/dist/core/dataset/data/preprocessing.d.ts +13 -0
  51. package/dist/core/dataset/data/preprocessing.js +33 -0
  52. package/dist/core/dataset/data/tabular_data.d.ts +8 -0
  53. package/dist/core/dataset/data/tabular_data.js +40 -0
  54. package/dist/{dataset → core/dataset}/data_loader/data_loader.d.ts +4 -11
  55. package/dist/{dataset → core/dataset}/data_loader/data_loader.js +0 -0
  56. package/dist/core/dataset/data_loader/image_loader.d.ts +17 -0
  57. package/dist/core/dataset/data_loader/image_loader.js +141 -0
  58. package/dist/core/dataset/data_loader/index.d.ts +3 -0
  59. package/dist/core/dataset/data_loader/index.js +9 -0
  60. package/dist/core/dataset/data_loader/tabular_loader.d.ts +29 -0
  61. package/dist/core/dataset/data_loader/tabular_loader.js +101 -0
  62. package/dist/core/dataset/dataset.d.ts +2 -0
  63. package/dist/{task/training_information.js → core/dataset/dataset.js} +0 -0
  64. package/dist/{dataset → core/dataset}/dataset_builder.d.ts +5 -5
  65. package/dist/{dataset → core/dataset}/dataset_builder.js +14 -10
  66. package/dist/core/dataset/index.d.ts +4 -0
  67. package/dist/core/dataset/index.js +14 -0
  68. package/dist/core/default_tasks/cifar10.d.ts +2 -0
  69. package/dist/core/default_tasks/cifar10.js +68 -0
  70. package/dist/core/default_tasks/geotags.d.ts +2 -0
  71. package/dist/core/default_tasks/geotags.js +69 -0
  72. package/dist/core/default_tasks/index.d.ts +6 -0
  73. package/dist/core/default_tasks/index.js +15 -0
  74. package/dist/core/default_tasks/lus_covid.d.ts +2 -0
  75. package/dist/core/default_tasks/lus_covid.js +96 -0
  76. package/dist/core/default_tasks/mnist.d.ts +2 -0
  77. package/dist/core/default_tasks/mnist.js +69 -0
  78. package/dist/core/default_tasks/simple_face.d.ts +2 -0
  79. package/dist/core/default_tasks/simple_face.js +53 -0
  80. package/dist/core/default_tasks/titanic.d.ts +2 -0
  81. package/dist/core/default_tasks/titanic.js +97 -0
  82. package/dist/core/index.d.ts +18 -0
  83. package/dist/core/index.js +39 -0
  84. package/dist/{informant → core/informant}/graph_informant.d.ts +0 -0
  85. package/dist/{informant → core/informant}/graph_informant.js +0 -0
  86. package/dist/{informant → core/informant}/index.d.ts +0 -0
  87. package/dist/{informant → core/informant}/index.js +0 -0
  88. package/dist/{informant → core/informant}/training_informant/base.d.ts +3 -3
  89. package/dist/{informant → core/informant}/training_informant/base.js +3 -2
  90. package/dist/{informant → core/informant}/training_informant/decentralized.d.ts +0 -0
  91. package/dist/{informant → core/informant}/training_informant/decentralized.js +0 -0
  92. package/dist/{informant → core/informant}/training_informant/federated.d.ts +0 -0
  93. package/dist/{informant → core/informant}/training_informant/federated.js +0 -0
  94. package/dist/{informant → core/informant}/training_informant/index.d.ts +0 -0
  95. package/dist/{informant → core/informant}/training_informant/index.js +0 -0
  96. package/dist/{informant → core/informant}/training_informant/local.d.ts +2 -2
  97. package/dist/{informant → core/informant}/training_informant/local.js +2 -2
  98. package/dist/{logging → core/logging}/console_logger.d.ts +0 -0
  99. package/dist/{logging → core/logging}/console_logger.js +0 -0
  100. package/dist/{logging → core/logging}/index.d.ts +0 -0
  101. package/dist/{logging → core/logging}/index.js +0 -0
  102. package/dist/{logging → core/logging}/logger.d.ts +0 -0
  103. package/dist/{logging → core/logging}/logger.js +0 -0
  104. package/dist/{logging → core/logging}/trainer_logger.d.ts +0 -0
  105. package/dist/{logging → core/logging}/trainer_logger.js +0 -0
  106. package/dist/{memory → core/memory}/base.d.ts +2 -2
  107. package/dist/{memory → core/memory}/base.js +0 -0
  108. package/dist/{memory → core/memory}/empty.d.ts +0 -0
  109. package/dist/{memory → core/memory}/empty.js +0 -0
  110. package/dist/core/memory/index.d.ts +3 -0
  111. package/dist/core/memory/index.js +9 -0
  112. package/dist/{memory → core/memory}/model_type.d.ts +0 -0
  113. package/dist/{memory → core/memory}/model_type.js +0 -0
  114. package/dist/{privacy.d.ts → core/privacy.d.ts} +2 -3
  115. package/dist/{privacy.js → core/privacy.js} +3 -16
  116. package/dist/{serialization → core/serialization}/index.d.ts +0 -0
  117. package/dist/{serialization → core/serialization}/index.js +0 -0
  118. package/dist/{serialization → core/serialization}/model.d.ts +0 -0
  119. package/dist/{serialization → core/serialization}/model.js +0 -0
  120. package/dist/core/serialization/weights.d.ts +5 -0
  121. package/dist/{serialization → core/serialization}/weights.js +11 -9
  122. package/dist/{task → core/task}/data_example.d.ts +0 -0
  123. package/dist/{task → core/task}/data_example.js +0 -0
  124. package/dist/core/task/digest.d.ts +5 -0
  125. package/dist/core/task/digest.js +18 -0
  126. package/dist/{task → core/task}/display_information.d.ts +5 -5
  127. package/dist/{task → core/task}/display_information.js +5 -10
  128. package/dist/{task → core/task}/index.d.ts +3 -0
  129. package/dist/core/task/index.js +15 -0
  130. package/dist/core/task/model_compile_data.d.ts +6 -0
  131. package/dist/core/task/model_compile_data.js +22 -0
  132. package/dist/{task → core/task}/summary.d.ts +0 -0
  133. package/dist/{task → core/task}/summary.js +0 -4
  134. package/dist/{task → core/task}/task.d.ts +4 -2
  135. package/dist/{task → core/task}/task.js +10 -7
  136. package/dist/core/task/task_handler.d.ts +5 -0
  137. package/dist/core/task/task_handler.js +53 -0
  138. package/dist/core/task/task_provider.d.ts +6 -0
  139. package/dist/core/task/task_provider.js +13 -0
  140. package/dist/{task → core/task}/training_information.d.ts +10 -14
  141. package/dist/core/task/training_information.js +66 -0
  142. package/dist/core/training/disco.d.ts +23 -0
  143. package/dist/core/training/disco.js +130 -0
  144. package/dist/{training → core/training}/index.d.ts +0 -0
  145. package/dist/{training → core/training}/index.js +0 -0
  146. package/dist/{training → core/training}/trainer/distributed_trainer.d.ts +1 -2
  147. package/dist/{training → core/training}/trainer/distributed_trainer.js +6 -5
  148. package/dist/{training → core/training}/trainer/local_trainer.d.ts +2 -2
  149. package/dist/{training → core/training}/trainer/local_trainer.js +0 -0
  150. package/dist/{training → core/training}/trainer/round_tracker.d.ts +0 -0
  151. package/dist/{training → core/training}/trainer/round_tracker.js +0 -0
  152. package/dist/{training → core/training}/trainer/trainer.d.ts +1 -2
  153. package/dist/{training → core/training}/trainer/trainer.js +2 -2
  154. package/dist/{training → core/training}/trainer/trainer_builder.d.ts +0 -0
  155. package/dist/{training → core/training}/trainer/trainer_builder.js +0 -0
  156. package/dist/core/training/training_schemes.d.ts +5 -0
  157. package/dist/{training → core/training}/training_schemes.js +2 -2
  158. package/dist/{types.d.ts → core/types.d.ts} +0 -0
  159. package/dist/{types.js → core/types.js} +0 -0
  160. package/dist/{validation → core/validation}/index.d.ts +0 -0
  161. package/dist/{validation → core/validation}/index.js +0 -0
  162. package/dist/{validation → core/validation}/validator.d.ts +5 -8
  163. package/dist/{validation → core/validation}/validator.js +9 -11
  164. package/dist/core/weights/aggregation.d.ts +7 -0
  165. package/dist/core/weights/aggregation.js +72 -0
  166. package/dist/core/weights/index.d.ts +2 -0
  167. package/dist/core/weights/index.js +7 -0
  168. package/dist/core/weights/weights_container.d.ts +19 -0
  169. package/dist/core/weights/weights_container.js +64 -0
  170. package/dist/dataset/data_loader/image_loader.d.ts +3 -15
  171. package/dist/dataset/data_loader/image_loader.js +12 -125
  172. package/dist/dataset/data_loader/index.d.ts +2 -3
  173. package/dist/dataset/data_loader/index.js +3 -5
  174. package/dist/dataset/data_loader/tabular_loader.d.ts +3 -28
  175. package/dist/dataset/data_loader/tabular_loader.js +11 -92
  176. package/dist/imports.d.ts +2 -0
  177. package/dist/imports.js +7 -0
  178. package/dist/index.d.ts +2 -19
  179. package/dist/index.js +3 -39
  180. package/dist/memory/index.d.ts +1 -3
  181. package/dist/memory/index.js +3 -7
  182. package/dist/memory/memory.d.ts +26 -0
  183. package/dist/memory/memory.js +160 -0
  184. package/package.json +13 -26
  185. package/dist/aggregation.d.ts +0 -5
  186. package/dist/aggregation.js +0 -33
  187. package/dist/client/decentralized/base.d.ts +0 -43
  188. package/dist/client/decentralized/base.js +0 -243
  189. package/dist/client/decentralized/clear_text.d.ts +0 -13
  190. package/dist/client/decentralized/clear_text.js +0 -78
  191. package/dist/client/decentralized/messages.d.ts +0 -37
  192. package/dist/client/decentralized/messages.js +0 -15
  193. package/dist/client/decentralized/sec_agg.d.ts +0 -18
  194. package/dist/client/decentralized/sec_agg.js +0 -169
  195. package/dist/client/decentralized/secret_shares.d.ts +0 -5
  196. package/dist/client/decentralized/secret_shares.js +0 -58
  197. package/dist/client/decentralized/types.d.ts +0 -1
  198. package/dist/client/federated.d.ts +0 -30
  199. package/dist/client/federated.js +0 -218
  200. package/dist/dataset/index.d.ts +0 -2
  201. package/dist/dataset/index.js +0 -7
  202. package/dist/model_actor.d.ts +0 -16
  203. package/dist/model_actor.js +0 -20
  204. package/dist/serialization/weights.d.ts +0 -5
  205. package/dist/task/index.js +0 -8
  206. package/dist/task/model_compile_data.d.ts +0 -6
  207. package/dist/task/model_compile_data.js +0 -12
  208. package/dist/tasks/cifar10.d.ts +0 -4
  209. package/dist/tasks/cifar10.js +0 -76
  210. package/dist/tasks/index.d.ts +0 -5
  211. package/dist/tasks/index.js +0 -9
  212. package/dist/tasks/lus_covid.d.ts +0 -4
  213. package/dist/tasks/lus_covid.js +0 -85
  214. package/dist/tasks/mnist.d.ts +0 -4
  215. package/dist/tasks/mnist.js +0 -58
  216. package/dist/tasks/simple_face.d.ts +0 -4
  217. package/dist/tasks/simple_face.js +0 -84
  218. package/dist/tasks/titanic.d.ts +0 -4
  219. package/dist/tasks/titanic.js +0 -88
  220. package/dist/tfjs.d.ts +0 -2
  221. package/dist/tfjs.js +0 -6
  222. package/dist/training/disco.d.ts +0 -14
  223. package/dist/training/disco.js +0 -70
  224. package/dist/training/training_schemes.d.ts +0 -5
@@ -0,0 +1,26 @@
1
+ /// <reference types="node" />
2
+ import SimplePeer, { SignalData } from 'simple-peer';
3
+ import { PeerID } from './types';
4
+ interface Events {
5
+ 'close': () => void;
6
+ 'connect': () => void;
7
+ 'signal': (signal: SignalData) => void;
8
+ 'data': (data: Buffer) => void;
9
+ }
10
+ export declare class Peer {
11
+ readonly id: PeerID;
12
+ private readonly peer;
13
+ private bufferSize?;
14
+ private sendCounter;
15
+ private sendQueue;
16
+ private receiving;
17
+ constructor(id: PeerID, opts?: SimplePeer.Options);
18
+ send(msg: Buffer): void;
19
+ private flush;
20
+ get maxChunkSize(): number;
21
+ private chunk;
22
+ destroy(): void;
23
+ signal(signal: SimplePeer.SignalData): void;
24
+ on<K extends keyof Events>(event: K, listener: Events[K]): void;
25
+ }
26
+ export {};
@@ -0,0 +1,210 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.Peer = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var immutable_1 = require("immutable");
6
+ var simple_peer_1 = (0, tslib_1.__importDefault)(require("simple-peer"));
7
+ // message id + (chunk counter == 0) + chunk count
8
+ var FIRST_HEADER_SIZE = 2 + 1 + 1;
9
+ // message id + chunk counter
10
+ var HEADER_SIZE = 2 + 1;
11
+ // at which interval to poll
12
+ var TICK = 10;
13
+ // Peer wraps a SimplePeer, adding message fragmentation
14
+ //
15
+ // WebRTC implementations have various maximum message size
16
+ // but with huge models, our messages might be bigger.
17
+ // We split messages by chunks and reconstruct theses
18
+ // on the other side.
19
+ //
20
+ // As the WebRTC's DataChannel is not a stream, we need
21
+ // reorder messages, so we use a header on each chunk
22
+ // with a message id and chunk counter. The first chunk
23
+ // (chunk counter == 0), also add the total number of chunk.
24
+ //
25
+ // see feross/simple-peer#393 for more info
26
+ var Peer = /** @class */ (function () {
27
+ function Peer(id, opts) {
28
+ this.sendCounter = 0;
29
+ this.sendQueue = (0, immutable_1.List)();
30
+ this.receiving = (0, immutable_1.Map)();
31
+ this.id = id;
32
+ this.peer = new simple_peer_1.default(opts);
33
+ }
34
+ Peer.prototype.send = function (msg) {
35
+ console.debug('sending message of size', msg.length);
36
+ var chunks = this.chunk(msg);
37
+ this.sendQueue = this.sendQueue.concat(chunks);
38
+ this.flush();
39
+ };
40
+ Peer.prototype.flush = function () {
41
+ var _this = this;
42
+ if (this.bufferSize === undefined) {
43
+ throw new Error('flush without known buffer size');
44
+ }
45
+ var chunk = this.sendQueue.first();
46
+ if (chunk === undefined) {
47
+ return; // nothing to flush
48
+ }
49
+ var remainingBufferSize = this.bufferSize - this.peer.bufferSize;
50
+ if (chunk.length > remainingBufferSize) {
51
+ setTimeout(function () { return _this.flush(); }, TICK);
52
+ return;
53
+ }
54
+ console.debug('sending chunk of size', chunk.length);
55
+ this.sendQueue = this.sendQueue.shift();
56
+ this.peer.send(chunk);
57
+ // and loop
58
+ this.flush();
59
+ };
60
+ Object.defineProperty(Peer.prototype, "maxChunkSize", {
61
+ get: function () {
62
+ if (this.bufferSize === undefined) {
63
+ throw new Error('chunk without known buffer size');
64
+ }
65
+ // in the perfect world of bug-free implementations
66
+ // we would return this.bufferSize
67
+ // sadly, we are not there yet
68
+ //
69
+ // based on MDN, taking 16K seems to be a pretty safe
70
+ // and widely supported buffer size
71
+ return 16 * (1 << 10);
72
+ },
73
+ enumerable: false,
74
+ configurable: true
75
+ });
76
+ Peer.prototype.chunk = function (b) {
77
+ var _this = this;
78
+ var messageID = this.sendCounter;
79
+ this.sendCounter++;
80
+ if (this.sendCounter > 0xFFFF) {
81
+ throw new Error('too much messages sent to this peer');
82
+ }
83
+ // special case as Range(1, 0) yields a value
84
+ var tail = immutable_1.Seq.Indexed([]);
85
+ if (b.length > this.maxChunkSize) {
86
+ tail = (0, immutable_1.Range)(this.maxChunkSize - FIRST_HEADER_SIZE, b.length, this.maxChunkSize - HEADER_SIZE).map(function (offset) { return b.subarray(offset, offset + _this.maxChunkSize - HEADER_SIZE); });
87
+ }
88
+ var totalChunkCount = 1 + tail.count();
89
+ if (totalChunkCount > 0xFF) {
90
+ throw new Error('too big message to even chunk it');
91
+ }
92
+ var firstChunk = Buffer.alloc((b.length > this.maxChunkSize - FIRST_HEADER_SIZE)
93
+ ? this.maxChunkSize
94
+ : FIRST_HEADER_SIZE + b.length);
95
+ firstChunk.writeUint16BE(messageID);
96
+ firstChunk.writeUint8(0, 2);
97
+ firstChunk.writeUint8(totalChunkCount, 3);
98
+ b.copy(firstChunk, FIRST_HEADER_SIZE, 0, this.maxChunkSize - FIRST_HEADER_SIZE);
99
+ return immutable_1.Seq.Indexed([firstChunk])
100
+ .concat((0, immutable_1.Range)(1).zip(tail)
101
+ .map(function (_a) {
102
+ var _b = (0, tslib_1.__read)(_a, 2), id = _b[0], raw = _b[1];
103
+ var chunk = Buffer.alloc(HEADER_SIZE + raw.length);
104
+ chunk.writeUint16BE(messageID);
105
+ chunk.writeUint8(id, 2);
106
+ raw.copy(chunk, HEADER_SIZE, 0);
107
+ return chunk;
108
+ }));
109
+ };
110
+ Peer.prototype.destroy = function () {
111
+ this.peer.destroy();
112
+ };
113
+ Peer.prototype.signal = function (signal) {
114
+ // extract max buffer size
115
+ if (signal.type === 'offer' || signal.type === 'answer') {
116
+ if (signal.sdp === undefined) {
117
+ throw new Error('signal answer|offer without session description');
118
+ }
119
+ if (this.bufferSize !== undefined) {
120
+ throw new Error('buffer size set twice');
121
+ }
122
+ var match = signal.sdp.match(/a=max-message-size:(\d+)/);
123
+ if (match === null) {
124
+ // TODO default value instead?
125
+ throw new Error('no max-message-size found in signal');
126
+ }
127
+ var max = parseInt(match[1], 10);
128
+ if (isNaN(max)) {
129
+ throw new Error("unable to parse max-message-size as int: " + match[1]);
130
+ }
131
+ this.bufferSize = max;
132
+ }
133
+ this.peer.signal(signal);
134
+ };
135
+ Peer.prototype.on = function (event, listener) {
136
+ var _this = this;
137
+ if (event !== 'data') {
138
+ this.peer.on(event, listener);
139
+ return;
140
+ }
141
+ this.peer.on('data', function (data) {
142
+ if (!Buffer.isBuffer(data) || data.length < HEADER_SIZE) {
143
+ throw new Error('received invalid message type');
144
+ }
145
+ var messageID = data.readUint16BE();
146
+ var chunkID = data.readUint8(2);
147
+ var received = _this.receiving.get(messageID, {
148
+ total: undefined,
149
+ chunks: (0, immutable_1.Map)()
150
+ });
151
+ var total = received.total;
152
+ var chunks = received.chunks;
153
+ if (chunks.has(chunkID)) {
154
+ throw new Error("chunk " + messageID + ":" + chunkID + " already received");
155
+ }
156
+ var chunk;
157
+ if (chunkID !== 0) {
158
+ chunk = Buffer.alloc(data.length - HEADER_SIZE);
159
+ data.copy(chunk, 0, HEADER_SIZE);
160
+ }
161
+ else {
162
+ if (data.length < FIRST_HEADER_SIZE) {
163
+ throw new Error('received invalid message type');
164
+ }
165
+ if (total !== undefined) {
166
+ throw new Error('first header received twice');
167
+ }
168
+ var readTotal_1 = data.readUint8(3);
169
+ total = readTotal_1;
170
+ chunk = Buffer.alloc(data.length - FIRST_HEADER_SIZE);
171
+ data.copy(chunk, 0, FIRST_HEADER_SIZE);
172
+ if (chunks.keySeq().some(function (id) { return id > readTotal_1; })) {
173
+ throw new Error('received total of chunk but got now-out-of-bound chunks');
174
+ }
175
+ }
176
+ _this.receiving = _this.receiving.set(messageID, {
177
+ total: total,
178
+ chunks: chunks.set(chunkID, chunk)
179
+ });
180
+ console.debug("got chunk " + messageID + ":" + chunkID + "/" + (total !== null && total !== void 0 ? total : 'unknown') + " of size " + chunk.length);
181
+ var readyMessages = _this.receiving
182
+ .filter(function (_a) {
183
+ var total = _a.total, chunks = _a.chunks;
184
+ return total !== undefined && chunks.size === total;
185
+ })
186
+ .sort()
187
+ .map(function (_a) {
188
+ var chunks = _a.chunks;
189
+ return chunks.entrySeq().toList().sortBy(function (_a) {
190
+ var _b = (0, tslib_1.__read)(_a, 2), id = _b[0], _ = _b[1];
191
+ return id;
192
+ });
193
+ })
194
+ .map(function (chunks) { return Buffer.concat(chunks.map(function (_a) {
195
+ var _b = (0, tslib_1.__read)(_a, 2), _ = _b[0], b = _b[1];
196
+ return b;
197
+ }).toArray()); });
198
+ _this.receiving = _this.receiving.deleteAll(readyMessages.keys());
199
+ readyMessages
200
+ .forEach(function (message) {
201
+ console.debug(_this.peer.address().port, 'recved message of size', message.length);
202
+ // TODO debug
203
+ // @ts-expect-error
204
+ listener(message);
205
+ });
206
+ });
207
+ };
208
+ return Peer;
209
+ }());
210
+ exports.Peer = Peer;
@@ -0,0 +1,14 @@
1
+ import { Map, Set } from 'immutable';
2
+ import { SignalData } from 'simple-peer';
3
+ import { PeerID } from './types';
4
+ import { PeerConnection, EventConnection } from '../event_connection';
5
+ export declare class PeerPool {
6
+ private readonly id;
7
+ private readonly wrtc?;
8
+ private peers;
9
+ private constructor();
10
+ static init(id: PeerID): Promise<PeerPool>;
11
+ shutdown(): void;
12
+ signal(peerID: PeerID, signal: SignalData): void;
13
+ getPeers(peersToConnect: Set<PeerID>, signallingServer: EventConnection, clientHandle: (connections: Map<PeerID, PeerConnection>) => void): Promise<Map<PeerID, PeerConnection>>;
14
+ }
@@ -0,0 +1,92 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.PeerPool = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var immutable_1 = require("immutable");
6
+ var peer_1 = require("./peer");
7
+ var event_connection_1 = require("../event_connection");
8
+ // TODO cleanup old peers
9
+ var PeerPool = /** @class */ (function () {
10
+ function PeerPool(id, wrtc) {
11
+ this.id = id;
12
+ this.wrtc = wrtc;
13
+ this.peers = (0, immutable_1.Map)();
14
+ }
15
+ PeerPool.init = function (id) {
16
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
17
+ var wrtc, path, e_1;
18
+ return (0, tslib_1.__generator)(this, function (_a) {
19
+ switch (_a.label) {
20
+ case 0:
21
+ _a.trys.push([0, 2, , 3]);
22
+ path = require.resolve('@koush/wrtc', { paths: ['.'] });
23
+ return [4 /*yield*/, Promise.resolve().then(function () { return (0, tslib_1.__importStar)(require(path)); })];
24
+ case 1:
25
+ wrtc = _a.sent();
26
+ return [3 /*break*/, 3];
27
+ case 2:
28
+ e_1 = _a.sent();
29
+ return [3 /*break*/, 3];
30
+ case 3: return [2 /*return*/, new PeerPool(id, wrtc)];
31
+ }
32
+ });
33
+ });
34
+ };
35
+ PeerPool.prototype.shutdown = function () {
36
+ console.debug(this.id, 'shutdown their peers');
37
+ this.peers.forEach(function (peer) { return peer.disconnect(); });
38
+ this.peers = (0, immutable_1.Map)();
39
+ };
40
+ PeerPool.prototype.signal = function (peerID, signal) {
41
+ console.debug(this.id, 'signals for', peerID);
42
+ var peer = this.peers.get(peerID);
43
+ if (peer === undefined) {
44
+ throw new Error("received signal for unknown peer: " + peerID);
45
+ }
46
+ peer.signal(signal);
47
+ };
48
+ PeerPool.prototype.getPeers = function (peersToConnect, signallingServer,
49
+ // TODO as event?
50
+ clientHandle) {
51
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
52
+ var newPeers, newPeersConnections;
53
+ var _this = this;
54
+ return (0, tslib_1.__generator)(this, function (_a) {
55
+ switch (_a.label) {
56
+ case 0:
57
+ if (peersToConnect.contains(this.id)) {
58
+ throw new Error('peers to connect contains our id');
59
+ }
60
+ console.debug(this.id, 'is connecting peers:', peersToConnect.toJS());
61
+ newPeers = (0, immutable_1.Map)(peersToConnect
62
+ .filter(function (id) { return !_this.peers.has(id); })
63
+ .map(function (id) { return [id, id < _this.id]; })
64
+ .map(function (_a) {
65
+ var _b = (0, tslib_1.__read)(_a, 2), id = _b[0], initiator = _b[1];
66
+ var p = new peer_1.Peer(id, { initiator: initiator, wrtc: _this.wrtc });
67
+ // onNewPeer(id, p)
68
+ return [id, p];
69
+ }));
70
+ console.debug(this.id, 'asked to connect new peers:', newPeers.keySeq().toJS());
71
+ newPeersConnections = newPeers.map(function (peer, id) { return new event_connection_1.PeerConnection(_this.id, peer, signallingServer); });
72
+ // adding peers to pool before connecting them because they must be set to call signal on them
73
+ this.peers = this.peers.merge(newPeersConnections);
74
+ clientHandle(this.peers);
75
+ return [4 /*yield*/, Promise.all(Array.from(newPeersConnections.values()).map(function (connection) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
76
+ switch (_a.label) {
77
+ case 0: return [4 /*yield*/, connection.connect()];
78
+ case 1: return [2 /*return*/, _a.sent()];
79
+ }
80
+ }); }); }))];
81
+ case 1:
82
+ _a.sent();
83
+ console.debug(this.id, 'knowns connected peers:', this.peers.keySeq().toJS());
84
+ return [2 /*return*/, this.peers
85
+ .filter(function (_, id) { return peersToConnect.has(id); })];
86
+ }
87
+ });
88
+ });
89
+ };
90
+ return PeerPool;
91
+ }());
92
+ exports.PeerPool = PeerPool;
@@ -0,0 +1,22 @@
1
+ import { List, Map } from 'immutable';
2
+ import { Task, TrainingInformant, WeightsContainer } from '../..';
3
+ import { Base } from './base';
4
+ import { PeerID } from './types';
5
+ import { PeerConnection } from '../event_connection';
6
+ /**
7
+ * Decentralized client that utilizes secure aggregation so client updates remain private
8
+ */
9
+ export declare class SecAgg extends Base {
10
+ readonly url: URL;
11
+ readonly task: Task;
12
+ private readonly maxShareValue;
13
+ private receivedShares?;
14
+ private receivedPartialSums?;
15
+ constructor(url: URL, task: Task);
16
+ private sendShares;
17
+ private sendPartialSums;
18
+ sendAndReceiveWeights(peers: Map<PeerID, PeerConnection>, noisyWeights: WeightsContainer, round: number, trainingInformant: TrainingInformant): Promise<List<WeightsContainer>>;
19
+ private receiveShares;
20
+ private receivePartials;
21
+ clientHandle(peers: Map<PeerID, PeerConnection>): void;
22
+ }
@@ -0,0 +1,190 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.SecAgg = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var immutable_1 = require("immutable");
6
+ var __1 = require("../..");
7
+ var base_1 = require("./base");
8
+ var messages_1 = require("../messages");
9
+ var secret_shares = (0, tslib_1.__importStar)(require("./secret_shares"));
10
+ var event_connection_1 = require("../event_connection");
11
+ /**
12
+ * Decentralized client that utilizes secure aggregation so client updates remain private
13
+ */
14
+ var SecAgg = /** @class */ (function (_super) {
15
+ (0, tslib_1.__extends)(SecAgg, _super);
16
+ function SecAgg(url, task) {
17
+ var _a, _b;
18
+ var _this = _super.call(this, url, task) || this;
19
+ _this.url = url;
20
+ _this.task = task;
21
+ _this.maxShareValue = (_b = (_a = _this.task.trainingInformation) === null || _a === void 0 ? void 0 : _a.maxShareValue) !== null && _b !== void 0 ? _b : 100;
22
+ return _this;
23
+ }
24
+ /*
25
+ generates shares and sends to all ready peers adds differential privacy
26
+ */
27
+ SecAgg.prototype.sendShares = function (peers, weightShares, round, trainingInformant) {
28
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
29
+ var encodedWeightShares, _a;
30
+ var _this = this;
31
+ return (0, tslib_1.__generator)(this, function (_b) {
32
+ switch (_b.label) {
33
+ case 0:
34
+ _a = immutable_1.List;
35
+ return [4 /*yield*/, Promise.all(weightShares.rest().map(function (weights) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
36
+ switch (_a.label) {
37
+ case 0: return [4 /*yield*/, __1.serialization.weights.encode(weights)];
38
+ case 1: return [2 /*return*/, _a.sent()];
39
+ }
40
+ }); }); }))];
41
+ case 1:
42
+ encodedWeightShares = _a.apply(void 0, [_b.sent()]);
43
+ // Broadcast our weights to ith peer in the SERVER LIST OF PEERS (seen in signaling_server.ts)
44
+ peers
45
+ .entrySeq()
46
+ .toSeq()
47
+ .zip(encodedWeightShares)
48
+ .forEach(function (_a) {
49
+ var _b = (0, tslib_1.__read)(_a, 2), _c = (0, tslib_1.__read)(_b[0], 2), id = _c[0], peer = _c[1], weights = _b[1];
50
+ return _this.sendMessagetoPeer(peer, {
51
+ type: messages_1.type.Shares,
52
+ peer: id,
53
+ weights: weights
54
+ });
55
+ });
56
+ return [2 /*return*/];
57
+ }
58
+ });
59
+ });
60
+ };
61
+ /*
62
+ sends partial sums to connected peers so final update can be calculated
63
+ */
64
+ SecAgg.prototype.sendPartialSums = function (partial, peers) {
65
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
66
+ var myEncodedSum;
67
+ var _this = this;
68
+ return (0, tslib_1.__generator)(this, function (_a) {
69
+ switch (_a.label) {
70
+ case 0: return [4 /*yield*/, __1.serialization.weights.encode(partial)
71
+ // calculate, encode, and send sum
72
+ ];
73
+ case 1:
74
+ myEncodedSum = _a.sent();
75
+ // calculate, encode, and send sum
76
+ peers.forEach(function (peer, id) {
77
+ return _this.sendMessagetoPeer(peer, {
78
+ type: messages_1.type.PartialSums,
79
+ peer: id,
80
+ partials: myEncodedSum
81
+ });
82
+ });
83
+ return [2 /*return*/];
84
+ }
85
+ });
86
+ });
87
+ };
88
+ SecAgg.prototype.sendAndReceiveWeights = function (peers, noisyWeights, round, trainingInformant) {
89
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
90
+ var weightShares, shares, mySum, partials;
91
+ return (0, tslib_1.__generator)(this, function (_a) {
92
+ switch (_a.label) {
93
+ case 0:
94
+ if (!this.receivedShares || !this.receivedPartialSums) {
95
+ throw new Error('no promise setup for receiving weights');
96
+ }
97
+ weightShares = secret_shares.generateAllShares(noisyWeights, peers.size + 1, this.maxShareValue);
98
+ return [4 /*yield*/, this.sendShares(peers, weightShares, round, trainingInformant)];
99
+ case 1:
100
+ _a.sent();
101
+ return [4 /*yield*/, this.receivedShares
102
+ // add own share to list
103
+ ];
104
+ case 2:
105
+ shares = _a.sent();
106
+ // add own share to list
107
+ shares = shares.insert(0, weightShares.first());
108
+ mySum = __1.aggregation.sum(shares);
109
+ void this.sendPartialSums(mySum, peers);
110
+ return [4 /*yield*/, this.receivedPartialSums];
111
+ case 3:
112
+ partials = _a.sent();
113
+ partials = partials.insert(0, mySum);
114
+ trainingInformant.update({
115
+ currentNumberOfParticipants: partials.size
116
+ });
117
+ // resets state
118
+ this.receivedPartialSums = undefined;
119
+ this.receivedShares = undefined;
120
+ return [2 /*return*/, partials];
121
+ }
122
+ });
123
+ });
124
+ };
125
+ SecAgg.prototype.receiveShares = function (peers) {
126
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
127
+ var sharesPromises, receivedShares, sharesMessages;
128
+ var _this = this;
129
+ return (0, tslib_1.__generator)(this, function (_a) {
130
+ switch (_a.label) {
131
+ case 0:
132
+ sharesPromises = Array.from(peers.values()).map(function (peer) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
133
+ switch (_a.label) {
134
+ case 0: return [4 /*yield*/, (0, event_connection_1.waitMessage)(peer, messages_1.type.Shares)];
135
+ case 1: return [2 /*return*/, _a.sent()];
136
+ }
137
+ }); }); });
138
+ receivedShares = (0, immutable_1.List)();
139
+ return [4 /*yield*/, Promise.all(sharesPromises)];
140
+ case 1:
141
+ sharesMessages = _a.sent();
142
+ sharesMessages.forEach(function (message) {
143
+ receivedShares = receivedShares.push(__1.serialization.weights.decode(message.weights));
144
+ });
145
+ if (receivedShares.size < peers.size) {
146
+ throw new Error('Not enough shares received');
147
+ }
148
+ return [2 /*return*/, receivedShares];
149
+ }
150
+ });
151
+ });
152
+ };
153
+ SecAgg.prototype.receivePartials = function (peers) {
154
+ return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
155
+ var partialsPromises, receivedPartials, partialMessages;
156
+ var _this = this;
157
+ return (0, tslib_1.__generator)(this, function (_a) {
158
+ switch (_a.label) {
159
+ case 0:
160
+ partialsPromises = Array.from(peers.values()).map(function (peer) { return (0, tslib_1.__awaiter)(_this, void 0, void 0, function () { return (0, tslib_1.__generator)(this, function (_a) {
161
+ switch (_a.label) {
162
+ case 0: return [4 /*yield*/, (0, event_connection_1.waitMessage)(peer, messages_1.type.PartialSums)];
163
+ case 1: return [2 /*return*/, _a.sent()];
164
+ }
165
+ }); }); });
166
+ receivedPartials = (0, immutable_1.List)();
167
+ return [4 /*yield*/, Promise.all(partialsPromises)];
168
+ case 1:
169
+ partialMessages = _a.sent();
170
+ partialMessages.forEach(function (message) {
171
+ receivedPartials = receivedPartials.push(__1.serialization.weights.decode(message.partials));
172
+ });
173
+ if (receivedPartials.size < peers.size) {
174
+ throw new Error('Not enough partials received');
175
+ }
176
+ return [2 /*return*/, receivedPartials];
177
+ }
178
+ });
179
+ });
180
+ };
181
+ /*
182
+ handles received messages from signaling server
183
+ */
184
+ SecAgg.prototype.clientHandle = function (peers) {
185
+ this.receivedShares = this.receiveShares(peers);
186
+ this.receivedPartialSums = this.receivePartials(peers);
187
+ };
188
+ return SecAgg;
189
+ }(base_1.Base));
190
+ exports.SecAgg = SecAgg;
@@ -0,0 +1,3 @@
1
+ import { List } from 'immutable';
2
+ import { WeightsContainer } from '../..';
3
+ export declare function generateAllShares(secret: WeightsContainer, nParticipants: number, maxShareValue: number): List<WeightsContainer>;
@@ -0,0 +1,39 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.generateAllShares = void 0;
4
+ var tslib_1 = require("tslib");
5
+ var immutable_1 = require("immutable");
6
+ var crypto = (0, tslib_1.__importStar)(require("crypto"));
7
+ var __1 = require("../..");
8
+ var maxSeed = Math.pow(2, 47);
9
+ /*
10
+ Return Weights in the remaining share once N-1 shares have been constructed (where N is number of ready clients)
11
+ */
12
+ function lastShare(currentShares, secret) {
13
+ if (currentShares.size === 0) {
14
+ throw new Error('Need at least one current share to be able to subtract secret from');
15
+ }
16
+ return secret.sub(__1.aggregation.sum(currentShares));
17
+ }
18
+ /*
19
+ Generate N additive shares that aggregate to the secret weights array (where N is number of ready clients)
20
+ */
21
+ function generateAllShares(secret, nParticipants, maxShareValue) {
22
+ if (nParticipants < 1) {
23
+ throw new Error('too few participants to genreate shares');
24
+ }
25
+ var randomShares = (0, immutable_1.Range)(0, nParticipants - 1)
26
+ .map(function () { return generateRandomShare(secret, maxShareValue); })
27
+ .toList();
28
+ return randomShares
29
+ .push(lastShare(randomShares, secret));
30
+ }
31
+ exports.generateAllShares = generateAllShares;
32
+ /*
33
+ generates one share in the same shape as the secret that is populated with values randomly chosend from
34
+ a uniform distribution between (-maxShareValue, maxShareValue).
35
+ */
36
+ function generateRandomShare(secret, maxShareValue) {
37
+ var seed = crypto.randomInt(maxSeed);
38
+ return secret.map(function (t) { return __1.tf.randomUniform(t.shape, -maxShareValue, maxShareValue, 'float32', seed); });
39
+ }
@@ -0,0 +1,2 @@
1
+ export declare type PeerID = number;
2
+ export declare function isPeerID(raw: unknown): raw is PeerID;
@@ -0,0 +1,7 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.isPeerID = void 0;
4
+ function isPeerID(raw) {
5
+ return typeof raw === 'number';
6
+ }
7
+ exports.isPeerID = isPeerID;
@@ -0,0 +1,37 @@
1
+ import { Peer } from './decentralized/peer';
2
+ import { PeerID } from './decentralized/types';
3
+ import { type, NarrowMessage, Message } from './messages';
4
+ import { SignalData } from 'simple-peer';
5
+ export interface EventConnection {
6
+ on: <K extends type>(type: K, handler: (event: NarrowMessage<K>) => void) => void;
7
+ once: <K extends type>(type: K, handler: (event: NarrowMessage<K>) => void) => void;
8
+ send: <T extends Message>(msg: T) => void;
9
+ disconnect: () => void;
10
+ }
11
+ export declare function waitMessage<T extends type>(connection: EventConnection, type: T): Promise<NarrowMessage<T>>;
12
+ export declare function waitMessageWithTimeout<T extends type>(connection: EventConnection, type: T, timeoutMs: number): Promise<NarrowMessage<T>>;
13
+ export declare class PeerConnection implements EventConnection {
14
+ private readonly selfId;
15
+ private readonly peer;
16
+ private readonly signallingServer;
17
+ private readonly eventEmitter;
18
+ constructor(selfId: PeerID, peer: Peer, signallingServer: EventConnection);
19
+ connect(): Promise<void>;
20
+ signal(signal: SignalData): void;
21
+ on<K extends type>(type: K, handler: (event: NarrowMessage<K>) => void): void;
22
+ once<K extends type>(type: K, handler: (event: NarrowMessage<K>) => void): void;
23
+ send<T extends Message>(msg: T): void;
24
+ disconnect(): void;
25
+ }
26
+ export declare class WebSocketServer implements EventConnection {
27
+ private readonly socket;
28
+ private readonly eventEmitter;
29
+ private readonly validateReceived?;
30
+ private readonly validateSent?;
31
+ private constructor();
32
+ static connect(url: URL, validateReceived?: (msg: any) => boolean, validateSent?: (msg: any) => boolean): Promise<WebSocketServer>;
33
+ disconnect(): void;
34
+ on<K extends type>(type: K, handler: (event: NarrowMessage<K>) => void): void;
35
+ once<K extends type>(type: K, handler: (event: NarrowMessage<K>) => void): void;
36
+ send(msg: Message): void;
37
+ }