@epfml/discojs 3.0.1-p20240805130603.0 → 3.0.1-p20240809121820.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.js +5 -4
- package/dist/aggregator/mean.js +6 -3
- package/dist/client/decentralized/base.js +12 -10
- package/dist/client/decentralized/peer_pool.js +7 -5
- package/dist/client/event_connection.js +7 -3
- package/dist/client/federated/base.js +5 -3
- package/dist/dataset/data/preprocessing/text_preprocessing.js +2 -1
- package/dist/dataset/data/tabular_data.js +2 -3
- package/dist/index.d.ts +1 -2
- package/dist/index.js +0 -1
- package/dist/memory/base.d.ts +5 -9
- package/dist/memory/empty.d.ts +2 -2
- package/dist/memory/index.d.ts +1 -1
- package/dist/models/gpt/index.d.ts +7 -2
- package/dist/models/gpt/index.js +5 -4
- package/dist/models/gpt/model.js +12 -2
- package/dist/models/gpt/optimizers.js +0 -1
- package/dist/task/task_handler.js +3 -1
- package/dist/validation/validator.d.ts +3 -3
- package/dist/validation/validator.js +1 -1
- package/package.json +1 -1
- package/dist/types.d.ts +0 -6
- package/dist/types.js +0 -1
package/dist/aggregator/base.js
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import { Map, Set } from 'immutable';
|
|
2
3
|
import { EventEmitter } from '../utils/event_emitter.js';
|
|
4
|
+
const debug = createDebug("discojs:aggregator");
|
|
3
5
|
export var AggregationStep;
|
|
4
6
|
(function (AggregationStep) {
|
|
5
7
|
AggregationStep[AggregationStep["ADD"] = 0] = "ADD";
|
|
@@ -75,17 +77,16 @@ export class Base extends EventEmitter {
|
|
|
75
77
|
log(step, from) {
|
|
76
78
|
switch (step) {
|
|
77
79
|
case AggregationStep.ADD:
|
|
78
|
-
|
|
80
|
+
debug(`adding contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
|
|
79
81
|
break;
|
|
80
82
|
case AggregationStep.UPDATE:
|
|
81
83
|
if (from === undefined) {
|
|
82
84
|
return;
|
|
83
85
|
}
|
|
84
|
-
|
|
86
|
+
debug(`updating contribution from node ${from ?? '"unknown"'} for round (${this.communicationRound}, ${this.round})`);
|
|
85
87
|
break;
|
|
86
88
|
case AggregationStep.AGGREGATE:
|
|
87
|
-
|
|
88
|
-
console.log(`Buffer is full. Aggregating weights for round (${this.communicationRound}, ${this.round})\n`);
|
|
89
|
+
debug(`buffer full, aggregating weights for round (${this.communicationRound}, ${this.round})`);
|
|
89
90
|
break;
|
|
90
91
|
default: {
|
|
91
92
|
const _ = step;
|
package/dist/aggregator/mean.js
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import { AggregationStep, Base as Aggregator } from "./base.js";
|
|
2
3
|
import { aggregation } from "../index.js";
|
|
4
|
+
const debug = createDebug("discojs:aggregator:mean");
|
|
3
5
|
/**
|
|
4
6
|
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
5
7
|
*
|
|
@@ -51,7 +53,8 @@ export class MeanAggregator extends Aggregator {
|
|
|
51
53
|
else {
|
|
52
54
|
// Print a warning regarding the default behavior when thresholdType is not specified
|
|
53
55
|
if (thresholdType === undefined) {
|
|
54
|
-
|
|
56
|
+
// TODO enforce validity by splitting features instead of warning
|
|
57
|
+
debug("[WARN] Setting the aggregator's threshold to 100% of the nodes' contributions by default. " +
|
|
55
58
|
"To instead wait for a single contribution, set thresholdType = 'absolute'");
|
|
56
59
|
this.#thresholdType = 'relative';
|
|
57
60
|
}
|
|
@@ -72,9 +75,9 @@ export class MeanAggregator extends Aggregator {
|
|
|
72
75
|
throw new Error("only a single communication round");
|
|
73
76
|
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round)) {
|
|
74
77
|
if (!this.nodes.has(nodeId))
|
|
75
|
-
|
|
78
|
+
debug(`contribution rejected because node ${nodeId} is not registered`);
|
|
76
79
|
if (!this.isWithinRoundCutoff(round))
|
|
77
|
-
|
|
80
|
+
debug(`contribution rejected because round ${round} is not within cutoff`);
|
|
78
81
|
return false;
|
|
79
82
|
}
|
|
80
83
|
this.log(this.contributions.hasIn([0, nodeId])
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import { Map, Set } from 'immutable';
|
|
2
3
|
import { serialization } from "../../index.js";
|
|
3
4
|
import { Client } from '../index.js';
|
|
@@ -6,6 +7,7 @@ import { timeout } from '../utils.js';
|
|
|
6
7
|
import { WebSocketServer, waitMessage, waitMessageWithTimeout } from '../event_connection.js';
|
|
7
8
|
import { PeerPool } from './peer_pool.js';
|
|
8
9
|
import * as messages from './messages.js';
|
|
10
|
+
const debug = createDebug("discojs:client:decentralized");
|
|
9
11
|
/**
|
|
10
12
|
* Represents a decentralized client in a network of peers. Peers coordinate each other with the
|
|
11
13
|
* help of the network's server, yet only exchange payloads between each other. Communication
|
|
@@ -49,7 +51,7 @@ export class Base extends Client {
|
|
|
49
51
|
};
|
|
50
52
|
this.server.send(msg);
|
|
51
53
|
const peerIdMsg = await waitMessage(this.server, type.AssignNodeID);
|
|
52
|
-
|
|
54
|
+
debug(`[${peerIdMsg.id}] assigned id generated by server`);
|
|
53
55
|
if (this._ownId !== undefined) {
|
|
54
56
|
throw new Error('received id from server but was already received');
|
|
55
57
|
}
|
|
@@ -123,11 +125,11 @@ export class Base extends Client {
|
|
|
123
125
|
// this awaits the peer's weight update and adds it to
|
|
124
126
|
// our aggregator upon reception
|
|
125
127
|
(conn) => { this.receivePayloads(conn, round); });
|
|
126
|
-
|
|
128
|
+
debug(`[${this.ownId}] received peers for round ${round}: %o`, connections.keySeq().toJS());
|
|
127
129
|
this.connections = connections;
|
|
128
130
|
}
|
|
129
131
|
catch (e) {
|
|
130
|
-
|
|
132
|
+
debug(`[${this.ownId}] while beginning round: %o`, e);
|
|
131
133
|
this.aggregator.setNodes(Set(this.ownId));
|
|
132
134
|
this.connections = Map();
|
|
133
135
|
}
|
|
@@ -144,20 +146,20 @@ export class Base extends Client {
|
|
|
144
146
|
receivePayloads(connections, round) {
|
|
145
147
|
connections.forEach(async (connection, peerId) => {
|
|
146
148
|
let currentCommunicationRounds = 0;
|
|
147
|
-
|
|
149
|
+
debug(`waiting for peer ${peerId}`);
|
|
148
150
|
do {
|
|
149
151
|
try {
|
|
150
152
|
const message = await waitMessageWithTimeout(connection, type.Payload, 60_000, "Timeout waiting for a contribution from peer " + peerId);
|
|
151
153
|
const decoded = serialization.weights.decode(message.payload);
|
|
152
154
|
if (!this.aggregator.add(peerId, decoded, round, message.round)) {
|
|
153
|
-
|
|
155
|
+
debug(`[${this.ownId}] failed to add contribution from peer ${peerId}`);
|
|
154
156
|
}
|
|
155
157
|
}
|
|
156
158
|
catch (e) {
|
|
157
159
|
if (this.isDisconnected) {
|
|
158
160
|
return;
|
|
159
161
|
}
|
|
160
|
-
|
|
162
|
+
debug(`[${this.ownId}] while receiving payloads: %o`, e);
|
|
161
163
|
}
|
|
162
164
|
} while (++currentCommunicationRounds < this.aggregator.communicationRounds);
|
|
163
165
|
});
|
|
@@ -192,13 +194,13 @@ export class Base extends Client {
|
|
|
192
194
|
payload: encoded
|
|
193
195
|
};
|
|
194
196
|
peer.send(msg);
|
|
195
|
-
|
|
197
|
+
debug(`[${this.ownId}] send weight update to peer ${msg.peer}: %O`, msg);
|
|
196
198
|
}
|
|
197
199
|
}
|
|
198
200
|
}));
|
|
199
201
|
}
|
|
200
|
-
catch {
|
|
201
|
-
throw new Error('error while sending weights');
|
|
202
|
+
catch (cause) {
|
|
203
|
+
throw new Error('error while sending weights', { cause });
|
|
202
204
|
}
|
|
203
205
|
}
|
|
204
206
|
// Wait for aggregation before proceeding to the next communication round.
|
|
@@ -213,7 +215,7 @@ export class Base extends Client {
|
|
|
213
215
|
if (this.isDisconnected) {
|
|
214
216
|
return weights;
|
|
215
217
|
}
|
|
216
|
-
|
|
218
|
+
debug(`[${this.ownId}] while waiting for aggregation: %o`, e);
|
|
217
219
|
break;
|
|
218
220
|
}
|
|
219
221
|
// There is at least one communication round remaining
|
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import { Map } from 'immutable';
|
|
2
3
|
import { Peer } from './peer.js';
|
|
3
4
|
import { PeerConnection } from '../event_connection.js';
|
|
5
|
+
const debug = createDebug("discojs:client:decentralized:pool");
|
|
4
6
|
// TODO cleanup old peers
|
|
5
7
|
export class PeerPool {
|
|
6
8
|
id;
|
|
@@ -9,7 +11,7 @@ export class PeerPool {
|
|
|
9
11
|
this.id = id;
|
|
10
12
|
}
|
|
11
13
|
async shutdown() {
|
|
12
|
-
|
|
14
|
+
debug(`[${this.id}] is shutting down all its connections`);
|
|
13
15
|
// Add a timeout o.w. the promise hangs forever if the other peer is already disconnected
|
|
14
16
|
await Promise.race([
|
|
15
17
|
Promise.all(this.peers.valueSeq().map((peer) => peer.disconnect())),
|
|
@@ -18,7 +20,7 @@ export class PeerPool {
|
|
|
18
20
|
this.peers = Map();
|
|
19
21
|
}
|
|
20
22
|
signal(peerId, signal) {
|
|
21
|
-
|
|
23
|
+
debug(`[${this.id}] signals for %s`, peerId);
|
|
22
24
|
const peer = this.peers.get(peerId);
|
|
23
25
|
if (peer === undefined) {
|
|
24
26
|
throw new Error(`received signal for unknown peer: ${peerId}`);
|
|
@@ -31,17 +33,17 @@ export class PeerPool {
|
|
|
31
33
|
if (peersToConnect.contains(this.id)) {
|
|
32
34
|
throw new Error('peers to connect contains our id');
|
|
33
35
|
}
|
|
34
|
-
|
|
36
|
+
debug(`[${this.id}] is connecting peers: %o`, peersToConnect.toArray());
|
|
35
37
|
const newPeers = Map(peersToConnect
|
|
36
38
|
.filter((id) => !this.peers.has(id))
|
|
37
39
|
.map((id) => [id, new Peer(id, id < this.id)]));
|
|
38
|
-
|
|
40
|
+
debug(`[${this.id}] asked to connect new peers: %o`, newPeers.keySeq().toArray());
|
|
39
41
|
const newPeersConnections = newPeers.map((peer) => new PeerConnection(this.id, peer, signallingServer));
|
|
40
42
|
// adding peers to pool before connecting them because they must be set to call signal on them
|
|
41
43
|
this.peers = this.peers.merge(newPeersConnections);
|
|
42
44
|
clientHandle(this.peers);
|
|
43
45
|
await Promise.all(newPeersConnections.valueSeq().map((conn) => conn.connect()));
|
|
44
|
-
|
|
46
|
+
debug(`[${this.id}] knowns connected peers: %o`, this.peers.keySeq().toArray());
|
|
45
47
|
return this.peers
|
|
46
48
|
.filter((_, id) => peersToConnect.has(id));
|
|
47
49
|
}
|
|
@@ -1,9 +1,11 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
1
|
+
import createDebug from "debug";
|
|
2
|
+
import WebSocket from "isomorphic-ws";
|
|
3
|
+
import msgpack from "msgpack-lite";
|
|
3
4
|
import * as decentralizedMessages from './decentralized/messages.js';
|
|
4
5
|
import { type } from './messages.js';
|
|
5
6
|
import { timeout } from './utils.js';
|
|
6
7
|
import { EventEmitter } from '../utils/event_emitter.js';
|
|
8
|
+
const debug = createDebug("discojs:client:connections");
|
|
7
9
|
export async function waitMessage(connection, type) {
|
|
8
10
|
return await new Promise((resolve) => {
|
|
9
11
|
// "once" is important because we can't resolve the same promise multiple times
|
|
@@ -41,7 +43,9 @@ export class PeerConnection extends EventEmitter {
|
|
|
41
43
|
}
|
|
42
44
|
this.emit(msg.type, msg);
|
|
43
45
|
});
|
|
44
|
-
this.peer.on(
|
|
46
|
+
this.peer.on("close", () => {
|
|
47
|
+
debug(`[${this._ownId}] peer ${this.peer.id} closed connection`);
|
|
48
|
+
});
|
|
45
49
|
await new Promise((resolve) => {
|
|
46
50
|
this.peer.on('connect', resolve);
|
|
47
51
|
});
|
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import { serialization, } from "../../index.js";
|
|
2
3
|
import { Base as Client } from "../base.js";
|
|
3
4
|
import { type } from "../messages.js";
|
|
4
5
|
import { waitMessageWithTimeout, WebSocketServer, } from "../event_connection.js";
|
|
5
6
|
import * as messages from "./messages.js";
|
|
7
|
+
const debug = createDebug("discojs:client:federated");
|
|
6
8
|
/**
|
|
7
9
|
* Client class that communicates with a centralized, federated server, when training
|
|
8
10
|
* a specific task in the federated setting.
|
|
@@ -53,7 +55,7 @@ export class Base extends Client {
|
|
|
53
55
|
};
|
|
54
56
|
this.server.send(msg);
|
|
55
57
|
const received = await waitMessageWithTimeout(this.server, type.AssignNodeID);
|
|
56
|
-
|
|
58
|
+
debug(`[${received.id}] assign id generated by the server`);
|
|
57
59
|
this._ownId = received.id;
|
|
58
60
|
}
|
|
59
61
|
/**
|
|
@@ -87,7 +89,7 @@ export class Base extends Client {
|
|
|
87
89
|
else {
|
|
88
90
|
// Unexpected case: for some reason, the server result is stale.
|
|
89
91
|
// We proceed to the next round without its result.
|
|
90
|
-
|
|
92
|
+
debug(`[${this.ownId}] server result is either stale or not received`);
|
|
91
93
|
this.aggregator.nextRound();
|
|
92
94
|
}
|
|
93
95
|
return await this.aggregationResult;
|
|
@@ -122,7 +124,7 @@ export class Base extends Client {
|
|
|
122
124
|
}
|
|
123
125
|
}
|
|
124
126
|
catch (e) {
|
|
125
|
-
|
|
127
|
+
debug(`[${this.ownId}] while receiving results: %o`, e);
|
|
126
128
|
}
|
|
127
129
|
}
|
|
128
130
|
}
|
|
@@ -35,7 +35,7 @@ const leftPadding = {
|
|
|
35
35
|
// to include one more token than maxSequenceLength in order to have the next token's label of the maxSequenceLength'th token
|
|
36
36
|
const maxLength = task.trainingInformation.maxSequenceLength ?? tokenizer.model_max_length;
|
|
37
37
|
const maxLengthPlusLabel = maxLength + 1;
|
|
38
|
-
let fixedLengthTokens = tf.
|
|
38
|
+
let fixedLengthTokens = tf.tensor1d(tokens, 'int32'); // cast tokens from float to int for gpt-tfjs
|
|
39
39
|
if (fixedLengthTokens.size > maxLengthPlusLabel) { // Should never happen because tokenization truncates inputs
|
|
40
40
|
throw Error("There are more tokens than expected after tokenization and truncation");
|
|
41
41
|
}
|
|
@@ -45,6 +45,7 @@ const leftPadding = {
|
|
|
45
45
|
}
|
|
46
46
|
// if tokens.size == maxLengthPlusLabel we can leave it as it is
|
|
47
47
|
// ys is a one-hot encoding of the next token (i.e. xs shifted by one)
|
|
48
|
+
// cast because oneHot isn't size-typing its return value
|
|
48
49
|
const ys = tf.oneHot(fixedLengthTokens.slice([1]), tokenizer.model.vocab.length + 1);
|
|
49
50
|
// remove the extra token now that ys is created
|
|
50
51
|
const xs = fixedLengthTokens.slice([0], maxLength);
|
|
@@ -13,9 +13,8 @@ export class TabularData extends Data {
|
|
|
13
13
|
try {
|
|
14
14
|
await dataset.iterator();
|
|
15
15
|
}
|
|
16
|
-
catch (
|
|
17
|
-
|
|
18
|
-
throw (e);
|
|
16
|
+
catch (cause) {
|
|
17
|
+
throw new Error('data input format not compatible with chosen task', { cause });
|
|
19
18
|
}
|
|
20
19
|
return new TabularData(dataset, task, size);
|
|
21
20
|
}
|
package/dist/index.d.ts
CHANGED
|
@@ -6,12 +6,11 @@ export * as client from './client/index.js';
|
|
|
6
6
|
export * as aggregator from './aggregator/index.js';
|
|
7
7
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
8
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
9
|
-
export { Memory, type ModelInfo, type
|
|
9
|
+
export { Memory, type ModelInfo, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
|
|
10
10
|
export { Disco, RoundLogs } from './training/index.js';
|
|
11
11
|
export { Validator } from './validation/index.js';
|
|
12
12
|
export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
|
|
13
13
|
export * as models from './models/index.js';
|
|
14
14
|
export * from './task/index.js';
|
|
15
15
|
export * as defaultTasks from './default_tasks/index.js';
|
|
16
|
-
export * from './types.js';
|
|
17
16
|
export * as async_iterator from "./utils/async_iterator.js";
|
package/dist/index.js
CHANGED
|
@@ -13,5 +13,4 @@ export { Model, EpochLogs } from './models/index.js';
|
|
|
13
13
|
export * as models from './models/index.js';
|
|
14
14
|
export * from './task/index.js';
|
|
15
15
|
export * as defaultTasks from './default_tasks/index.js';
|
|
16
|
-
export * from './types.js';
|
|
17
16
|
export * as async_iterator from "./utils/async_iterator.js";
|
package/dist/memory/base.d.ts
CHANGED
|
@@ -1,8 +1,4 @@
|
|
|
1
1
|
import type { Model, TaskID } from '../index.js';
|
|
2
|
-
/**
|
|
3
|
-
* Model path which uniquely identifies a model in memory.
|
|
4
|
-
*/
|
|
5
|
-
export type Path = string;
|
|
6
2
|
/**
|
|
7
3
|
* Type of models stored in memory. Stored models can either be a model currently
|
|
8
4
|
* being trained ("working model") or a regular model saved in memory ("saved model").
|
|
@@ -21,10 +17,10 @@ export interface ModelInfo {
|
|
|
21
17
|
}
|
|
22
18
|
/**
|
|
23
19
|
* A model source uniquely identifies a model stored in memory.
|
|
24
|
-
* It can be in the form of either a model info object or
|
|
20
|
+
* It can be in the form of either a model info object or an ID
|
|
25
21
|
* (one-to-one mapping between the two)
|
|
26
22
|
*/
|
|
27
|
-
export type ModelSource = ModelInfo |
|
|
23
|
+
export type ModelSource = ModelInfo | string;
|
|
28
24
|
/**
|
|
29
25
|
* Represents a model memory system, providing functions to fetch, save, delete and update models.
|
|
30
26
|
* Stored models can either be a model currently being trained ("working model") or a regular model
|
|
@@ -67,7 +63,7 @@ export declare abstract class Memory {
|
|
|
67
63
|
* @param source The model source
|
|
68
64
|
* @returns The saved model's path
|
|
69
65
|
*/
|
|
70
|
-
abstract saveWorkingModel(source: ModelSource): Promise<
|
|
66
|
+
abstract saveWorkingModel(source: ModelSource): Promise<string | undefined>;
|
|
71
67
|
/**
|
|
72
68
|
* Saves the newly provided model to the given model source.
|
|
73
69
|
* Returns the saved model's path
|
|
@@ -75,7 +71,7 @@ export declare abstract class Memory {
|
|
|
75
71
|
* @param model The new model
|
|
76
72
|
* @returns The saved model's path
|
|
77
73
|
*/
|
|
78
|
-
abstract saveModel(source: ModelSource, model: Model): Promise<
|
|
74
|
+
abstract saveModel(source: ModelSource, model: Model): Promise<string | undefined>;
|
|
79
75
|
/**
|
|
80
76
|
* Moves the model identified by the model source to a file system. This is platform-dependent.
|
|
81
77
|
* @param source The model source
|
|
@@ -95,7 +91,7 @@ export declare abstract class Memory {
|
|
|
95
91
|
* @param source The model source
|
|
96
92
|
* @returns The model path
|
|
97
93
|
*/
|
|
98
|
-
abstract getModelMemoryPath(source: ModelSource):
|
|
94
|
+
abstract getModelMemoryPath(source: ModelSource): string | undefined;
|
|
99
95
|
/**
|
|
100
96
|
* Computes the model information corresponding to the given model source, be it a path or model information.
|
|
101
97
|
* This is used to easily switch between model path and information, which are both unique model identifiers
|
package/dist/memory/empty.d.ts
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import type { Model } from '../index.js';
|
|
2
|
-
import type { ModelInfo
|
|
2
|
+
import type { ModelInfo } from './base.js';
|
|
3
3
|
import { Memory } from './base.js';
|
|
4
4
|
/**
|
|
5
5
|
* Represents an empty model memory.
|
|
@@ -14,7 +14,7 @@ export declare class Empty extends Memory {
|
|
|
14
14
|
saveModel(): Promise<undefined>;
|
|
15
15
|
deleteModel(): Promise<void>;
|
|
16
16
|
downloadModel(): Promise<void>;
|
|
17
|
-
getModelMemoryPath():
|
|
17
|
+
getModelMemoryPath(): string;
|
|
18
18
|
getModelInfo(): ModelInfo;
|
|
19
19
|
duplicateSource(): Promise<undefined>;
|
|
20
20
|
}
|
package/dist/memory/index.d.ts
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
export { Empty } from './empty.js';
|
|
2
|
-
export { Memory, type ModelInfo, type
|
|
2
|
+
export { Memory, type ModelInfo, type ModelSource } from './base.js';
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
import * as tf from '@tensorflow/tfjs';
|
|
5
5
|
import { PreTrainedTokenizer } from '@xenova/transformers';
|
|
6
6
|
import { WeightsContainer } from '../../index.js';
|
|
7
|
-
import type { Dataset } from '../../dataset/index.js';
|
|
8
7
|
import { BatchLogs, Model, EpochLogs } from "../index.js";
|
|
9
8
|
import type { Prediction, Sample } from '../model.js';
|
|
10
9
|
import { type GPTConfig } from './config.js';
|
|
@@ -25,7 +24,13 @@ export declare class GPT extends Model {
|
|
|
25
24
|
* @param epochs the number of passes of the training dataset
|
|
26
25
|
* @param tracker
|
|
27
26
|
*/
|
|
28
|
-
train(trainingData: Dataset
|
|
27
|
+
train(trainingData: tf.data.Dataset<{
|
|
28
|
+
xs: tf.Tensor2D;
|
|
29
|
+
ys: tf.Tensor3D;
|
|
30
|
+
}>, validationData?: tf.data.Dataset<{
|
|
31
|
+
xs: tf.Tensor2D;
|
|
32
|
+
ys: tf.Tensor3D;
|
|
33
|
+
}>): AsyncGenerator<BatchLogs, EpochLogs>;
|
|
29
34
|
predict(input: Sample): Promise<Prediction>;
|
|
30
35
|
generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
|
|
31
36
|
get config(): Required<GPTConfig>;
|
package/dist/models/gpt/index.js
CHANGED
|
@@ -1,13 +1,15 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
3
|
**/
|
|
4
|
+
import createDebug from "debug";
|
|
5
|
+
import { List } from 'immutable';
|
|
4
6
|
import * as tf from '@tensorflow/tfjs';
|
|
5
7
|
import { WeightsContainer } from '../../index.js';
|
|
6
8
|
import { Model, EpochLogs } from "../index.js";
|
|
7
9
|
import { GPTForCausalLM } from './model.js';
|
|
8
10
|
import { DEFAULT_CONFIG } from './config.js';
|
|
9
11
|
import evaluate from './evaluate.js';
|
|
10
|
-
|
|
12
|
+
const debug = createDebug("discojs:models:gpt");
|
|
11
13
|
export class GPT extends Model {
|
|
12
14
|
model;
|
|
13
15
|
#maxBatchCount;
|
|
@@ -126,8 +128,7 @@ export class GPT extends Model {
|
|
|
126
128
|
this.model.optimizer.dispose();
|
|
127
129
|
}
|
|
128
130
|
const disposeResults = this.model.dispose();
|
|
129
|
-
if (disposeResults.refCountAfterDispose > 0)
|
|
130
|
-
|
|
131
|
-
}
|
|
131
|
+
if (disposeResults.refCountAfterDispose > 0)
|
|
132
|
+
debug("model not disposed correctly: %o", disposeResults);
|
|
132
133
|
}
|
|
133
134
|
}
|
package/dist/models/gpt/model.js
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
+
import createDebug from "debug";
|
|
1
2
|
import * as tf from '@tensorflow/tfjs';
|
|
2
3
|
import { getModelSizes, DEFAULT_CONFIG } from './config.js';
|
|
3
4
|
import { getCustomAdam, clipByGlobalNormObj } from './optimizers.js';
|
|
4
5
|
import evaluate from './evaluate.js';
|
|
5
6
|
import { GPTArchitecture } from './layers.js';
|
|
7
|
+
const debug = createDebug("discojs:models:gpt");
|
|
6
8
|
/**
|
|
7
9
|
* GPTModel extends tf.LayersModel and overrides tfjs' default training loop
|
|
8
10
|
*
|
|
@@ -87,10 +89,18 @@ class GPTModel extends tf.LayersModel {
|
|
|
87
89
|
this.config.evaluateEvery !== undefined &&
|
|
88
90
|
iteration % this.config.evaluateEvery == 0) {
|
|
89
91
|
const iterationLogs = await evaluate(this, evalDataset, this.config.maxEvalBatches);
|
|
90
|
-
|
|
92
|
+
debug('evaluation metrics: %O', iterationLogs);
|
|
91
93
|
}
|
|
92
94
|
const memory = tf.memory().numBytes / 1024 / 1024 / 1024;
|
|
93
|
-
|
|
95
|
+
debug("training metrics: %O", {
|
|
96
|
+
epoch,
|
|
97
|
+
iteration,
|
|
98
|
+
loss,
|
|
99
|
+
memory,
|
|
100
|
+
allocated: tf.memory().numTensors,
|
|
101
|
+
preprocessingTime,
|
|
102
|
+
weightUpdateTime,
|
|
103
|
+
});
|
|
94
104
|
iteration++;
|
|
95
105
|
next = await iterator.next();
|
|
96
106
|
}
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import axios from 'axios';
|
|
2
|
+
import createDebug from "debug";
|
|
2
3
|
import { Map } from 'immutable';
|
|
3
4
|
import { serialization } from '../index.js';
|
|
4
5
|
import { isTask } from './task.js';
|
|
6
|
+
const debug = createDebug("discojs:task:handlers");
|
|
5
7
|
const TASK_ENDPOINT = 'tasks';
|
|
6
8
|
export async function pushTask(url, task, model) {
|
|
7
9
|
await axios.post(url.href + TASK_ENDPOINT, {
|
|
@@ -19,7 +21,7 @@ export async function fetchTasks(url) {
|
|
|
19
21
|
else if (!tasks.every(isTask)) {
|
|
20
22
|
for (const task of tasks) {
|
|
21
23
|
if (!isTask(task)) {
|
|
22
|
-
|
|
24
|
+
debug("task has invalid format: :O", task);
|
|
23
25
|
}
|
|
24
26
|
}
|
|
25
27
|
throw new Error('invalid tasks response, the task object received is not well formatted');
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { data, Model, Task, Logger, client as clients, Memory, ModelSource
|
|
1
|
+
import type { data, Model, Task, Logger, client as clients, Memory, ModelSource } from '../index.js';
|
|
2
2
|
export declare class Validator {
|
|
3
3
|
readonly task: Task;
|
|
4
4
|
readonly logger: Logger;
|
|
@@ -13,10 +13,10 @@ export declare class Validator {
|
|
|
13
13
|
test(data: data.Data): AsyncGenerator<Array<{
|
|
14
14
|
groundTruth: number;
|
|
15
15
|
pred: number;
|
|
16
|
-
features:
|
|
16
|
+
features: number[];
|
|
17
17
|
}>, void>;
|
|
18
18
|
inference(data: data.Data): AsyncGenerator<Array<{
|
|
19
|
-
features:
|
|
19
|
+
features: number[];
|
|
20
20
|
pred: number;
|
|
21
21
|
}>, void>;
|
|
22
22
|
getModel(): Promise<Model>;
|
|
@@ -57,7 +57,7 @@ export class Validator {
|
|
|
57
57
|
hits += List(pred).zip(List(ysLabel)).filter(([p, y]) => p === y).size;
|
|
58
58
|
this.rollingAccuracy = hits / this.size;
|
|
59
59
|
tf.dispose([xs, ys, yPredTensor]);
|
|
60
|
-
yield List(ysLabel).zip(List(pred), List(currentFeatures))
|
|
60
|
+
yield (List(ysLabel).zip(List(pred), List(currentFeatures)))
|
|
61
61
|
.map(([gt, p, f]) => ({ groundTruth: gt, pred: p, features: f }))
|
|
62
62
|
.toArray();
|
|
63
63
|
next = await iterator.next();
|
package/package.json
CHANGED
package/dist/types.d.ts
DELETED
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
import type { Map } from 'immutable';
|
|
2
|
-
import type { WeightsContainer } from './index.js';
|
|
3
|
-
import type { NodeID } from './client/index.js';
|
|
4
|
-
export type Path = string;
|
|
5
|
-
export type Features = number | number[] | number[][] | number[][][] | number[][][][] | number[][][][][];
|
|
6
|
-
export type Contributions = Map<NodeID, WeightsContainer>;
|
package/dist/types.js
DELETED
|
@@ -1 +0,0 @@
|
|
|
1
|
-
export {};
|