@epfml/discojs 3.0.1-p20241014092014.0 → 3.0.1-p20241024094708.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/client/client.js +3 -3
- package/dist/client/decentralized/messages.d.ts +2 -2
- package/dist/client/decentralized/messages.js +2 -2
- package/dist/client/event_connection.js +2 -2
- package/dist/client/federated/messages.d.ts +4 -4
- package/dist/index.d.ts +0 -2
- package/dist/serialization/coder.d.ts +4 -0
- package/dist/serialization/coder.js +51 -0
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +1 -0
- package/dist/serialization/model.d.ts +1 -2
- package/dist/serialization/model.js +9 -24
- package/dist/serialization/weights.d.ts +2 -3
- package/dist/serialization/weights.js +15 -26
- package/dist/task/task_handler.d.ts +5 -5
- package/dist/task/task_handler.js +21 -15
- package/package.json +2 -3
package/dist/client/client.js
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
|
-
import axios from 'axios';
|
|
3
2
|
import { serialization } from '../index.js';
|
|
4
3
|
import { EventEmitter } from '../utils/event_emitter.js';
|
|
5
4
|
import { type } from "./messages.js";
|
|
@@ -149,8 +148,9 @@ export class Client extends EventEmitter {
|
|
|
149
148
|
url.pathname += '/';
|
|
150
149
|
}
|
|
151
150
|
url.pathname += `tasks/${this.task.id}/model.json`;
|
|
152
|
-
const response = await
|
|
153
|
-
|
|
151
|
+
const response = await fetch(url);
|
|
152
|
+
const encoded = new Uint8Array(await response.arrayBuffer());
|
|
153
|
+
return await serialization.model.decode(encoded);
|
|
154
154
|
}
|
|
155
155
|
get ownId() {
|
|
156
156
|
if (this._ownId === undefined) {
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { serialization } from "../../index.js";
|
|
2
2
|
import { type SignalData } from './peer.js';
|
|
3
3
|
import { type NodeID } from '../types.js';
|
|
4
4
|
import { type } from '../messages.js';
|
|
@@ -29,7 +29,7 @@ export interface Payload {
|
|
|
29
29
|
peer: NodeID;
|
|
30
30
|
aggregationRound: number;
|
|
31
31
|
communicationRound: number;
|
|
32
|
-
payload:
|
|
32
|
+
payload: serialization.Encoded;
|
|
33
33
|
}
|
|
34
34
|
export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound | WaitingForMoreParticipants | EnoughParticipants;
|
|
35
35
|
export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | JoinRound;
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { serialization } from "../../index.js";
|
|
2
2
|
import { isNodeID } from '../types.js';
|
|
3
3
|
import { type, hasMessageType } from '../messages.js';
|
|
4
4
|
export function isMessageFromServer(o) {
|
|
@@ -41,7 +41,7 @@ export function isPeerMessage(o) {
|
|
|
41
41
|
switch (o.type) {
|
|
42
42
|
case type.Payload:
|
|
43
43
|
return ('peer' in o && isNodeID(o.peer) &&
|
|
44
|
-
'payload' in o &&
|
|
44
|
+
'payload' in o && serialization.isEncoded(o.payload));
|
|
45
45
|
}
|
|
46
46
|
return false;
|
|
47
47
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import createDebug from "debug";
|
|
2
2
|
import WebSocket from "isomorphic-ws";
|
|
3
|
-
import msgpack from "msgpack
|
|
3
|
+
import * as msgpack from "@msgpack/msgpack";
|
|
4
4
|
import * as decentralizedMessages from './decentralized/messages.js';
|
|
5
5
|
import { type } from './messages.js';
|
|
6
6
|
import { timeout } from './utils.js';
|
|
@@ -57,7 +57,7 @@ export class PeerConnection extends EventEmitter {
|
|
|
57
57
|
if (!decentralizedMessages.isPeerMessage(msg)) {
|
|
58
58
|
throw new Error(`can't send this type of message: ${JSON.stringify(msg)}`);
|
|
59
59
|
}
|
|
60
|
-
this.peer.send(msgpack.encode(msg));
|
|
60
|
+
this.peer.send(Buffer.from(msgpack.encode(msg)));
|
|
61
61
|
}
|
|
62
62
|
async disconnect() {
|
|
63
63
|
await this.peer.destroy();
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import type { serialization } from "../../index.js";
|
|
2
2
|
import { type NodeID } from '..//types.js';
|
|
3
3
|
import { type } from '../messages.js';
|
|
4
4
|
import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
|
|
@@ -7,18 +7,18 @@ export interface NewFederatedNodeInfo {
|
|
|
7
7
|
type: type.NewFederatedNodeInfo;
|
|
8
8
|
id: NodeID;
|
|
9
9
|
waitForMoreParticipants: boolean;
|
|
10
|
-
payload:
|
|
10
|
+
payload: serialization.Encoded;
|
|
11
11
|
round: number;
|
|
12
12
|
nbOfParticipants: number;
|
|
13
13
|
}
|
|
14
14
|
export interface SendPayload {
|
|
15
15
|
type: type.SendPayload;
|
|
16
|
-
payload:
|
|
16
|
+
payload: serialization.Encoded;
|
|
17
17
|
round: number;
|
|
18
18
|
}
|
|
19
19
|
export interface ReceiveServerPayload {
|
|
20
20
|
type: type.ReceiveServerPayload;
|
|
21
|
-
payload:
|
|
21
|
+
payload: serialization.Encoded;
|
|
22
22
|
round: number;
|
|
23
23
|
nbOfParticipants: number;
|
|
24
24
|
}
|
package/dist/index.d.ts
CHANGED
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
export * as data from './dataset/index.js';
|
|
2
2
|
export * as serialization from './serialization/index.js';
|
|
3
|
-
export { Encoded as EncodedModel } from './serialization/model.js';
|
|
4
|
-
export { Encoded as EncodedWeights } from './serialization/weights.js';
|
|
5
3
|
export * as training from './training/index.js';
|
|
6
4
|
export * as privacy from './privacy.js';
|
|
7
5
|
export * as client from './client/index.js';
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import * as msgpack from "@msgpack/msgpack";
|
|
2
|
+
export function isEncoded(raw) {
|
|
3
|
+
if (!(raw instanceof Uint8Array))
|
|
4
|
+
return false;
|
|
5
|
+
const _ = raw;
|
|
6
|
+
return true;
|
|
7
|
+
}
|
|
8
|
+
// create a new buffer instead of referencing the backing one
|
|
9
|
+
function copy(arr) {
|
|
10
|
+
// `Buffer.slice` (subclass of Uint8Array on Node) doesn't copy
|
|
11
|
+
// thus doesn't respect Liskov substitution principle
|
|
12
|
+
// https://nodejs.org/api/buffer.html#bufslicestart-end
|
|
13
|
+
// here we call the correct implementation
|
|
14
|
+
return Uint8Array.prototype.slice.call(arr);
|
|
15
|
+
}
|
|
16
|
+
// to avoid mapping every ArrayBuffer to Uint8Array,
|
|
17
|
+
// we register our own convertors for the type we know are needed
|
|
18
|
+
// type id are arbitrally taken from msgpack-lite
|
|
19
|
+
// https://www.npmjs.com/package/msgpack-lite#extension-types
|
|
20
|
+
const CODEC = new msgpack.ExtensionCodec();
|
|
21
|
+
// used by TFJS's weights
|
|
22
|
+
CODEC.register({
|
|
23
|
+
type: 0x17,
|
|
24
|
+
encode(obj) {
|
|
25
|
+
if (!(obj instanceof Float32Array))
|
|
26
|
+
return null;
|
|
27
|
+
return new Uint8Array(obj.buffer, obj.byteOffset, obj.byteLength);
|
|
28
|
+
},
|
|
29
|
+
decode: (raw) =>
|
|
30
|
+
// to reinterpred uint8 into float32, it needs to be 4-bytes aligned
|
|
31
|
+
// but the given buffer might not be so we need to copy it.
|
|
32
|
+
new Float32Array(copy(raw).buffer),
|
|
33
|
+
});
|
|
34
|
+
// used by TFJS's saved model
|
|
35
|
+
CODEC.register({
|
|
36
|
+
type: 0x1a,
|
|
37
|
+
encode(obj) {
|
|
38
|
+
if (!(obj instanceof ArrayBuffer))
|
|
39
|
+
return null;
|
|
40
|
+
return new Uint8Array(obj);
|
|
41
|
+
},
|
|
42
|
+
decode: (raw) =>
|
|
43
|
+
// need to copy as backing ArrayBuffer might be larger
|
|
44
|
+
copy(raw),
|
|
45
|
+
});
|
|
46
|
+
export function encode(serialized) {
|
|
47
|
+
return msgpack.encode(serialized, { extensionCodec: CODEC });
|
|
48
|
+
}
|
|
49
|
+
export function decode(encoded) {
|
|
50
|
+
return msgpack.decode(encoded, { extensionCodec: CODEC });
|
|
51
|
+
}
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import type { Model } from '../index.js';
|
|
2
|
-
|
|
3
|
-
export declare function isEncoded(raw: unknown): raw is Encoded;
|
|
2
|
+
import { Encoded } from "./coder.js";
|
|
4
3
|
export declare function encode(model: Model): Promise<Encoded>;
|
|
5
4
|
export declare function decode(encoded: unknown): Promise<Model>;
|
|
@@ -1,38 +1,29 @@
|
|
|
1
|
-
import msgpack from 'msgpack-lite';
|
|
2
1
|
import { models, serialization } from '../index.js';
|
|
2
|
+
import * as coder from "./coder.js";
|
|
3
|
+
import { isEncoded } from "./coder.js";
|
|
3
4
|
const Type = {
|
|
4
5
|
TFJS: 0,
|
|
5
6
|
GPT: 1
|
|
6
7
|
};
|
|
7
|
-
export function isEncoded(raw) {
|
|
8
|
-
return raw instanceof Uint8Array;
|
|
9
|
-
}
|
|
10
8
|
export async function encode(model) {
|
|
11
|
-
let encoded;
|
|
12
9
|
switch (true) {
|
|
13
10
|
case model instanceof models.TFJS: {
|
|
14
11
|
const serialized = await model.serialize();
|
|
15
|
-
|
|
16
|
-
break;
|
|
12
|
+
return coder.encode([Type.TFJS, serialized]);
|
|
17
13
|
}
|
|
18
14
|
case model instanceof models.GPT: {
|
|
19
15
|
const { weights, config } = model.serialize();
|
|
20
16
|
const serializedWeights = await serialization.weights.encode(weights);
|
|
21
|
-
|
|
22
|
-
break;
|
|
17
|
+
return coder.encode([Type.GPT, serializedWeights, config]);
|
|
23
18
|
}
|
|
24
19
|
default:
|
|
25
20
|
throw new Error("unknown model type");
|
|
26
21
|
}
|
|
27
|
-
// Node's Buffer extends Node's Uint8Array, which might not be the same
|
|
28
|
-
// as the browser's Uint8Array. we ensure here that it is.
|
|
29
|
-
return new Uint8Array(encoded);
|
|
30
22
|
}
|
|
31
23
|
export async function decode(encoded) {
|
|
32
|
-
if (!isEncoded(encoded))
|
|
24
|
+
if (!isEncoded(encoded))
|
|
33
25
|
throw new Error("Invalid encoding, raw encoding isn't an instance of Uint8Array");
|
|
34
|
-
|
|
35
|
-
const raw = msgpack.decode(encoded);
|
|
26
|
+
const raw = coder.decode(encoded);
|
|
36
27
|
if (!Array.isArray(raw) || raw.length < 2) {
|
|
37
28
|
throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values");
|
|
38
29
|
}
|
|
@@ -59,15 +50,9 @@ export async function decode(encoded) {
|
|
|
59
50
|
else {
|
|
60
51
|
throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3');
|
|
61
52
|
}
|
|
62
|
-
if (!
|
|
63
|
-
throw new Error(
|
|
64
|
-
|
|
65
|
-
const arr = rawModel;
|
|
66
|
-
if (arr.some((r) => typeof r !== 'number')) {
|
|
67
|
-
throw new Error("invalid encoding, gpt-tfjs weights should be numbers");
|
|
68
|
-
}
|
|
69
|
-
const nums = arr;
|
|
70
|
-
const weights = serialization.weights.decode(nums);
|
|
53
|
+
if (!isEncoded(rawModel))
|
|
54
|
+
throw new Error("invalid encoding, gpt-tfjs model weights should be an encoding of its weights");
|
|
55
|
+
const weights = serialization.weights.decode(rawModel);
|
|
71
56
|
return models.GPT.deserialize({ weights, config });
|
|
72
57
|
}
|
|
73
58
|
default:
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
import { WeightsContainer } from
|
|
2
|
-
|
|
3
|
-
export declare function isEncoded(raw: unknown): raw is Encoded;
|
|
1
|
+
import { WeightsContainer } from "../index.js";
|
|
2
|
+
import { Encoded } from "./coder.js";
|
|
4
3
|
export declare function encode(weights: WeightsContainer): Promise<Encoded>;
|
|
5
4
|
export declare function decode(encoded: Encoded): WeightsContainer;
|
|
@@ -1,37 +1,26 @@
|
|
|
1
|
-
import * as
|
|
2
|
-
import
|
|
3
|
-
import
|
|
1
|
+
import * as tf from "@tensorflow/tfjs";
|
|
2
|
+
import { WeightsContainer } from "../index.js";
|
|
3
|
+
import * as coder from "./coder.js";
|
|
4
4
|
function isSerialized(raw) {
|
|
5
|
-
if (typeof raw !==
|
|
5
|
+
if (typeof raw !== "object" || raw === null)
|
|
6
6
|
return false;
|
|
7
|
-
}
|
|
8
7
|
const { shape, data } = raw;
|
|
9
|
-
if (!(Array.isArray(shape) && shape.every((e) => typeof e ===
|
|
10
|
-
!(
|
|
8
|
+
if (!(Array.isArray(shape) && shape.every((e) => typeof e === "number")) ||
|
|
9
|
+
!(data instanceof Float32Array))
|
|
11
10
|
return false;
|
|
12
|
-
}
|
|
13
|
-
const _ = {
|
|
14
|
-
shape: shape,
|
|
15
|
-
data: data,
|
|
16
|
-
};
|
|
11
|
+
const _ = { shape, data };
|
|
17
12
|
return true;
|
|
18
13
|
}
|
|
19
|
-
export function isEncoded(raw) {
|
|
20
|
-
return Array.isArray(raw) && raw.every((e) => typeof e === 'number');
|
|
21
|
-
}
|
|
22
14
|
export async function encode(weights) {
|
|
23
|
-
const serialized = await Promise.all(weights.weights.map(async (t) => {
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
}));
|
|
29
|
-
return [...msgpack.encode(serialized).values()];
|
|
15
|
+
const serialized = await Promise.all(weights.weights.map(async (t) => ({
|
|
16
|
+
shape: t.shape,
|
|
17
|
+
data: await t.data(),
|
|
18
|
+
})));
|
|
19
|
+
return coder.encode(serialized);
|
|
30
20
|
}
|
|
31
21
|
export function decode(encoded) {
|
|
32
|
-
const raw =
|
|
33
|
-
if (!(Array.isArray(raw) && raw.every(isSerialized)))
|
|
34
|
-
throw new Error(
|
|
35
|
-
}
|
|
22
|
+
const raw = coder.decode(encoded);
|
|
23
|
+
if (!(Array.isArray(raw) && raw.every(isSerialized)))
|
|
24
|
+
throw new Error("expected to decode an array of serialized weights");
|
|
36
25
|
return new WeightsContainer(raw.map((w) => tf.tensor(w.data, w.shape)));
|
|
37
26
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import { Map } from
|
|
2
|
-
import type { Model } from
|
|
3
|
-
import type { Task, TaskID } from
|
|
4
|
-
export declare function pushTask(
|
|
5
|
-
export declare function fetchTasks(
|
|
1
|
+
import { Map } from "immutable";
|
|
2
|
+
import type { Model } from "../index.js";
|
|
3
|
+
import type { Task, TaskID } from "./task.js";
|
|
4
|
+
export declare function pushTask(base: URL, task: Task, model: Model): Promise<void>;
|
|
5
|
+
export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task>>;
|
|
@@ -1,22 +1,28 @@
|
|
|
1
|
-
import axios from 'axios';
|
|
2
1
|
import createDebug from "debug";
|
|
3
|
-
import { Map } from
|
|
4
|
-
import { serialization } from
|
|
5
|
-
import { isTask } from
|
|
2
|
+
import { Map } from "immutable";
|
|
3
|
+
import { serialization } from "../index.js";
|
|
4
|
+
import { isTask } from "./task.js";
|
|
6
5
|
const debug = createDebug("discojs:task:handlers");
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
6
|
+
function urlToTasks(base) {
|
|
7
|
+
const ret = new URL(base);
|
|
8
|
+
ret.pathname += "tasks";
|
|
9
|
+
return ret;
|
|
10
|
+
}
|
|
11
|
+
export async function pushTask(base, task, model) {
|
|
12
|
+
await fetch(urlToTasks(base), {
|
|
13
|
+
method: "POST",
|
|
14
|
+
body: JSON.stringify({
|
|
15
|
+
task,
|
|
16
|
+
model: await serialization.model.encode(model),
|
|
17
|
+
weights: await serialization.weights.encode(model.weights),
|
|
18
|
+
}),
|
|
13
19
|
});
|
|
14
20
|
}
|
|
15
|
-
export async function fetchTasks(
|
|
16
|
-
const response = await
|
|
17
|
-
const tasks = response.
|
|
21
|
+
export async function fetchTasks(base) {
|
|
22
|
+
const response = await fetch(urlToTasks(base));
|
|
23
|
+
const tasks = await response.json();
|
|
18
24
|
if (!Array.isArray(tasks)) {
|
|
19
|
-
throw new Error(
|
|
25
|
+
throw new Error("Expected to receive an array of Tasks when fetching tasks");
|
|
20
26
|
}
|
|
21
27
|
else if (!tasks.every(isTask)) {
|
|
22
28
|
for (const task of tasks) {
|
|
@@ -24,7 +30,7 @@ export async function fetchTasks(url) {
|
|
|
24
30
|
debug("task has invalid format: :O", task);
|
|
25
31
|
}
|
|
26
32
|
}
|
|
27
|
-
throw new Error(
|
|
33
|
+
throw new Error("invalid tasks response, the task object received is not well formatted");
|
|
28
34
|
}
|
|
29
35
|
return Map(tasks.map((t) => [t.id, t]));
|
|
30
36
|
}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@epfml/discojs",
|
|
3
|
-
"version": "3.0.1-
|
|
3
|
+
"version": "3.0.1-p20241024094708.0",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"main": "dist/index.js",
|
|
6
6
|
"types": "dist/index.d.ts",
|
|
@@ -19,13 +19,12 @@
|
|
|
19
19
|
},
|
|
20
20
|
"homepage": "https://github.com/epfml/disco#readme",
|
|
21
21
|
"dependencies": {
|
|
22
|
+
"@msgpack/msgpack": "^3.0.0-beta2",
|
|
22
23
|
"@tensorflow/tfjs": "4",
|
|
23
24
|
"@xenova/transformers": "2",
|
|
24
|
-
"axios": "1",
|
|
25
25
|
"immutable": "4",
|
|
26
26
|
"isomorphic-wrtc": "1",
|
|
27
27
|
"isomorphic-ws": "5",
|
|
28
|
-
"msgpack-lite": "0.1",
|
|
29
28
|
"simple-peer": "9",
|
|
30
29
|
"tslib": "2",
|
|
31
30
|
"ws": "8"
|