@epfml/discojs 3.0.1-p20241007204240.0 → 3.0.1-p20241024094708.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
- package/dist/aggregator/{base.js → aggregator.js} +48 -36
- package/dist/aggregator/get.d.ts +2 -2
- package/dist/aggregator/get.js +4 -4
- package/dist/aggregator/index.d.ts +1 -4
- package/dist/aggregator/index.js +1 -1
- package/dist/aggregator/mean.d.ts +4 -4
- package/dist/aggregator/mean.js +5 -15
- package/dist/aggregator/secure.d.ts +4 -4
- package/dist/aggregator/secure.js +7 -17
- package/dist/client/client.d.ts +71 -17
- package/dist/client/client.js +118 -17
- package/dist/client/decentralized/decentralized_client.d.ts +11 -13
- package/dist/client/decentralized/decentralized_client.js +121 -84
- package/dist/client/decentralized/messages.d.ts +12 -6
- package/dist/client/decentralized/messages.js +9 -8
- package/dist/client/event_connection.js +2 -2
- package/dist/client/federated/federated_client.d.ts +1 -13
- package/dist/client/federated/federated_client.js +15 -94
- package/dist/client/federated/messages.d.ts +6 -11
- package/dist/client/local_client.d.ts +1 -0
- package/dist/client/local_client.js +3 -0
- package/dist/client/messages.d.ts +14 -7
- package/dist/client/messages.js +13 -11
- package/dist/default_tasks/cifar10.js +1 -1
- package/dist/default_tasks/lus_covid.js +1 -0
- package/dist/default_tasks/mnist.js +1 -1
- package/dist/default_tasks/simple_face.js +1 -0
- package/dist/default_tasks/titanic.js +1 -0
- package/dist/default_tasks/wikitext.js +1 -0
- package/dist/index.d.ts +0 -2
- package/dist/serialization/coder.d.ts +4 -0
- package/dist/serialization/coder.js +51 -0
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +1 -0
- package/dist/serialization/model.d.ts +1 -2
- package/dist/serialization/model.js +9 -24
- package/dist/serialization/weights.d.ts +2 -3
- package/dist/serialization/weights.js +15 -26
- package/dist/task/task_handler.d.ts +5 -5
- package/dist/task/task_handler.js +21 -15
- package/dist/task/training_information.d.ts +1 -2
- package/dist/task/training_information.js +6 -8
- package/dist/training/disco.d.ts +4 -1
- package/dist/training/trainer.js +1 -1
- package/dist/utils/event_emitter.d.ts +3 -3
- package/dist/utils/event_emitter.js +10 -9
- package/package.json +2 -3
|
@@ -1,22 +1,28 @@
|
|
|
1
|
-
import axios from 'axios';
|
|
2
1
|
import createDebug from "debug";
|
|
3
|
-
import { Map } from
|
|
4
|
-
import { serialization } from
|
|
5
|
-
import { isTask } from
|
|
2
|
+
import { Map } from "immutable";
|
|
3
|
+
import { serialization } from "../index.js";
|
|
4
|
+
import { isTask } from "./task.js";
|
|
6
5
|
const debug = createDebug("discojs:task:handlers");
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
6
|
+
function urlToTasks(base) {
|
|
7
|
+
const ret = new URL(base);
|
|
8
|
+
ret.pathname += "tasks";
|
|
9
|
+
return ret;
|
|
10
|
+
}
|
|
11
|
+
export async function pushTask(base, task, model) {
|
|
12
|
+
await fetch(urlToTasks(base), {
|
|
13
|
+
method: "POST",
|
|
14
|
+
body: JSON.stringify({
|
|
15
|
+
task,
|
|
16
|
+
model: await serialization.model.encode(model),
|
|
17
|
+
weights: await serialization.weights.encode(model.weights),
|
|
18
|
+
}),
|
|
13
19
|
});
|
|
14
20
|
}
|
|
15
|
-
export async function fetchTasks(
|
|
16
|
-
const response = await
|
|
17
|
-
const tasks = response.
|
|
21
|
+
export async function fetchTasks(base) {
|
|
22
|
+
const response = await fetch(urlToTasks(base));
|
|
23
|
+
const tasks = await response.json();
|
|
18
24
|
if (!Array.isArray(tasks)) {
|
|
19
|
-
throw new Error(
|
|
25
|
+
throw new Error("Expected to receive an array of Tasks when fetching tasks");
|
|
20
26
|
}
|
|
21
27
|
else if (!tasks.every(isTask)) {
|
|
22
28
|
for (const task of tasks) {
|
|
@@ -24,7 +30,7 @@ export async function fetchTasks(url) {
|
|
|
24
30
|
debug("task has invalid format: :O", task);
|
|
25
31
|
}
|
|
26
32
|
}
|
|
27
|
-
throw new Error(
|
|
33
|
+
throw new Error("invalid tasks response, the task object received is not well formatted");
|
|
28
34
|
}
|
|
29
35
|
return Map(tasks.map((t) => [t.id, t]));
|
|
30
36
|
}
|
|
@@ -18,10 +18,9 @@ export interface TrainingInformation {
|
|
|
18
18
|
LABEL_LIST?: string[];
|
|
19
19
|
scheme: 'decentralized' | 'federated' | 'local';
|
|
20
20
|
privacy?: Privacy;
|
|
21
|
-
decentralizedSecure?: boolean;
|
|
22
21
|
maxShareValue?: number;
|
|
23
22
|
minNbOfParticipants: number;
|
|
24
|
-
|
|
23
|
+
aggregationStrategy?: 'mean' | 'secure';
|
|
25
24
|
tokenizer?: string | PreTrainedTokenizer;
|
|
26
25
|
maxSequenceLength?: number;
|
|
27
26
|
tensorBackend: 'tfjs' | 'gpt';
|
|
@@ -24,7 +24,7 @@ export function isTrainingInformation(raw) {
|
|
|
24
24
|
if (typeof raw !== 'object' || raw === null) {
|
|
25
25
|
return false;
|
|
26
26
|
}
|
|
27
|
-
const { IMAGE_H, IMAGE_W, LABEL_LIST,
|
|
27
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregationStrategy, batchSize, dataType, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
|
|
28
28
|
if (typeof dataType !== 'string' ||
|
|
29
29
|
typeof epochs !== 'number' ||
|
|
30
30
|
typeof batchSize !== 'number' ||
|
|
@@ -33,8 +33,7 @@ export function isTrainingInformation(raw) {
|
|
|
33
33
|
typeof minNbOfParticipants !== 'number' ||
|
|
34
34
|
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
35
35
|
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
36
|
-
(
|
|
37
|
-
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
|
|
36
|
+
(aggregationStrategy !== undefined && typeof aggregationStrategy !== 'string') ||
|
|
38
37
|
(privacy !== undefined && !isPrivacy(privacy)) ||
|
|
39
38
|
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
|
|
40
39
|
(IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
|
|
@@ -45,8 +44,8 @@ export function isTrainingInformation(raw) {
|
|
|
45
44
|
(preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
|
|
46
45
|
return false;
|
|
47
46
|
}
|
|
48
|
-
if (
|
|
49
|
-
switch (
|
|
47
|
+
if (aggregationStrategy !== undefined) {
|
|
48
|
+
switch (aggregationStrategy) {
|
|
50
49
|
case 'mean': break;
|
|
51
50
|
case 'secure': break;
|
|
52
51
|
default: return false;
|
|
@@ -58,7 +57,7 @@ export function isTrainingInformation(raw) {
|
|
|
58
57
|
case 'text': break;
|
|
59
58
|
default: return false;
|
|
60
59
|
}
|
|
61
|
-
//
|
|
60
|
+
// interdependencies on data type
|
|
62
61
|
if (dataType === 'image') {
|
|
63
62
|
if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
|
|
64
63
|
return false;
|
|
@@ -87,10 +86,9 @@ export function isTrainingInformation(raw) {
|
|
|
87
86
|
IMAGE_W,
|
|
88
87
|
IMAGE_H,
|
|
89
88
|
LABEL_LIST,
|
|
90
|
-
|
|
89
|
+
aggregationStrategy,
|
|
91
90
|
batchSize,
|
|
92
91
|
dataType,
|
|
93
|
-
decentralizedSecure,
|
|
94
92
|
privacy,
|
|
95
93
|
epochs,
|
|
96
94
|
inputColumns,
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -7,7 +7,10 @@ interface DiscoConfig {
|
|
|
7
7
|
scheme: TrainingInformation["scheme"];
|
|
8
8
|
logger: Logger;
|
|
9
9
|
}
|
|
10
|
-
export type RoundStatus =
|
|
10
|
+
export type RoundStatus = 'not enough participants' | // Server notification to wait for more participants
|
|
11
|
+
'updating model' | // fetching/aggregating local updates into a global model
|
|
12
|
+
'local training' | // Training the model locally
|
|
13
|
+
'connecting to peers';
|
|
11
14
|
/**
|
|
12
15
|
* Top-level class handling distributed training from a client's perspective. It is meant to be
|
|
13
16
|
* a convenient object providing a reduced yet complete API that wraps model training and
|
package/dist/training/trainer.js
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
type Listener<T> = (_: T) => void
|
|
1
|
+
type Listener<T> = (_: T) => void | Promise<void>;
|
|
2
2
|
/**
|
|
3
3
|
* Call handlers on given events
|
|
4
4
|
*
|
|
5
5
|
* @typeParam I object/mapping from event name to emitted value type
|
|
6
6
|
*/
|
|
7
7
|
export declare class EventEmitter<I extends Record<string, unknown>> {
|
|
8
|
-
private
|
|
8
|
+
#private;
|
|
9
9
|
/**
|
|
10
10
|
* @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
|
|
11
11
|
*/
|
|
@@ -13,7 +13,7 @@ export declare class EventEmitter<I extends Record<string, unknown>> {
|
|
|
13
13
|
[E in keyof I]?: Listener<I[E]>;
|
|
14
14
|
});
|
|
15
15
|
/**
|
|
16
|
-
* Register listener to call on event
|
|
16
|
+
* Register listener to call on event.
|
|
17
17
|
*
|
|
18
18
|
* @param event event name to listen to
|
|
19
19
|
* @param listener handler to call
|
|
@@ -6,7 +6,8 @@ import { List } from 'immutable';
|
|
|
6
6
|
* @typeParam I object/mapping from event name to emitted value type
|
|
7
7
|
*/
|
|
8
8
|
export class EventEmitter {
|
|
9
|
-
|
|
9
|
+
// List of callbacks to run per event
|
|
10
|
+
#listeners = {};
|
|
10
11
|
/**
|
|
11
12
|
* @param initialListeners object/mapping of event name to listener, as if using `on` on created instance
|
|
12
13
|
*/
|
|
@@ -19,14 +20,14 @@ export class EventEmitter {
|
|
|
19
20
|
}
|
|
20
21
|
}
|
|
21
22
|
/**
|
|
22
|
-
* Register listener to call on event
|
|
23
|
+
* Register listener to call on event.
|
|
23
24
|
*
|
|
24
25
|
* @param event event name to listen to
|
|
25
26
|
* @param listener handler to call
|
|
26
27
|
*/
|
|
27
28
|
on(event, listener) {
|
|
28
|
-
const eventListeners = this
|
|
29
|
-
this
|
|
29
|
+
const eventListeners = this.#listeners[event] ?? List();
|
|
30
|
+
this.#listeners[event] = eventListeners.push([false, listener]);
|
|
30
31
|
}
|
|
31
32
|
/**
|
|
32
33
|
* Register listener to call once on next event
|
|
@@ -35,8 +36,8 @@ export class EventEmitter {
|
|
|
35
36
|
* @param listener handler to call next time
|
|
36
37
|
*/
|
|
37
38
|
once(event, listener) {
|
|
38
|
-
const eventListeners = this
|
|
39
|
-
this
|
|
39
|
+
const eventListeners = this.#listeners[event] ?? List();
|
|
40
|
+
this.#listeners[event] = eventListeners.push([true, listener]);
|
|
40
41
|
}
|
|
41
42
|
/**
|
|
42
43
|
* Send value to registered listeners of event name
|
|
@@ -45,9 +46,9 @@ export class EventEmitter {
|
|
|
45
46
|
* @param value what to call listeners with
|
|
46
47
|
*/
|
|
47
48
|
emit(event, value) {
|
|
48
|
-
const eventListeners = this
|
|
49
|
-
this
|
|
50
|
-
eventListeners.forEach(([_, listener]) => { listener(value); });
|
|
49
|
+
const eventListeners = this.#listeners[event] ?? List();
|
|
50
|
+
this.#listeners[event] = eventListeners.filterNot(([once]) => once);
|
|
51
|
+
eventListeners.forEach(async ([_, listener]) => { await listener(value); });
|
|
51
52
|
}
|
|
52
53
|
}
|
|
53
54
|
/** `EventEmitter` for all events */
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@epfml/discojs",
|
|
3
|
-
"version": "3.0.1-
|
|
3
|
+
"version": "3.0.1-p20241024094708.0",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"main": "dist/index.js",
|
|
6
6
|
"types": "dist/index.d.ts",
|
|
@@ -19,13 +19,12 @@
|
|
|
19
19
|
},
|
|
20
20
|
"homepage": "https://github.com/epfml/disco#readme",
|
|
21
21
|
"dependencies": {
|
|
22
|
+
"@msgpack/msgpack": "^3.0.0-beta2",
|
|
22
23
|
"@tensorflow/tfjs": "4",
|
|
23
24
|
"@xenova/transformers": "2",
|
|
24
|
-
"axios": "1",
|
|
25
25
|
"immutable": "4",
|
|
26
26
|
"isomorphic-wrtc": "1",
|
|
27
27
|
"isomorphic-ws": "5",
|
|
28
|
-
"msgpack-lite": "0.1",
|
|
29
28
|
"simple-peer": "9",
|
|
30
29
|
"tslib": "2",
|
|
31
30
|
"ws": "8"
|