@epfml/discojs 3.0.1-p20241007204240.0 → 3.0.1-p20241024094708.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/dist/aggregator/{base.d.ts → aggregator.d.ts} +24 -31
  2. package/dist/aggregator/{base.js → aggregator.js} +48 -36
  3. package/dist/aggregator/get.d.ts +2 -2
  4. package/dist/aggregator/get.js +4 -4
  5. package/dist/aggregator/index.d.ts +1 -4
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +4 -4
  8. package/dist/aggregator/mean.js +5 -15
  9. package/dist/aggregator/secure.d.ts +4 -4
  10. package/dist/aggregator/secure.js +7 -17
  11. package/dist/client/client.d.ts +71 -17
  12. package/dist/client/client.js +118 -17
  13. package/dist/client/decentralized/decentralized_client.d.ts +11 -13
  14. package/dist/client/decentralized/decentralized_client.js +121 -84
  15. package/dist/client/decentralized/messages.d.ts +12 -6
  16. package/dist/client/decentralized/messages.js +9 -8
  17. package/dist/client/event_connection.js +2 -2
  18. package/dist/client/federated/federated_client.d.ts +1 -13
  19. package/dist/client/federated/federated_client.js +15 -94
  20. package/dist/client/federated/messages.d.ts +6 -11
  21. package/dist/client/local_client.d.ts +1 -0
  22. package/dist/client/local_client.js +3 -0
  23. package/dist/client/messages.d.ts +14 -7
  24. package/dist/client/messages.js +13 -11
  25. package/dist/default_tasks/cifar10.js +1 -1
  26. package/dist/default_tasks/lus_covid.js +1 -0
  27. package/dist/default_tasks/mnist.js +1 -1
  28. package/dist/default_tasks/simple_face.js +1 -0
  29. package/dist/default_tasks/titanic.js +1 -0
  30. package/dist/default_tasks/wikitext.js +1 -0
  31. package/dist/index.d.ts +0 -2
  32. package/dist/serialization/coder.d.ts +4 -0
  33. package/dist/serialization/coder.js +51 -0
  34. package/dist/serialization/index.d.ts +2 -0
  35. package/dist/serialization/index.js +1 -0
  36. package/dist/serialization/model.d.ts +1 -2
  37. package/dist/serialization/model.js +9 -24
  38. package/dist/serialization/weights.d.ts +2 -3
  39. package/dist/serialization/weights.js +15 -26
  40. package/dist/task/task_handler.d.ts +5 -5
  41. package/dist/task/task_handler.js +21 -15
  42. package/dist/task/training_information.d.ts +1 -2
  43. package/dist/task/training_information.js +6 -8
  44. package/dist/training/disco.d.ts +4 -1
  45. package/dist/training/trainer.js +1 -1
  46. package/dist/utils/event_emitter.d.ts +3 -3
  47. package/dist/utils/event_emitter.js +10 -9
  48. package/package.json +2 -3
@@ -6,25 +6,13 @@ import { Client } from "../client.js";
6
6
  */
7
7
  export declare class FederatedClient extends Client {
8
8
  #private;
9
- get nbOfParticipants(): number;
10
- /**
11
- * Opens a new WebSocket connection with the server and listens to new messages over the channel
12
- */
13
- private connectServer;
9
+ getNbOfParticipants(): number;
14
10
  /**
15
11
  * Initializes the connection to the server, gets our node ID
16
12
  * as well as the latest training information: latest global model, current round and
17
13
  * whether we are waiting for more participants.
18
14
  */
19
15
  connect(): Promise<Model>;
20
- /**
21
- * Method called when the server notifies the client that there aren't enough
22
- * participants (anymore) to start/continue training
23
- * The method creates a promise that will resolve once the server notifies
24
- * the client that the training can resume via a subsequent EnoughParticipants message
25
- * @returns a promise which resolves when enough participants joined the session
26
- */
27
- private waitForMoreParticipants;
28
16
  /**
29
17
  * Disconnection process when user quits the task.
30
18
  */
@@ -1,6 +1,6 @@
1
1
  import createDebug from "debug";
2
2
  import { serialization } from "../../index.js";
3
- import { Client } from "../client.js";
3
+ import { Client, shortenId } from "../client.js";
4
4
  import { type } from "../messages.js";
5
5
  import { waitMessage, waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
6
6
  import * as messages from "./messages.js";
@@ -19,38 +19,10 @@ export class FederatedClient extends Client {
19
19
  // Total number of other federated contributors, including this client, excluding the server
20
20
  // E.g., if 3 users are training a federated model, nbOfParticipants is 3
21
21
  #nbOfParticipants = 1;
22
- /**
23
- * When the server notifies clients to pause and wait until more
24
- * participants join, we rely on this promise to wait
25
- * until the server signals that the training can resume
26
- */
27
- #promiseForMoreParticipants = undefined;
28
- /**
29
- * When the server notifies the client that they can resume training
30
- * after waiting for more participants, we want to be able to display what
31
- * we were doing before waiting (training locally or updating our model).
32
- * We use this attribute to store the status to rollback to when we stop waiting
33
- */
34
- #previousStatus = undefined;
35
- /**
36
- * Whether the client should wait until more
37
- * participants join the session, i.e. a promise has been created
38
- */
39
- get #waitingForMoreParticipants() {
40
- return this.#promiseForMoreParticipants !== undefined;
41
- }
42
22
  // the number of participants excluding the server
43
- get nbOfParticipants() {
23
+ getNbOfParticipants() {
44
24
  return this.#nbOfParticipants;
45
25
  }
46
- /**
47
- * Opens a new WebSocket connection with the server and listens to new messages over the channel
48
- */
49
- async connectServer(url) {
50
- const server = await WebSocketServer.connect(url, messages.isMessageFederated, // can only receive federated message types from the server
51
- messages.isMessageFederated);
52
- return server;
53
- }
54
26
  /**
55
27
  * Initializes the connection to the server, gets our node ID
56
28
  * as well as the latest training information: latest global model, current round and
@@ -70,31 +42,12 @@ export class FederatedClient extends Client {
70
42
  throw new Error(`unknown protocol: ${this.url.protocol}`);
71
43
  }
72
44
  serverURL.pathname += `federated/${this.task.id}`;
73
- this._server = await this.connectServer(serverURL);
74
- // Setup an event callback if the server signals that we should
75
- // wait for more participants
76
- this.server.on(type.WaitingForMoreParticipants, () => {
77
- debug(`[${id.slice(0, 4)}] received WaitingForMoreParticipants message from server`);
78
- // Display the waiting status right away
79
- this.emit("status", "Waiting for more participants");
80
- // Upon receiving a WaitingForMoreParticipants message,
81
- // the client will await for this promise to resolve before sending its
82
- // local weight update
83
- this.#promiseForMoreParticipants = this.waitForMoreParticipants();
84
- });
85
- // As an example assume we need at least 2 participants to train,
86
- // When two participants join almost at the same time, the server
87
- // sends a NewFederatedNodeInfo with waitForMoreParticipants=true to the first participant
88
- // and directly follows with an EnoughParticipants message when the 2nd participant joins
89
- // However, the EnoughParticipants can arrive before the NewFederatedNodeInfo (which is much bigger)
90
- // so we check whether we received the EnoughParticipants before being assigned a node ID
45
+ // Opens a new WebSocket connection with the server and listens to new messages over the channel
46
+ this._server = await WebSocketServer.connect(serverURL, messages.isMessageFederated, // can only receive federated message types from the server
47
+ messages.isMessageFederated);
48
+ // c.f. setupServerCallbacks doc for explanation
91
49
  let receivedEnoughParticipants = false;
92
- this.server.once(type.EnoughParticipants, () => {
93
- if (this._ownId === undefined) {
94
- debug(`Received EnoughParticipants message from server before the NewFederatedNodeInfo message`);
95
- receivedEnoughParticipants = true;
96
- }
97
- });
50
+ this.setupServerCallbacks(() => receivedEnoughParticipants = true);
98
51
  this.aggregator.registerNode(SERVER_NODE_ID);
99
52
  const msg = {
100
53
  type: type.ClientConnected,
@@ -109,40 +62,21 @@ export class FederatedClient extends Client {
109
62
  if (waitForMoreParticipants && !receivedEnoughParticipants) {
110
63
  // Create a promise that resolves when enough participants join
111
64
  // The client will await this promise before sending its local weight update
112
- this.#promiseForMoreParticipants = this.waitForMoreParticipants();
65
+ this.promiseForMoreParticipants = this.createPromiseForMoreParticipants();
113
66
  }
114
67
  if (this._ownId !== undefined) {
115
68
  throw new Error('received id from server but was already received');
116
69
  }
117
70
  this._ownId = id;
118
- debug(`[${id.slice(0, 4)}] joined session at round ${round} `);
71
+ debug(`[${shortenId(id)}] joined session at round ${round} `);
119
72
  this.aggregator.setRound(round);
120
73
  this.#nbOfParticipants = nbOfParticipants;
121
74
  // Upon connecting, the server answers with a boolean
122
75
  // which indicates whether there are enough participants or not
123
- debug(`[${this.ownId.slice(0, 4)}] upon connecting, wait for participant flag %o`, this.#waitingForMoreParticipants);
76
+ debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants);
124
77
  model.weights = serialization.weights.decode(payload);
125
78
  return model;
126
79
  }
127
- /**
128
- * Method called when the server notifies the client that there aren't enough
129
- * participants (anymore) to start/continue training
130
- * The method creates a promise that will resolve once the server notifies
131
- * the client that the training can resume via a subsequent EnoughParticipants message
132
- * @returns a promise which resolves when enough participants joined the session
133
- */
134
- async waitForMoreParticipants() {
135
- return new Promise((resolve) => {
136
- // "once" is important because we can't resolve the same promise multiple times
137
- this.server.once(type.EnoughParticipants, () => {
138
- debug(`[${this.ownId.slice(0, 4)}] received EnoughParticipants message from server`);
139
- // Emit the last status emitted before waiting if defined
140
- if (this.#previousStatus !== undefined)
141
- this.emit("status", this.#previousStatus);
142
- resolve();
143
- });
144
- });
145
- }
146
80
  /**
147
81
  * Disconnection process when user quits the task.
148
82
  */
@@ -155,10 +89,7 @@ export class FederatedClient extends Client {
155
89
  onRoundBeginCommunication() {
156
90
  // Prepare the result promise for the incoming round
157
91
  this.aggregationResult = new Promise((resolve) => this.aggregator.once('aggregation', resolve));
158
- // Save the status in case participants leave and we switch to waiting for more participants
159
- // Once enough new participants join we can display the previous status again
160
- this.#previousStatus = "Training the model on the data you connected";
161
- this.emit("status", this.#previousStatus);
92
+ this.saveAndEmit("local training");
162
93
  return Promise.resolve();
163
94
  }
164
95
  /**
@@ -176,18 +107,8 @@ export class FederatedClient extends Client {
176
107
  throw new Error("local aggregation result was not set");
177
108
  }
178
109
  // First we check if we are waiting for more participants before sending our weight update
179
- if (this.#waitingForMoreParticipants) {
180
- // wait for the promise to resolve, which takes as long as it takes for new participants to join
181
- debug(`[${this.ownId.slice(0, 4)}] is awaiting the promise for more participants`);
182
- this.emit("status", "Waiting for more participants");
183
- await this.#promiseForMoreParticipants;
184
- // Make sure to set the promise back to undefined once resolved
185
- this.#promiseForMoreParticipants = undefined;
186
- }
187
- // Save the status in case participants leave and we switch to waiting for more participants
188
- // Once enough new participants join we can display the previous status again
189
- this.#previousStatus = "Updating the model with other participants' models";
190
- this.emit("status", this.#previousStatus);
110
+ await this.waitForParticipantsIfNeeded();
111
+ this.saveAndEmit("updating model");
191
112
  // Send our local contribution to the server
192
113
  // and receive the server global update for this round as an answer to our contribution
193
114
  const payloadToServer = this.aggregator.makePayloads(weights).first();
@@ -198,9 +119,9 @@ export class FederatedClient extends Client {
198
119
  };
199
120
  // Need to await the resulting global model right after sending our local contribution
200
121
  // to make sure we don't miss it
201
- debug(`[${this.ownId.slice(0, 4)}] sent its local update to the server for round ${this.aggregator.round}`);
122
+ debug(`[${shortenId(this.ownId)}] sent its local update to the server for round ${this.aggregator.round}`);
202
123
  this.server.send(msg);
203
- debug(`[${this.ownId.slice(0, 4)}] is waiting for server update for round ${this.aggregator.round + 1}`);
124
+ debug(`[${shortenId(this.ownId)}] is waiting for server update for round ${this.aggregator.round + 1}`);
204
125
  const { payload: payloadFromServer, round: serverRound, nbOfParticipants } = await waitMessage(this.server, type.ReceiveServerPayload); // Wait indefinitely for the server update
205
126
  this.#nbOfParticipants = nbOfParticipants; // Save the current participants
206
127
  const serverResult = serialization.weights.decode(payloadFromServer);
@@ -1,30 +1,25 @@
1
- import { type weights } from '../../serialization/index.js';
1
+ import type { serialization } from "../../index.js";
2
2
  import { type NodeID } from '..//types.js';
3
- import { type, type ClientConnected } from '../messages.js';
3
+ import { type } from '../messages.js';
4
+ import type { ClientConnected, WaitingForMoreParticipants, EnoughParticipants } from '../messages.js';
4
5
  export type MessageFederated = ClientConnected | NewFederatedNodeInfo | SendPayload | ReceiveServerPayload | WaitingForMoreParticipants | EnoughParticipants;
5
6
  export interface NewFederatedNodeInfo {
6
7
  type: type.NewFederatedNodeInfo;
7
8
  id: NodeID;
8
9
  waitForMoreParticipants: boolean;
9
- payload: weights.Encoded;
10
+ payload: serialization.Encoded;
10
11
  round: number;
11
12
  nbOfParticipants: number;
12
13
  }
13
14
  export interface SendPayload {
14
15
  type: type.SendPayload;
15
- payload: weights.Encoded;
16
+ payload: serialization.Encoded;
16
17
  round: number;
17
18
  }
18
19
  export interface ReceiveServerPayload {
19
20
  type: type.ReceiveServerPayload;
20
- payload: weights.Encoded;
21
+ payload: serialization.Encoded;
21
22
  round: number;
22
23
  nbOfParticipants: number;
23
24
  }
24
- export interface EnoughParticipants {
25
- type: type.EnoughParticipants;
26
- }
27
- export interface WaitingForMoreParticipants {
28
- type: type.WaitingForMoreParticipants;
29
- }
30
25
  export declare function isMessageFederated(raw: unknown): raw is MessageFederated;
@@ -5,6 +5,7 @@ import { Client } from "./client.js";
5
5
  * with anyone. Thus LocalClient doesn't do anything during communication
6
6
  */
7
7
  export declare class LocalClient extends Client {
8
+ getNbOfParticipants(): number;
8
9
  onRoundBeginCommunication(): Promise<void>;
9
10
  onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
10
11
  }
@@ -4,6 +4,9 @@ import { Client } from "./client.js";
4
4
  * with anyone. Thus LocalClient doesn't do anything during communication
5
5
  */
6
6
  export class LocalClient extends Client {
7
+ getNbOfParticipants() {
8
+ return 1;
9
+ }
7
10
  onRoundBeginCommunication() {
8
11
  return Promise.resolve();
9
12
  }
@@ -3,19 +3,26 @@ import type * as federated from './federated/messages.js';
3
3
  export declare enum type {
4
4
  ClientConnected = 0,
5
5
  NewDecentralizedNodeInfo = 1,
6
- SignalForPeer = 2,
6
+ JoinRound = 2,
7
7
  PeerIsReady = 3,
8
8
  PeersForRound = 4,
9
- Payload = 5,
10
- NewFederatedNodeInfo = 6,
11
- WaitingForMoreParticipants = 7,
12
- EnoughParticipants = 8,
13
- SendPayload = 9,
14
- ReceiveServerPayload = 10
9
+ SignalForPeer = 5,
10
+ Payload = 6,
11
+ NewFederatedNodeInfo = 7,
12
+ WaitingForMoreParticipants = 8,
13
+ EnoughParticipants = 9,
14
+ SendPayload = 10,
15
+ ReceiveServerPayload = 11
15
16
  }
16
17
  export interface ClientConnected {
17
18
  type: type.ClientConnected;
18
19
  }
20
+ export interface EnoughParticipants {
21
+ type: type.EnoughParticipants;
22
+ }
23
+ export interface WaitingForMoreParticipants {
24
+ type: type.WaitingForMoreParticipants;
25
+ }
19
26
  export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
20
27
  export type NarrowMessage<D> = Extract<Message, {
21
28
  type: D;
@@ -9,34 +9,36 @@ export var type;
9
9
  // answers with its peer id and also tells the client whether we are waiting
10
10
  // for more participants before starting training
11
11
  type[type["NewDecentralizedNodeInfo"] = 1] = "NewDecentralizedNodeInfo";
12
- // Message forwarded by the server from a client to another client
13
- // to establish a peer-to-peer (WebRTC) connection
14
- type[type["SignalForPeer"] = 2] = "SignalForPeer";
12
+ // Message sent by peers to the server to signal they want to
13
+ // join the next round
14
+ type[type["JoinRound"] = 2] = "JoinRound";
15
15
  // Message sent by nodes to server signaling they are ready to
16
16
  // start the next round
17
17
  type[type["PeerIsReady"] = 3] = "PeerIsReady";
18
18
  // Sent by the server to participating peers containing the list
19
19
  // of peers for the round
20
20
  type[type["PeersForRound"] = 4] = "PeersForRound";
21
+ // Message forwarded by the server from a client to another client
22
+ // to establish a peer-to-peer (WebRTC) connection
23
+ type[type["SignalForPeer"] = 5] = "SignalForPeer";
21
24
  // The weight update
22
- type[type["Payload"] = 5] = "Payload";
25
+ type[type["Payload"] = 6] = "Payload";
23
26
  /* Federated */
24
27
  // The server answers the ClientConnected message with the necessary information
25
28
  // to start training: node id, latest model global weights, current round etc
26
- type[type["NewFederatedNodeInfo"] = 6] = "NewFederatedNodeInfo";
29
+ type[type["NewFederatedNodeInfo"] = 7] = "NewFederatedNodeInfo";
27
30
  // Message sent by server to notify clients that there are not enough
28
31
  // participants to continue training
29
- type[type["WaitingForMoreParticipants"] = 7] = "WaitingForMoreParticipants";
32
+ type[type["WaitingForMoreParticipants"] = 8] = "WaitingForMoreParticipants";
30
33
  // Message sent by server to notify clients that there are now enough
31
34
  // participants to start training collaboratively
32
- type[type["EnoughParticipants"] = 8] = "EnoughParticipants";
33
- type[type["SendPayload"] = 9] = "SendPayload";
34
- type[type["ReceiveServerPayload"] = 10] = "ReceiveServerPayload";
35
+ type[type["EnoughParticipants"] = 9] = "EnoughParticipants";
36
+ type[type["SendPayload"] = 10] = "SendPayload";
37
+ type[type["ReceiveServerPayload"] = 11] = "ReceiveServerPayload";
35
38
  })(type || (type = {}));
36
39
  export function hasMessageType(raw) {
37
- if (typeof raw !== 'object' || raw === null) {
40
+ if (typeof raw !== 'object' || raw === null)
38
41
  return false;
39
- }
40
42
  const o = raw;
41
43
  if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
42
44
  return false;
@@ -29,8 +29,8 @@ export const cifar10 = {
29
29
  IMAGE_W: 224,
30
30
  LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
31
31
  scheme: 'decentralized',
32
+ aggregationStrategy: 'mean',
32
33
  privacy: { clippingRadius: 20, noiseScale: 1 },
33
- decentralizedSecure: true,
34
34
  minNbOfParticipants: 3,
35
35
  maxShareValue: 100,
36
36
  tensorBackend: 'tfjs'
@@ -28,6 +28,7 @@ export const lusCovid = {
28
28
  LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
29
29
  dataType: 'image',
30
30
  scheme: 'federated',
31
+ aggregationStrategy: 'mean',
31
32
  minNbOfParticipants: 2,
32
33
  tensorBackend: 'tfjs'
33
34
  }
@@ -29,7 +29,7 @@ export const mnist = {
29
29
  preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
30
30
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
31
31
  scheme: 'decentralized',
32
- decentralizedSecure: true,
32
+ aggregationStrategy: 'secure',
33
33
  minNbOfParticipants: 3,
34
34
  maxShareValue: 100,
35
35
  tensorBackend: 'tfjs'
@@ -28,6 +28,7 @@ export const simpleFace = {
28
28
  IMAGE_W: 200,
29
29
  LABEL_LIST: ['child', 'adult'],
30
30
  scheme: 'federated',
31
+ aggregationStrategy: 'mean',
31
32
  minNbOfParticipants: 2,
32
33
  tensorBackend: 'tfjs'
33
34
  }
@@ -62,6 +62,7 @@ export const titanic = {
62
62
  'Survived'
63
63
  ],
64
64
  scheme: 'federated',
65
+ aggregationStrategy: 'mean',
65
66
  minNbOfParticipants: 2,
66
67
  tensorBackend: 'tfjs'
67
68
  }
@@ -25,6 +25,7 @@ export const wikitext = {
25
25
  dataType: 'text',
26
26
  preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
27
27
  scheme: 'federated',
28
+ aggregationStrategy: 'mean',
28
29
  minNbOfParticipants: 2,
29
30
  epochs: 6,
30
31
  // Unused by wikitext because data already comes split
package/dist/index.d.ts CHANGED
@@ -1,7 +1,5 @@
1
1
  export * as data from './dataset/index.js';
2
2
  export * as serialization from './serialization/index.js';
3
- export { Encoded as EncodedModel } from './serialization/model.js';
4
- export { Encoded as EncodedWeights } from './serialization/weights.js';
5
3
  export * as training from './training/index.js';
6
4
  export * as privacy from './privacy.js';
7
5
  export * as client from './client/index.js';
@@ -0,0 +1,4 @@
1
+ export type Encoded = Uint8Array;
2
+ export declare function isEncoded(raw: unknown): raw is Encoded;
3
+ export declare function encode(serialized: unknown): Encoded;
4
+ export declare function decode(encoded: Encoded): unknown;
@@ -0,0 +1,51 @@
1
+ import * as msgpack from "@msgpack/msgpack";
2
+ export function isEncoded(raw) {
3
+ if (!(raw instanceof Uint8Array))
4
+ return false;
5
+ const _ = raw;
6
+ return true;
7
+ }
8
+ // create a new buffer instead of referencing the backing one
9
+ function copy(arr) {
10
+ // `Buffer.slice` (subclass of Uint8Array on Node) doesn't copy
11
+ // thus doesn't respect Liskov substitution principle
12
+ // https://nodejs.org/api/buffer.html#bufslicestart-end
13
+ // here we call the correct implementation
14
+ return Uint8Array.prototype.slice.call(arr);
15
+ }
16
+ // to avoid mapping every ArrayBuffer to Uint8Array,
17
+ // we register our own convertors for the type we know are needed
18
+ // type id are arbitrally taken from msgpack-lite
19
+ // https://www.npmjs.com/package/msgpack-lite#extension-types
20
+ const CODEC = new msgpack.ExtensionCodec();
21
+ // used by TFJS's weights
22
+ CODEC.register({
23
+ type: 0x17,
24
+ encode(obj) {
25
+ if (!(obj instanceof Float32Array))
26
+ return null;
27
+ return new Uint8Array(obj.buffer, obj.byteOffset, obj.byteLength);
28
+ },
29
+ decode: (raw) =>
30
+ // to reinterpred uint8 into float32, it needs to be 4-bytes aligned
31
+ // but the given buffer might not be so we need to copy it.
32
+ new Float32Array(copy(raw).buffer),
33
+ });
34
+ // used by TFJS's saved model
35
+ CODEC.register({
36
+ type: 0x1a,
37
+ encode(obj) {
38
+ if (!(obj instanceof ArrayBuffer))
39
+ return null;
40
+ return new Uint8Array(obj);
41
+ },
42
+ decode: (raw) =>
43
+ // need to copy as backing ArrayBuffer might be larger
44
+ copy(raw),
45
+ });
46
+ export function encode(serialized) {
47
+ return msgpack.encode(serialized, { extensionCodec: CODEC });
48
+ }
49
+ export function decode(encoded) {
50
+ return msgpack.decode(encoded, { extensionCodec: CODEC });
51
+ }
@@ -1,2 +1,4 @@
1
1
  export * as model from './model.js';
2
2
  export * as weights from './weights.js';
3
+ export type { Encoded } from "./coder.js";
4
+ export { isEncoded } from "./coder.js";
@@ -1,2 +1,3 @@
1
1
  export * as model from './model.js';
2
2
  export * as weights from './weights.js';
3
+ export { isEncoded } from "./coder.js";
@@ -1,5 +1,4 @@
1
1
  import type { Model } from '../index.js';
2
- export type Encoded = Uint8Array;
3
- export declare function isEncoded(raw: unknown): raw is Encoded;
2
+ import { Encoded } from "./coder.js";
4
3
  export declare function encode(model: Model): Promise<Encoded>;
5
4
  export declare function decode(encoded: unknown): Promise<Model>;
@@ -1,38 +1,29 @@
1
- import msgpack from 'msgpack-lite';
2
1
  import { models, serialization } from '../index.js';
2
+ import * as coder from "./coder.js";
3
+ import { isEncoded } from "./coder.js";
3
4
  const Type = {
4
5
  TFJS: 0,
5
6
  GPT: 1
6
7
  };
7
- export function isEncoded(raw) {
8
- return raw instanceof Uint8Array;
9
- }
10
8
  export async function encode(model) {
11
- let encoded;
12
9
  switch (true) {
13
10
  case model instanceof models.TFJS: {
14
11
  const serialized = await model.serialize();
15
- encoded = msgpack.encode([Type.TFJS, serialized]);
16
- break;
12
+ return coder.encode([Type.TFJS, serialized]);
17
13
  }
18
14
  case model instanceof models.GPT: {
19
15
  const { weights, config } = model.serialize();
20
16
  const serializedWeights = await serialization.weights.encode(weights);
21
- encoded = msgpack.encode([Type.GPT, serializedWeights, config]);
22
- break;
17
+ return coder.encode([Type.GPT, serializedWeights, config]);
23
18
  }
24
19
  default:
25
20
  throw new Error("unknown model type");
26
21
  }
27
- // Node's Buffer extends Node's Uint8Array, which might not be the same
28
- // as the browser's Uint8Array. we ensure here that it is.
29
- return new Uint8Array(encoded);
30
22
  }
31
23
  export async function decode(encoded) {
32
- if (!isEncoded(encoded)) {
24
+ if (!isEncoded(encoded))
33
25
  throw new Error("Invalid encoding, raw encoding isn't an instance of Uint8Array");
34
- }
35
- const raw = msgpack.decode(encoded);
26
+ const raw = coder.decode(encoded);
36
27
  if (!Array.isArray(raw) || raw.length < 2) {
37
28
  throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values");
38
29
  }
@@ -59,15 +50,9 @@ export async function decode(encoded) {
59
50
  else {
60
51
  throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3');
61
52
  }
62
- if (!Array.isArray(rawModel)) {
63
- throw new Error('invalid encoding, gpt-tfjs model weights should be an array');
64
- }
65
- const arr = rawModel;
66
- if (arr.some((r) => typeof r !== 'number')) {
67
- throw new Error("invalid encoding, gpt-tfjs weights should be numbers");
68
- }
69
- const nums = arr;
70
- const weights = serialization.weights.decode(nums);
53
+ if (!isEncoded(rawModel))
54
+ throw new Error("invalid encoding, gpt-tfjs model weights should be an encoding of its weights");
55
+ const weights = serialization.weights.decode(rawModel);
71
56
  return models.GPT.deserialize({ weights, config });
72
57
  }
73
58
  default:
@@ -1,5 +1,4 @@
1
- import { WeightsContainer } from '../index.js';
2
- export type Encoded = number[];
3
- export declare function isEncoded(raw: unknown): raw is Encoded;
1
+ import { WeightsContainer } from "../index.js";
2
+ import { Encoded } from "./coder.js";
4
3
  export declare function encode(weights: WeightsContainer): Promise<Encoded>;
5
4
  export declare function decode(encoded: Encoded): WeightsContainer;
@@ -1,37 +1,26 @@
1
- import * as msgpack from 'msgpack-lite';
2
- import * as tf from '@tensorflow/tfjs';
3
- import { WeightsContainer } from '../index.js';
1
+ import * as tf from "@tensorflow/tfjs";
2
+ import { WeightsContainer } from "../index.js";
3
+ import * as coder from "./coder.js";
4
4
  function isSerialized(raw) {
5
- if (typeof raw !== 'object' || raw === null) {
5
+ if (typeof raw !== "object" || raw === null)
6
6
  return false;
7
- }
8
7
  const { shape, data } = raw;
9
- if (!(Array.isArray(shape) && shape.every((e) => typeof e === 'number')) ||
10
- !(Array.isArray(data) && data.every((e) => typeof e === 'number'))) {
8
+ if (!(Array.isArray(shape) && shape.every((e) => typeof e === "number")) ||
9
+ !(data instanceof Float32Array))
11
10
  return false;
12
- }
13
- const _ = {
14
- shape: shape,
15
- data: data,
16
- };
11
+ const _ = { shape, data };
17
12
  return true;
18
13
  }
19
- export function isEncoded(raw) {
20
- return Array.isArray(raw) && raw.every((e) => typeof e === 'number');
21
- }
22
14
  export async function encode(weights) {
23
- const serialized = await Promise.all(weights.weights.map(async (t) => {
24
- return {
25
- shape: t.shape,
26
- data: [...await t.data()]
27
- };
28
- }));
29
- return [...msgpack.encode(serialized).values()];
15
+ const serialized = await Promise.all(weights.weights.map(async (t) => ({
16
+ shape: t.shape,
17
+ data: await t.data(),
18
+ })));
19
+ return coder.encode(serialized);
30
20
  }
31
21
  export function decode(encoded) {
32
- const raw = msgpack.decode(encoded);
33
- if (!(Array.isArray(raw) && raw.every(isSerialized))) {
34
- throw new Error('expected to decode an array of serialized weights');
35
- }
22
+ const raw = coder.decode(encoded);
23
+ if (!(Array.isArray(raw) && raw.every(isSerialized)))
24
+ throw new Error("expected to decode an array of serialized weights");
36
25
  return new WeightsContainer(raw.map((w) => tf.tensor(w.data, w.shape)));
37
26
  }
@@ -1,5 +1,5 @@
1
- import { Map } from 'immutable';
2
- import type { Model } from '../index.js';
3
- import type { Task, TaskID } from './task.js';
4
- export declare function pushTask(url: URL, task: Task, model: Model): Promise<void>;
5
- export declare function fetchTasks(url: URL): Promise<Map<TaskID, Task>>;
1
+ import { Map } from "immutable";
2
+ import type { Model } from "../index.js";
3
+ import type { Task, TaskID } from "./task.js";
4
+ export declare function pushTask(base: URL, task: Task, model: Model): Promise<void>;
5
+ export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task>>;