@epfml/discojs 3.0.1-p20240805130603.0 → 3.0.1-p20240809121820.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
+ import createDebug from "debug";
1
2
  import { Map, Set } from 'immutable';
2
3
  import { EventEmitter } from '../utils/event_emitter.js';
4
+ const debug = createDebug("discojs:aggregator");
3
5
  export var AggregationStep;
4
6
  (function (AggregationStep) {
5
7
  AggregationStep[AggregationStep["ADD"] = 0] = "ADD";
@@ -75,17 +77,16 @@ export class Base extends EventEmitter {
75
77
  log(step, from) {
76
78
  switch (step) {
77
79
  case AggregationStep.ADD:
78
- console.log(`Adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
80
+ debug(`adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
79
81
  break;
80
82
  case AggregationStep.UPDATE:
81
83
  if (from === undefined) {
82
84
  return;
83
85
  }
84
- console.log(`> Updating contribution from node ${from} for round (${this.communicationRound}, ${this.round})`);
86
+ debug(`updating contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
85
87
  break;
86
88
  case AggregationStep.AGGREGATE:
87
- console.log('*'.repeat(80));
88
- console.log(`Buffer is full. Aggregating weights for round (${this.communicationRound}, ${this.round})\n`);
89
+ debug(`buffer full, aggregating weights for round (${this.communicationRound}, ${this.round})`);
89
90
  break;
90
91
  default: {
91
92
  const _ = step;
@@ -1,5 +1,7 @@
1
+ import createDebug from "debug";
1
2
  import { AggregationStep, Base as Aggregator } from "./base.js";
2
3
  import { aggregation } from "../index.js";
4
+ const debug = createDebug("discojs:aggregator:mean");
3
5
  /**
4
6
  * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
5
7
  *
@@ -51,7 +53,8 @@ export class MeanAggregator extends Aggregator {
51
53
  else {
52
54
  // Print a warning regarding the default behavior when thresholdType is not specified
53
55
  if (thresholdType === undefined) {
54
- console.warn("[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
56
+ // TODO enforce validity by splitting features instead of warning
57
+ debug("[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
55
58
  "To instead wait for a single contribution, set thresholdType = 'absolute'");
56
59
  this.#thresholdType = 'relative';
57
60
  }
@@ -72,9 +75,9 @@ export class MeanAggregator extends Aggregator {
72
75
  throw new Error("only a single communication round");
73
76
  if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
74
77
  if (!this.nodes.has(nodeId))
75
- console.warn("Contribution rejected because node id is not registered");
78
+ debug(`contribution rejected because node ${nodeId} is not registered`);
76
79
  if (!this.isWithinRoundCutoff(round))
77
- console.warn(`Contribution rejected because round ${round} is not within round cutoff`);
80
+ debug(`contribution rejected because round ${round} is not within cutoff`);
78
81
  return false;
79
82
  }
80
83
  this.log(this.contributions.hasIn([0, nodeId])
@@ -1,3 +1,4 @@
1
+ import createDebug from "debug";
1
2
  import { Map, Set } from 'immutable';
2
3
  import { serialization } from "../../index.js";
3
4
  import { Client } from '../index.js';
@@ -6,6 +7,7 @@ import { timeout } from '../utils.js';
6
7
  import { WebSocketServer, waitMessage, waitMessageWithTimeout } from '../event_connection.js';
7
8
  import { PeerPool } from './peer_pool.js';
8
9
  import * as messages from './messages.js';
10
+ const debug = createDebug("discojs:client:decentralized");
9
11
  /**
10
12
  * Represents a decentralized client in a network of peers. Peers coordinate each other with the
11
13
  * help of the network's server, yet only exchange payloads between each other. Communication
@@ -49,7 +51,7 @@ export class Base extends Client {
49
51
  };
50
52
  this.server.send(msg);
51
53
  const peerIdMsg = await waitMessage(this.server, type.AssignNodeID);
52
- console.log(`[${peerIdMsg.id}] assigned id generated by server`);
54
+ debug(`[${peerIdMsg.id}] assigned id generated by server`);
53
55
  if (this._ownId !== undefined) {
54
56
  throw new Error('received id from server but was already received');
55
57
  }
@@ -123,11 +125,11 @@ export class Base extends Client {
123
125
  // this awaits the peer's weight update and adds it to
124
126
  // our aggregator upon reception
125
127
  (conn) => { this.receivePayloads(conn, round); });
126
- console.log(`[${this.ownId}] received peers for round ${round}:`, connections.keySeq().toJS());
128
+ debug(`[${this.ownId}] received peers for round ${round}: %o`, connections.keySeq().toJS());
127
129
  this.connections = connections;
128
130
  }
129
131
  catch (e) {
130
- console.error(e);
132
+ debug(`[${this.ownId}] while beginning round: %o`, e);
131
133
  this.aggregator.setNodes(Set(this.ownId));
132
134
  this.connections = Map();
133
135
  }
@@ -144,20 +146,20 @@ export class Base extends Client {
144
146
  receivePayloads(connections, round) {
145
147
  connections.forEach(async (connection, peerId) => {
146
148
  let currentCommunicationRounds = 0;
147
- console.log(`waiting for peer ${peerId}`);
149
+ debug(`waiting for peer ${peerId}`);
148
150
  do {
149
151
  try {
150
152
  const message = await waitMessageWithTimeout(connection, type.Payload, 60_000, "Timeout waiting for a contribution from peer " + peerId);
151
153
  const decoded = serialization.weights.decode(message.payload);
152
154
  if (!this.aggregator.add(peerId, decoded, round, message.round)) {
153
- console.warn(`[${this.ownId}] Failed to add contribution from peer ${peerId}`);
155
+ debug(`[${this.ownId}] failed to add contribution from peer ${peerId}`);
154
156
  }
155
157
  }
156
158
  catch (e) {
157
159
  if (this.isDisconnected) {
158
160
  return;
159
161
  }
160
- console.error(e instanceof Error ? e.message : e);
162
+ debug(`[${this.ownId}] while receiving payloads: %o`, e);
161
163
  }
162
164
  } while (++currentCommunicationRounds < this.aggregator.communicationRounds);
163
165
  });
@@ -192,13 +194,13 @@ export class Base extends Client {
192
194
  payload: encoded
193
195
  };
194
196
  peer.send(msg);
195
- console.log(`[${this.ownId}] send weight update to peer`, msg.peer, msg);
197
+ debug(`[${this.ownId}] send weight update to peer ${msg.peer}: %O`, msg);
196
198
  }
197
199
  }
198
200
  }));
199
201
  }
200
- catch {
201
- throw new Error('error while sending weights');
202
+ catch (cause) {
203
+ throw new Error('error while sending weights', { cause });
202
204
  }
203
205
  }
204
206
  // Wait for aggregation before proceeding to the next communication round.
@@ -213,7 +215,7 @@ export class Base extends Client {
213
215
  if (this.isDisconnected) {
214
216
  return weights;
215
217
  }
216
- console.error(e);
218
+ debug(`[${this.ownId}] while waiting for aggregation: %o`, e);
217
219
  break;
218
220
  }
219
221
  // There is at least one communication round remaining
@@ -1,6 +1,8 @@
1
+ import createDebug from "debug";
1
2
  import { Map } from 'immutable';
2
3
  import { Peer } from './peer.js';
3
4
  import { PeerConnection } from '../event_connection.js';
5
+ const debug = createDebug("discojs:client:decentralized:pool");
4
6
  // TODO cleanup old peers
5
7
  export class PeerPool {
6
8
  id;
@@ -9,7 +11,7 @@ export class PeerPool {
9
11
  this.id = id;
10
12
  }
11
13
  async shutdown() {
12
- console.info(`[${this.id}] is shutting down all its connections`);
14
+ debug(`[${this.id}] is shutting down all its connections`);
13
15
  // Add a timeout o.w. the promise hangs forever if the other peer is already disconnected
14
16
  await Promise.race([
15
17
  Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect())),
@@ -18,7 +20,7 @@ export class PeerPool {
18
20
  this.peers = Map();
19
21
  }
20
22
  signal(peerId, signal) {
21
- console.info(`[${this.id}] signals for`, peerId);
23
+ debug(`[${this.id}] signals for %s`, peerId);
22
24
  const peer = this.peers.get(peerId);
23
25
  if (peer === undefined) {
24
26
  throw new Error(`received signal for unknown peer: ${peerId}`);
@@ -31,17 +33,17 @@ export class PeerPool {
31
33
  if (peersToConnect.contains(this.id)) {
32
34
  throw new Error('peers to connect contains our id');
33
35
  }
34
- console.info(`[${this.id}] is connecting peers:`, peersToConnect.toJS());
36
+ debug(`[${this.id}] is connecting peers: %o`, peersToConnect.toArray());
35
37
  const newPeers = Map(peersToConnect
36
38
  .filter((id) => !this.peers.has(id))
37
39
  .map((id) => [id, new Peer(id, id < this.id)]));
38
- console.info(`[${this.id}] asked to connect new peers:`, newPeers.keySeq().toJS());
40
+ debug(`[${this.id}] asked to connect new peers: %o`, newPeers.keySeq().toArray());
39
41
  const newPeersConnections = newPeers.map((peer) => new PeerConnection(this.id, peer, signallingServer));
40
42
  // adding peers to pool before connecting them because they must be set to call signal on them
41
43
  this.peers = this.peers.merge(newPeersConnections);
42
44
  clientHandle(this.peers);
43
45
  await Promise.all(newPeersConnections.valueSeq().map((conn) => conn.connect()));
44
- console.info(`[${this.id}] knowns connected peers:`, this.peers.keySeq().toJS());
46
+ debug(`[${this.id}] knowns connected peers: %o`, this.peers.keySeq().toArray());
45
47
  return this.peers
46
48
  .filter((_, id) => peersToConnect.has(id));
47
49
  }
@@ -1,9 +1,11 @@
1
- import WebSocket from 'isomorphic-ws';
2
- import msgpack from 'msgpack-lite';
1
+ import createDebug from "debug";
2
+ import WebSocket from "isomorphic-ws";
3
+ import msgpack from "msgpack-lite";
3
4
  import * as decentralizedMessages from './decentralized/messages.js';
4
5
  import { type } from './messages.js';
5
6
  import { timeout } from './utils.js';
6
7
  import { EventEmitter } from '../utils/event_emitter.js';
8
+ const debug = createDebug("discojs:client:connections");
7
9
  export async function waitMessage(connection, type) {
8
10
  return await new Promise((resolve) => {
9
11
  // "once" is important because we can't resolve the same promise multiple times
@@ -41,7 +43,9 @@ export class PeerConnection extends EventEmitter {
41
43
  }
42
44
  this.emit(msg.type, msg);
43
45
  });
44
- this.peer.on('close', () => { console.warn('From', this._ownId, ': peer', this.peer.id, 'closed connection'); });
46
+ this.peer.on("close", () => {
47
+ debug(`[${this._ownId}] peer ${this.peer.id} closed connection`);
48
+ });
45
49
  await new Promise((resolve) => {
46
50
  this.peer.on('connect', resolve);
47
51
  });
@@ -1,8 +1,10 @@
1
+ import createDebug from "debug";
1
2
  import { serialization, } from "../../index.js";
2
3
  import { Base as Client } from "../base.js";
3
4
  import { type } from "../messages.js";
4
5
  import { waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
5
6
  import * as messages from "./messages.js";
7
+ const debug = createDebug("discojs:client:federated");
6
8
  /**
7
9
  * Client class that communicates with a centralized, federated server, when training
8
10
  * a specific task in the federated setting.
@@ -53,7 +55,7 @@ export class Base extends Client {
53
55
  };
54
56
  this.server.send(msg);
55
57
  const received = await waitMessageWithTimeout(this.server, type.AssignNodeID);
56
- console.info(`[${received.id}] assign id generated by the server`);
58
+ debug(`[${received.id}] assign id generated by the server`);
57
59
  this._ownId = received.id;
58
60
  }
59
61
  /**
@@ -87,7 +89,7 @@ export class Base extends Client {
87
89
  else {
88
90
  // Unexpected case: for some reason, the server result is stale.
89
91
  // We proceed to the next round without its result.
90
- console.info(`[${this.ownId}] Server result is either stale or not received`);
92
+ debug(`[${this.ownId}] server result is either stale or not received`);
91
93
  this.aggregator.nextRound();
92
94
  }
93
95
  return await this.aggregationResult;
@@ -122,7 +124,7 @@ export class Base extends Client {
122
124
  }
123
125
  }
124
126
  catch (e) {
125
- console.error(e);
127
+ debug(`[${this.ownId}] while receiving results: %o`, e);
126
128
  }
127
129
  }
128
130
  }
@@ -35,7 +35,7 @@ const leftPadding = {
35
35
  // to include one more token than maxSequenceLength in order to have the next token's label of the maxSequenceLength'th token
36
36
  const maxLength = task.trainingInformation.maxSequenceLength ?? tokenizer.model_max_length;
37
37
  const maxLengthPlusLabel = maxLength + 1;
38
- let fixedLengthTokens = tf.tensor(tokens, undefined, 'int32'); // cast tokens from float to int for gpt-tfjs
38
+ let fixedLengthTokens = tf.tensor1d(tokens, 'int32'); // cast tokens from float to int for gpt-tfjs
39
39
  if (fixedLengthTokens.size > maxLengthPlusLabel) { // Should never happen because tokenization truncates inputs
40
40
  throw Error("There are more tokens than expected after tokenization and truncation");
41
41
  }
@@ -45,6 +45,7 @@ const leftPadding = {
45
45
  }
46
46
  // if tokens.size == maxLengthPlusLabel we can leave it as it is
47
47
  // ys is a one-hot encoding of the next token (i.e. xs shifted by one)
48
+ // cast because oneHot isn't size-typing its return value
48
49
  const ys = tf.oneHot(fixedLengthTokens.slice([1]), tokenizer.model.vocab.length + 1);
49
50
  // remove the extra token now that ys is created
50
51
  const xs = fixedLengthTokens.slice([0], maxLength);
@@ -13,9 +13,8 @@ export class TabularData extends Data {
13
13
  try {
14
14
  await dataset.iterator();
15
15
  }
16
- catch (e) {
17
- console.error('Data input format is not compatible with the chosen task.');
18
- throw (e);
16
+ catch (cause) {
17
+ throw new Error('data input format not compatible with chosen task', { cause });
19
18
  }
20
19
  return new TabularData(dataset, task, size);
21
20
  }
package/dist/index.d.ts CHANGED
@@ -6,12 +6,11 @@ export * as client from './client/index.js';
6
6
  export * as aggregator from './aggregator/index.js';
7
7
  export { WeightsContainer, aggregation } from './weights/index.js';
8
8
  export { Logger, ConsoleLogger } from './logging/index.js';
9
- export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
9
+ export { Memory, type ModelInfo, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
10
10
  export { Disco, RoundLogs } from './training/index.js';
11
11
  export { Validator } from './validation/index.js';
12
12
  export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
13
13
  export * as models from './models/index.js';
14
14
  export * from './task/index.js';
15
15
  export * as defaultTasks from './default_tasks/index.js';
16
- export * from './types.js';
17
16
  export * as async_iterator from "./utils/async_iterator.js";
package/dist/index.js CHANGED
@@ -13,5 +13,4 @@ export { Model, EpochLogs } from './models/index.js';
13
13
  export * as models from './models/index.js';
14
14
  export * from './task/index.js';
15
15
  export * as defaultTasks from './default_tasks/index.js';
16
- export * from './types.js';
17
16
  export * as async_iterator from "./utils/async_iterator.js";
@@ -1,8 +1,4 @@
1
1
  import type { Model, TaskID } from '../index.js';
2
- /**
3
- * Model path which uniquely identifies a model in memory.
4
- */
5
- export type Path = string;
6
2
  /**
7
3
  * Type of models stored in memory. Stored models can either be a model currently
8
4
  * being trained ("working model") or a regular model saved in memory ("saved model").
@@ -21,10 +17,10 @@ export interface ModelInfo {
21
17
  }
22
18
  /**
23
19
  * A model source uniquely identifies a model stored in memory.
24
- * It can be in the form of either a model info object or a Path string
20
+ * It can be in the form of either a model info object or an ID
25
21
  * (one-to-one mapping between the two)
26
22
  */
27
- export type ModelSource = ModelInfo | Path;
23
+ export type ModelSource = ModelInfo | string;
28
24
  /**
29
25
  * Represents a model memory system, providing functions to fetch, save, delete and update models.
30
26
  * Stored models can either be a model currently being trained ("working model") or a regular model
@@ -67,7 +63,7 @@ export declare abstract class Memory {
67
63
  * @param source The model source
68
64
  * @returns The saved model's path
69
65
  */
70
- abstract saveWorkingModel(source: ModelSource): Promise<Path | undefined>;
66
+ abstract saveWorkingModel(source: ModelSource): Promise<string | undefined>;
71
67
  /**
72
68
  * Saves the newly provided model to the given model source.
73
69
  * Returns the saved model's path
@@ -75,7 +71,7 @@ export declare abstract class Memory {
75
71
  * @param model The new model
76
72
  * @returns The saved model's path
77
73
  */
78
- abstract saveModel(source: ModelSource, model: Model): Promise<Path | undefined>;
74
+ abstract saveModel(source: ModelSource, model: Model): Promise<string | undefined>;
79
75
  /**
80
76
  * Moves the model identified by the model source to a file system. This is platform-dependent.
81
77
  * @param source The model source
@@ -95,7 +91,7 @@ export declare abstract class Memory {
95
91
  * @param source The model source
96
92
  * @returns The model path
97
93
  */
98
- abstract getModelMemoryPath(source: ModelSource): Path | undefined;
94
+ abstract getModelMemoryPath(source: ModelSource): string | undefined;
99
95
  /**
100
96
  * Computes the model information corresponding to the given model source, be it a path or model information.
101
97
  * This is used to easily switch between model path and information, which are both unique model identifiers
@@ -1,5 +1,5 @@
1
1
  import type { Model } from '../index.js';
2
- import type { ModelInfo, Path } from './base.js';
2
+ import type { ModelInfo } from './base.js';
3
3
  import { Memory } from './base.js';
4
4
  /**
5
5
  * Represents an empty model memory.
@@ -14,7 +14,7 @@ export declare class Empty extends Memory {
14
14
  saveModel(): Promise<undefined>;
15
15
  deleteModel(): Promise<void>;
16
16
  downloadModel(): Promise<void>;
17
- getModelMemoryPath(): Path;
17
+ getModelMemoryPath(): string;
18
18
  getModelInfo(): ModelInfo;
19
19
  duplicateSource(): Promise<undefined>;
20
20
  }
@@ -1,2 +1,2 @@
1
1
  export { Empty } from './empty.js';
2
- export { Memory, type ModelInfo, type Path, type ModelSource } from './base.js';
2
+ export { Memory, type ModelInfo, type ModelSource } from './base.js';
@@ -4,7 +4,6 @@
4
4
  import * as tf from '@tensorflow/tfjs';
5
5
  import { PreTrainedTokenizer } from '@xenova/transformers';
6
6
  import { WeightsContainer } from '../../index.js';
7
- import type { Dataset } from '../../dataset/index.js';
8
7
  import { BatchLogs, Model, EpochLogs } from "../index.js";
9
8
  import type { Prediction, Sample } from '../model.js';
10
9
  import { type GPTConfig } from './config.js';
@@ -25,7 +24,13 @@ export declare class GPT extends Model {
25
24
  * @param epochs the number of passes of the training dataset
26
25
  * @param tracker
27
26
  */
28
- train(trainingData: Dataset, validationData?: Dataset): AsyncGenerator<BatchLogs, EpochLogs>;
27
+ train(trainingData: tf.data.Dataset<{
28
+ xs: tf.Tensor2D;
29
+ ys: tf.Tensor3D;
30
+ }>, validationData?: tf.data.Dataset<{
31
+ xs: tf.Tensor2D;
32
+ ys: tf.Tensor3D;
33
+ }>): AsyncGenerator<BatchLogs, EpochLogs>;
29
34
  predict(input: Sample): Promise<Prediction>;
30
35
  generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
31
36
  get config(): Required<GPTConfig>;
@@ -1,13 +1,15 @@
1
1
  /**
2
2
  * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
3
3
  **/
4
+ import createDebug from "debug";
5
+ import { List } from 'immutable';
4
6
  import * as tf from '@tensorflow/tfjs';
5
7
  import { WeightsContainer } from '../../index.js';
6
8
  import { Model, EpochLogs } from "../index.js";
7
9
  import { GPTForCausalLM } from './model.js';
8
10
  import { DEFAULT_CONFIG } from './config.js';
9
11
  import evaluate from './evaluate.js';
10
- import { List } from 'immutable';
12
+ const debug = createDebug("discojs:models:gpt");
11
13
  export class GPT extends Model {
12
14
  model;
13
15
  #maxBatchCount;
@@ -126,8 +128,7 @@ export class GPT extends Model {
126
128
  this.model.optimizer.dispose();
127
129
  }
128
130
  const disposeResults = this.model.dispose();
129
- if (disposeResults.refCountAfterDispose > 0) {
130
- console.error("The GPT model was not disposed correctly (refcount > 0)", disposeResults);
131
- }
131
+ if (disposeResults.refCountAfterDispose > 0)
132
+ debug("model not disposed correctly: %o", disposeResults);
132
133
  }
133
134
  }
@@ -1,8 +1,10 @@
1
+ import createDebug from "debug";
1
2
  import * as tf from '@tensorflow/tfjs';
2
3
  import { getModelSizes, DEFAULT_CONFIG } from './config.js';
3
4
  import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js';
4
5
  import evaluate from './evaluate.js';
5
6
  import { GPTArchitecture } from './layers.js';
7
+ const debug = createDebug("discojs:models:gpt");
6
8
  /**
7
9
  * GPTModel extends tf.LayersModel and overrides tfjs' default training loop
8
10
  *
@@ -87,10 +89,18 @@ class GPTModel extends tf.LayersModel {
87
89
  this.config.evaluateEvery !== undefined &&
88
90
  iteration % this.config.evaluateEvery == 0) {
89
91
  const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches);
90
- console.log(iterationLogs);
92
+ debug('evaluation metrics: %O', iterationLogs);
91
93
  }
92
94
  const memory = tf.memory().numBytes / 1024 / 1024 / 1024;
93
- console.log(`Epoch: ${epoch}`, `\tStep: ${iteration} / ${this.config.maxIter}`, `\tLoss: ${loss.toFixed(3)}`, `\tMemory: ${memory.toFixed(2)} GB`, `\tNumber of tensors allocated: ${tf.memory().numTensors}`, `\tPreprocessing time: ${preprocessingTime.toFixed(0)} ms`, `\tWeight update time: ${weightUpdateTime.toFixed(0)} ms`);
95
+ debug("training metrics: %O", {
96
+ epoch,
97
+ iteration,
98
+ loss,
99
+ memory,
100
+ allocated: tf.memory().numTensors,
101
+ preprocessingTime,
102
+ weightUpdateTime,
103
+ });
94
104
  iteration++;
95
105
  next = await iterator.next();
96
106
  }
@@ -38,7 +38,6 @@ class AdamW extends tf.AdamOptimizer {
38
38
  excludeFromWeightDecay;
39
39
  gradientClipNorm;
40
40
  constructor(params) {
41
- console.log('Using custom AdamW optimizer');
42
41
  const defaultParams = {
43
42
  learningRate: 0.1,
44
43
  beta1: 0.9,
@@ -1,7 +1,9 @@
1
1
  import axios from 'axios';
2
+ import createDebug from "debug";
2
3
  import { Map } from 'immutable';
3
4
  import { serialization } from '../index.js';
4
5
  import { isTask } from './task.js';
6
+ const debug = createDebug("discojs:task:handlers");
5
7
  const TASK_ENDPOINT = 'tasks';
6
8
  export async function pushTask(url, task, model) {
7
9
  await axios.post(url.href + TASK_ENDPOINT, {
@@ -19,7 +21,7 @@ export async function fetchTasks(url) {
19
21
  else if (!tasks.every(isTask)) {
20
22
  for (const task of tasks) {
21
23
  if (!isTask(task)) {
22
- console.error("task has invalid format:", task);
24
+ debug("task has invalid format: :O", task);
23
25
  }
24
26
  }
25
27
  throw new Error('invalid tasks response, the task object received is not well formatted');
@@ -1,4 +1,4 @@
1
- import type { data, Model, Task, Logger, client as clients, Memory, ModelSource, Features } from '../index.js';
1
+ import type { data, Model, Task, Logger, client as clients, Memory, ModelSource } from '../index.js';
2
2
  export declare class Validator {
3
3
  readonly task: Task;
4
4
  readonly logger: Logger;
@@ -13,10 +13,10 @@ export declare class Validator {
13
13
  test(data: data.Data): AsyncGenerator<Array<{
14
14
  groundTruth: number;
15
15
  pred: number;
16
- features: Features;
16
+ features: number[];
17
17
  }>, void>;
18
18
  inference(data: data.Data): AsyncGenerator<Array<{
19
- features: Features;
19
+ features: number[];
20
20
  pred: number;
21
21
  }>, void>;
22
22
  getModel(): Promise<Model>;
@@ -57,7 +57,7 @@ export class Validator {
57
57
  hits += List(pred).zip(List(ysLabel)).filter(([p, y]) => p === y).size;
58
58
  this.rollingAccuracy = hits / this.size;
59
59
  tf.dispose([xs, ys, yPredTensor]);
60
- yield List(ysLabel).zip(List(pred), List(currentFeatures))
60
+ yield (List(ysLabel).zip(List(pred), List(currentFeatures)))
61
61
  .map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
62
62
  .toArray();
63
63
  next = await iterator.next();
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@epfml/discojs",
3
- "version": "3.0.1-p20240805130603.0",
3
+ "version": "3.0.1-p20240809121820.0",
4
4
  "type": "module",
5
5
  "main": "dist/index.js",
6
6
  "types": "dist/index.d.ts",
package/dist/types.d.ts DELETED
@@ -1,6 +0,0 @@
1
- import type { Map } from 'immutable';
2
- import type { WeightsContainer } from './index.js';
3
- import type { NodeID } from './client/index.js';
4
- export type Path = string;
5
- export type Features = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][];
6
- export type Contributions = Map<NodeID, WeightsContainer>;
package/dist/types.js DELETED
@@ -1 +0,0 @@
1
- export {};