@epfml/discojs 3.0.1-p20241014092014.0 → 3.0.1-p20241025115642.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,4 @@
1
1
  import createDebug from "debug";
2
- import axios from 'axios';
3
2
  import { serialization } from '../index.js';
4
3
  import { EventEmitter } from '../utils/event_emitter.js';
5
4
  import { type } from "./messages.js";
@@ -149,8 +148,9 @@ export class Client extends EventEmitter {
149
148
  url.pathname += '/';
150
149
  }
151
150
  url.pathname += `tasks/${this.task.id}/model.json`;
152
- const response = await axios.get(url.href, { responseType: 'arraybuffer' });
153
- return await serialization.model.decode(new Uint8Array(response.data));
151
+ const response = await fetch(url);
152
+ const encoded = new Uint8Array(await response.arrayBuffer());
153
+ return await serialization.model.decode(encoded);
154
154
  }
155
155
  get ownId() {
156
156
  if (this._ownId === undefined) {
@@ -1,4 +1,4 @@
1
- import { weights } from '../../serialization/index.js';
1
+ import { serialization } from "../../index.js";
2
2
  import { type SignalData } from './peer.js';
3
3
  import { type NodeID } from '../types.js';
4
4
  import { type } from '../messages.js';
@@ -29,7 +29,7 @@ export interface Payload {
29
29
  peer: NodeID;
30
30
  aggregationRound: number;
31
31
  communicationRound: number;
32
- payload: weights.Encoded;
32
+ payload: serialization.Encoded;
33
33
  }
34
34
  export type MessageFromServer = NewDecentralizedNodeInfo | SignalForPeer | PeersForRound | WaitingForMoreParticipants | EnoughParticipants;
35
35
  export type MessageToServer = ClientConnected | SignalForPeer | PeerIsReady | JoinRound;
@@ -1,4 +1,4 @@
1
- import { weights } from '../../serialization/index.js';
1
+ import { serialization } from "../../index.js";
2
2
  import { isNodeID } from '../types.js';
3
3
  import { type, hasMessageType } from '../messages.js';
4
4
  export function isMessageFromServer(o) {
@@ -41,7 +41,7 @@ export function isPeerMessage(o) {
41
41
  switch (o.type) {
42
42
  case type.Payload:
43
43
  return ('peer' in o && isNodeID(o.peer) &&
44
- 'payload' in o && weights.isEncoded(o.payload));
44
+ 'payload' in o && serialization.isEncoded(o.payload));
45
45
  }
46
46
  return false;
47
47
  }
@@ -1,6 +1,6 @@
1
1
  import createDebug from "debug";
2
2
  import WebSocket from "isomorphic-ws";
3
- import msgpack from "msgpack-lite";
3
+ import * as msgpack from "@msgpack/msgpack";
4
4
  import * as decentralizedMessages from './decentralized/messages.js';
5
5
  import { type } from './messages.js';
6
6
  import { timeout } from './utils.js';
@@ -57,7 +57,7 @@ export class PeerConnection extends EventEmitter {
57
57
  if (!decentralizedMessages.isPeerMessage(msg)) {
58
58
  throw new Error(`can't send this type of message: ${JSON.stringify(msg)}`);
59
59
  }
60
- this.peer.send(msgpack.encode(msg));
60
+ this.peer.send(Buffer.from(msgpack.encode(msg)));
61
61
  }
62
62
  async disconnect() {
63
63
  await this.peer.destroy();
@@ -1,4 +1,4 @@
1
- import { type weights } from '../../serialization/index.js';
1
+ import type { serialization } from "../../index.js";
2
2
  import { type NodeID } from '..//types.js';
3
3
  import { type } from '../messages.js';
4
4
  import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
@@ -7,18 +7,18 @@ export interface NewFederatedNodeInfo {
7
7
  type: type.NewFederatedNodeInfo;
8
8
  id: NodeID;
9
9
  waitForMoreParticipants: boolean;
10
- payload: weights.Encoded;
10
+ payload: serialization.Encoded;
11
11
  round: number;
12
12
  nbOfParticipants: number;
13
13
  }
14
14
  export interface SendPayload {
15
15
  type: type.SendPayload;
16
- payload: weights.Encoded;
16
+ payload: serialization.Encoded;
17
17
  round: number;
18
18
  }
19
19
  export interface ReceiveServerPayload {
20
20
  type: type.ReceiveServerPayload;
21
- payload: weights.Encoded;
21
+ payload: serialization.Encoded;
22
22
  round: number;
23
23
  nbOfParticipants: number;
24
24
  }
package/dist/index.d.ts CHANGED
@@ -1,7 +1,5 @@
1
1
  export * as data from './dataset/index.js';
2
2
  export * as serialization from './serialization/index.js';
3
- export { Encoded as EncodedModel } from './serialization/model.js';
4
- export { Encoded as EncodedWeights } from './serialization/weights.js';
5
3
  export * as training from './training/index.js';
6
4
  export * as privacy from './privacy.js';
7
5
  export * as client from './client/index.js';
@@ -0,0 +1,4 @@
1
+ export type Encoded = Uint8Array;
2
+ export declare function isEncoded(raw: unknown): raw is Encoded;
3
+ export declare function encode(serialized: unknown): Encoded;
4
+ export declare function decode(encoded: Encoded): unknown;
@@ -0,0 +1,51 @@
1
+ import * as msgpack from "@msgpack/msgpack";
2
+ export function isEncoded(raw) {
3
+ if (!(raw instanceof Uint8Array))
4
+ return false;
5
+ const _ = raw;
6
+ return true;
7
+ }
8
+ // create a new buffer instead of referencing the backing one
9
+ function copy(arr) {
10
+ // `Buffer.slice` (subclass of Uint8Array on Node) doesn't copy
11
+ // thus doesn't respect Liskov substitution principle
12
+ // https://nodejs.org/api/buffer.html#bufslicestart-end
13
+ // here we call the correct implementation
14
+ return Uint8Array.prototype.slice.call(arr);
15
+ }
16
+ // to avoid mapping every ArrayBuffer to Uint8Array,
17
+ // we register our own convertors for the type we know are needed
18
+ // type id are arbitrally taken from msgpack-lite
19
+ // https://www.npmjs.com/package/msgpack-lite#extension-types
20
+ const CODEC = new msgpack.ExtensionCodec();
21
+ // used by TFJS's weights
22
+ CODEC.register({
23
+ type: 0x17,
24
+ encode(obj) {
25
+ if (!(obj instanceof Float32Array))
26
+ return null;
27
+ return new Uint8Array(obj.buffer, obj.byteOffset, obj.byteLength);
28
+ },
29
+ decode: (raw) =>
30
+ // to reinterpred uint8 into float32, it needs to be 4-bytes aligned
31
+ // but the given buffer might not be so we need to copy it.
32
+ new Float32Array(copy(raw).buffer),
33
+ });
34
+ // used by TFJS's saved model
35
+ CODEC.register({
36
+ type: 0x1a,
37
+ encode(obj) {
38
+ if (!(obj instanceof ArrayBuffer))
39
+ return null;
40
+ return new Uint8Array(obj);
41
+ },
42
+ decode: (raw) =>
43
+ // need to copy as backing ArrayBuffer might be larger
44
+ copy(raw),
45
+ });
46
+ export function encode(serialized) {
47
+ return msgpack.encode(serialized, { extensionCodec: CODEC });
48
+ }
49
+ export function decode(encoded) {
50
+ return msgpack.decode(encoded, { extensionCodec: CODEC });
51
+ }
@@ -1,2 +1,4 @@
1
1
  export * as model from './model.js';
2
2
  export * as weights from './weights.js';
3
+ export type { Encoded } from "./coder.js";
4
+ export { isEncoded } from "./coder.js";
@@ -1,2 +1,3 @@
1
1
  export * as model from './model.js';
2
2
  export * as weights from './weights.js';
3
+ export { isEncoded } from "./coder.js";
@@ -1,5 +1,4 @@
1
1
  import type { Model } from '../index.js';
2
- export type Encoded = Uint8Array;
3
- export declare function isEncoded(raw: unknown): raw is Encoded;
2
+ import { Encoded } from "./coder.js";
4
3
  export declare function encode(model: Model): Promise<Encoded>;
5
4
  export declare function decode(encoded: unknown): Promise<Model>;
@@ -1,38 +1,29 @@
1
- import msgpack from 'msgpack-lite';
2
1
  import { models, serialization } from '../index.js';
2
+ import * as coder from "./coder.js";
3
+ import { isEncoded } from "./coder.js";
3
4
  const Type = {
4
5
  TFJS: 0,
5
6
  GPT: 1
6
7
  };
7
- export function isEncoded(raw) {
8
- return raw instanceof Uint8Array;
9
- }
10
8
  export async function encode(model) {
11
- let encoded;
12
9
  switch (true) {
13
10
  case model instanceof models.TFJS: {
14
11
  const serialized = await model.serialize();
15
- encoded = msgpack.encode([Type.TFJS, serialized]);
16
- break;
12
+ return coder.encode([Type.TFJS, serialized]);
17
13
  }
18
14
  case model instanceof models.GPT: {
19
15
  const { weights, config } = model.serialize();
20
16
  const serializedWeights = await serialization.weights.encode(weights);
21
- encoded = msgpack.encode([Type.GPT, serializedWeights, config]);
22
- break;
17
+ return coder.encode([Type.GPT, serializedWeights, config]);
23
18
  }
24
19
  default:
25
20
  throw new Error("unknown model type");
26
21
  }
27
- // Node's Buffer extends Node's Uint8Array, which might not be the same
28
- // as the browser's Uint8Array. we ensure here that it is.
29
- return new Uint8Array(encoded);
30
22
  }
31
23
  export async function decode(encoded) {
32
- if (!isEncoded(encoded)) {
24
+ if (!isEncoded(encoded))
33
25
  throw new Error("Invalid encoding, raw encoding isn't an instance of Uint8Array");
34
- }
35
- const raw = msgpack.decode(encoded);
26
+ const raw = coder.decode(encoded);
36
27
  if (!Array.isArray(raw) || raw.length < 2) {
37
28
  throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values");
38
29
  }
@@ -59,15 +50,9 @@ export async function decode(encoded) {
59
50
  else {
60
51
  throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3');
61
52
  }
62
- if (!Array.isArray(rawModel)) {
63
- throw new Error('invalid encoding, gpt-tfjs model weights should be an array');
64
- }
65
- const arr = rawModel;
66
- if (arr.some((r) => typeof r !== 'number')) {
67
- throw new Error("invalid encoding, gpt-tfjs weights should be numbers");
68
- }
69
- const nums = arr;
70
- const weights = serialization.weights.decode(nums);
53
+ if (!isEncoded(rawModel))
54
+ throw new Error("invalid encoding, gpt-tfjs model weights should be an encoding of its weights");
55
+ const weights = serialization.weights.decode(rawModel);
71
56
  return models.GPT.deserialize({ weights, config });
72
57
  }
73
58
  default:
@@ -1,5 +1,4 @@
1
- import { WeightsContainer } from '../index.js';
2
- export type Encoded = number[];
3
- export declare function isEncoded(raw: unknown): raw is Encoded;
1
+ import { WeightsContainer } from "../index.js";
2
+ import { Encoded } from "./coder.js";
4
3
  export declare function encode(weights: WeightsContainer): Promise<Encoded>;
5
4
  export declare function decode(encoded: Encoded): WeightsContainer;
@@ -1,37 +1,26 @@
1
- import * as msgpack from 'msgpack-lite';
2
- import * as tf from '@tensorflow/tfjs';
3
- import { WeightsContainer } from '../index.js';
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { WeightsContainer } from "../index.js";
3
+ import * as coder from "./coder.js";
4
4
  function isSerialized(raw) {
5
- if (typeof raw !== 'object' || raw === null) {
5
+ if (typeof raw !== "object" || raw === null)
6
6
  return false;
7
- }
8
7
  const { shape, data } = raw;
9
- if (!(Array.isArray(shape) && shape.every((e) => typeof e === 'number')) ||
10
- !(Array.isArray(data) && data.every((e) => typeof e === 'number'))) {
8
+ if (!(Array.isArray(shape) && shape.every((e) => typeof e === "number")) ||
9
+ !(data instanceof Float32Array))
11
10
  return false;
12
- }
13
- const _ = {
14
- shape: shape,
15
- data: data,
16
- };
11
+ const _ = { shape, data };
17
12
  return true;
18
13
  }
19
- export function isEncoded(raw) {
20
- return Array.isArray(raw) && raw.every((e) => typeof e === 'number');
21
- }
22
14
  export async function encode(weights) {
23
- const serialized = await Promise.all(weights.weights.map(async (t) => {
24
- return {
25
- shape: t.shape,
26
- data: [...await t.data()]
27
- };
28
- }));
29
- return [...msgpack.encode(serialized).values()];
15
+ const serialized = await Promise.all(weights.weights.map(async (t) => ({
16
+ shape: t.shape,
17
+ data: await t.data(),
18
+ })));
19
+ return coder.encode(serialized);
30
20
  }
31
21
  export function decode(encoded) {
32
- const raw = msgpack.decode(encoded);
33
- if (!(Array.isArray(raw) && raw.every(isSerialized))) {
34
- throw new Error('expected to decode an array of serialized weights');
35
- }
22
+ const raw = coder.decode(encoded);
23
+ if (!(Array.isArray(raw) && raw.every(isSerialized)))
24
+ throw new Error("expected to decode an array of serialized weights");
36
25
  return new WeightsContainer(raw.map((w) => tf.tensor(w.data, w.shape)));
37
26
  }
@@ -1,5 +1,5 @@
1
- import { Map } from 'immutable';
2
- import type { Model } from '../index.js';
3
- import type { Task, TaskID } from './task.js';
4
- export declare function pushTask(url: URL, task: Task, model: Model): Promise<void>;
5
- export declare function fetchTasks(url: URL): Promise<Map<TaskID, Task>>;
1
+ import { Map } from "immutable";
2
+ import type { Model } from "../index.js";
3
+ import type { Task, TaskID } from "./task.js";
4
+ export declare function pushTask(base: URL, task: Task, model: Model): Promise<void>;
5
+ export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task>>;
@@ -1,22 +1,28 @@
1
- import axios from 'axios';
2
1
  import createDebug from "debug";
3
- import { Map } from 'immutable';
4
- import { serialization } from '../index.js';
5
- import { isTask } from './task.js';
2
+ import { Map } from "immutable";
3
+ import { serialization } from "../index.js";
4
+ import { isTask } from "./task.js";
6
5
  const debug = createDebug("discojs:task:handlers");
7
- const TASK_ENDPOINT = 'tasks';
8
- export async function pushTask(url, task, model) {
9
- await axios.post(url.href + TASK_ENDPOINT, {
10
- task,
11
- model: await serialization.model.encode(model),
12
- weights: await serialization.weights.encode(model.weights)
6
+ function urlToTasks(base) {
7
+ const ret = new URL(base);
8
+ ret.pathname += "tasks";
9
+ return ret;
10
+ }
11
+ export async function pushTask(base, task, model) {
12
+ await fetch(urlToTasks(base), {
13
+ method: "POST",
14
+ body: JSON.stringify({
15
+ task,
16
+ model: await serialization.model.encode(model),
17
+ weights: await serialization.weights.encode(model.weights),
18
+ }),
13
19
  });
14
20
  }
15
- export async function fetchTasks(url) {
16
- const response = await axios.get(new URL(TASK_ENDPOINT, url).href);
17
- const tasks = response.data;
21
+ export async function fetchTasks(base) {
22
+ const response = await fetch(urlToTasks(base));
23
+ const tasks = await response.json();
18
24
  if (!Array.isArray(tasks)) {
19
- throw new Error('Expected to receive an array of Tasks when fetching tasks');
25
+ throw new Error("Expected to receive an array of Tasks when fetching tasks");
20
26
  }
21
27
  else if (!tasks.every(isTask)) {
22
28
  for (const task of tasks) {
@@ -24,7 +30,7 @@ export async function fetchTasks(url) {
24
30
  debug("task has invalid format: :O", task);
25
31
  }
26
32
  }
27
- throw new Error('invalid tasks response, the task object received is not well formatted');
33
+ throw new Error("invalid tasks response, the task object received is not well formatted");
28
34
  }
29
35
  return Map(tasks.map((t) => [t.id, t]));
30
36
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20241014092014.0",
3
+ "version": "3.0.1-p20241025115642.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
@@ -19,13 +19,12 @@
19
19
  },
20
20
  "homepage": "https://github.com/epfml/disco#readme",
21
21
  "dependencies": {
22
+ "@msgpack/msgpack": "^3.0.0-beta2",
22
23
  "@tensorflow/tfjs": "4",
23
24
  "@xenova/transformers": "2",
24
- "axios": "1",
25
25
  "immutable": "4",
26
26
  "isomorphic-wrtc": "1",
27
27
  "isomorphic-ws": "5",
28
- "msgpack-lite": "0.1",
29
28
  "simple-peer": "9",
30
29
  "tslib": "2",
31
30
  "ws": "8"