@epfml/discojs 3.0.1-p20241025115642.0 → 3.0.1-p20241028120035.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/get.d.ts +3 -3
- package/dist/client/client.d.ts +5 -5
- package/dist/client/decentralized/decentralized_client.d.ts +2 -2
- package/dist/client/federated/federated_client.d.ts +2 -2
- package/dist/client/utils.d.ts +2 -2
- package/dist/dataset/dataset.d.ts +9 -2
- package/dist/dataset/dataset.js +83 -36
- package/dist/dataset/image.d.ts +5 -0
- package/dist/dataset/image.js +6 -1
- package/dist/dataset/index.d.ts +0 -1
- package/dist/dataset/index.js +0 -1
- package/dist/dataset/types.d.ts +2 -0
- package/dist/default_tasks/cifar10.d.ts +1 -1
- package/dist/default_tasks/cifar10.js +2 -3
- package/dist/default_tasks/lus_covid.d.ts +1 -1
- package/dist/default_tasks/lus_covid.js +2 -3
- package/dist/default_tasks/mnist.d.ts +1 -1
- package/dist/default_tasks/mnist.js +2 -4
- package/dist/default_tasks/simple_face.d.ts +1 -1
- package/dist/default_tasks/simple_face.js +2 -3
- package/dist/default_tasks/titanic.d.ts +1 -1
- package/dist/default_tasks/titanic.js +3 -6
- package/dist/default_tasks/wikitext.d.ts +1 -1
- package/dist/default_tasks/wikitext.js +1 -2
- package/dist/index.d.ts +4 -5
- package/dist/index.js +4 -5
- package/dist/models/gpt/index.d.ts +13 -16
- package/dist/models/gpt/index.js +62 -43
- package/dist/models/gpt/model.d.ts +1 -15
- package/dist/models/gpt/model.js +1 -75
- package/dist/models/model.d.ts +7 -12
- package/dist/models/tfjs.d.ts +10 -8
- package/dist/models/tfjs.js +106 -44
- package/dist/models/tokenizer.d.ts +1 -1
- package/dist/privacy.js +1 -1
- package/dist/processing/image.d.ts +18 -0
- package/dist/processing/image.js +75 -0
- package/dist/processing/index.d.ts +8 -0
- package/dist/processing/index.js +106 -0
- package/dist/processing/tabular.d.ts +19 -0
- package/dist/processing/tabular.js +33 -0
- package/dist/processing/text.d.ts +11 -0
- package/dist/processing/text.js +33 -0
- package/dist/serialization/model.d.ts +3 -3
- package/dist/serialization/model.js +19 -6
- package/dist/task/task.d.ts +4 -3
- package/dist/task/task.js +5 -3
- package/dist/task/task_handler.d.ts +3 -3
- package/dist/task/task_provider.d.ts +4 -4
- package/dist/task/training_information.d.ts +25 -16
- package/dist/task/training_information.js +76 -72
- package/dist/training/disco.d.ts +20 -12
- package/dist/training/disco.js +32 -13
- package/dist/training/trainer.d.ts +6 -7
- package/dist/training/trainer.js +6 -6
- package/dist/types/data_format.d.ts +40 -0
- package/dist/types/index.d.ts +2 -0
- package/dist/types/index.js +1 -0
- package/dist/validator.d.ts +10 -0
- package/dist/validator.js +30 -0
- package/package.json +4 -2
- package/dist/dataset/data/data.d.ts +0 -47
- package/dist/dataset/data/data.js +0 -88
- package/dist/dataset/data/data_split.d.ts +0 -8
- package/dist/dataset/data/helpers.d.ts +0 -10
- package/dist/dataset/data/helpers.js +0 -97
- package/dist/dataset/data/image_data.d.ts +0 -11
- package/dist/dataset/data/image_data.js +0 -43
- package/dist/dataset/data/index.d.ts +0 -5
- package/dist/dataset/data/index.js +0 -5
- package/dist/dataset/data/preprocessing/base.d.ts +0 -16
- package/dist/dataset/data/preprocessing/base.js +0 -1
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/image_preprocessing.js +0 -42
- package/dist/dataset/data/preprocessing/index.d.ts +0 -4
- package/dist/dataset/data/preprocessing/index.js +0 -3
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +0 -45
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/text_preprocessing.js +0 -100
- package/dist/dataset/data/tabular_data.d.ts +0 -11
- package/dist/dataset/data/tabular_data.js +0 -24
- package/dist/dataset/data/text_data.d.ts +0 -11
- package/dist/dataset/data/text_data.js +0 -14
- package/dist/processing.d.ts +0 -35
- package/dist/processing.js +0 -89
- package/dist/types.d.ts +0 -3
- package/dist/types.js +0 -1
- package/dist/validation/index.d.ts +0 -1
- package/dist/validation/index.js +0 -1
- package/dist/validation/validator.d.ts +0 -10
- package/dist/validation/validator.js +0 -113
- /package/dist/{dataset/data/data_split.js → types/data_format.js} +0 -0
package/dist/aggregator/get.d.ts
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import type { Task } from '../index.js';
|
|
1
|
+
import type { DataType, Task } from '../index.js';
|
|
2
2
|
import { aggregator } from '../index.js';
|
|
3
3
|
type AggregatorOptions = Partial<{
|
|
4
|
-
scheme: Task[
|
|
4
|
+
scheme: Task<DataType>["trainingInformation"]["scheme"];
|
|
5
5
|
roundCutOff: number;
|
|
6
6
|
threshold: number;
|
|
7
7
|
thresholdType: 'relative' | 'absolute';
|
|
@@ -24,5 +24,5 @@ type AggregatorOptions = Partial<{
|
|
|
24
24
|
* @param options Options passed down to the aggregator's constructor
|
|
25
25
|
* @returns The aggregator
|
|
26
26
|
*/
|
|
27
|
-
export declare function getAggregator(task: Task
|
|
27
|
+
export declare function getAggregator(task: Task<DataType>, options?: AggregatorOptions): aggregator.Aggregator;
|
|
28
28
|
export {};
|
package/dist/client/client.d.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { Model, Task, WeightsContainer
|
|
1
|
+
import type { DataType, Model, RoundStatus, Task, WeightsContainer } from "../index.js";
|
|
2
2
|
import type { NodeID } from './types.js';
|
|
3
3
|
import type { EventConnection } from './event_connection.js';
|
|
4
4
|
import type { Aggregator } from '../aggregator/index.js';
|
|
@@ -11,7 +11,7 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
11
11
|
'status': RoundStatus;
|
|
12
12
|
}> {
|
|
13
13
|
readonly url: URL;
|
|
14
|
-
readonly task: Task
|
|
14
|
+
readonly task: Task<DataType>;
|
|
15
15
|
readonly aggregator: Aggregator;
|
|
16
16
|
protected _ownId?: NodeID;
|
|
17
17
|
protected _server?: EventConnection;
|
|
@@ -30,7 +30,7 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
30
30
|
*/
|
|
31
31
|
private previousStatus;
|
|
32
32
|
constructor(url: URL, // The network server's URL to connect to
|
|
33
|
-
task: Task
|
|
33
|
+
task: Task<DataType>, // The client's corresponding task
|
|
34
34
|
aggregator: Aggregator);
|
|
35
35
|
/**
|
|
36
36
|
* Communication callback called at the beginning of every training round.
|
|
@@ -47,7 +47,7 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
47
47
|
* This method is overriden by the federated and decentralized clients
|
|
48
48
|
* By default, it fetches and returns the server's base model
|
|
49
49
|
*/
|
|
50
|
-
connect(): Promise<Model
|
|
50
|
+
connect(): Promise<Model<DataType>>;
|
|
51
51
|
/**
|
|
52
52
|
* Handles the disconnection process of the client from any sort of network server.
|
|
53
53
|
*/
|
|
@@ -94,7 +94,7 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
94
94
|
* Fetches the latest model available on the network's server, for the adequate task.
|
|
95
95
|
* @returns The latest model
|
|
96
96
|
*/
|
|
97
|
-
getLatestModel(): Promise<Model
|
|
97
|
+
getLatestModel(): Promise<Model<DataType>>;
|
|
98
98
|
/**
|
|
99
99
|
* Number of contributors to a collaborative session
|
|
100
100
|
* If decentralized, it should be the number of peers
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { Model, WeightsContainer } from "../../index.js";
|
|
1
|
+
import type { DataType, Model, WeightsContainer } from "../../index.js";
|
|
2
2
|
import { Client } from '../client.js';
|
|
3
3
|
/**
|
|
4
4
|
* Represents a decentralized client in a network of peers. Peers coordinate each other with the
|
|
@@ -18,7 +18,7 @@ export declare class DecentralizedClient extends Client {
|
|
|
18
18
|
* create peer-to-peer WebRTC connections with peers. The server is used to exchange
|
|
19
19
|
* peers network information.
|
|
20
20
|
*/
|
|
21
|
-
connect(): Promise<Model
|
|
21
|
+
connect(): Promise<Model<DataType>>;
|
|
22
22
|
disconnect(): Promise<void>;
|
|
23
23
|
/**
|
|
24
24
|
* At the beginning of a round, each peer tells the server it is ready to proceed
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { Model, WeightsContainer } from "../../index.js";
|
|
1
|
+
import type { DataType, Model, WeightsContainer } from "../../index.js";
|
|
2
2
|
import { Client } from "../client.js";
|
|
3
3
|
/**
|
|
4
4
|
* Client class that communicates with a centralized, federated server, when training
|
|
@@ -12,7 +12,7 @@ export declare class FederatedClient extends Client {
|
|
|
12
12
|
* as well as the latest training information: latest global model, current round and
|
|
13
13
|
* whether we are waiting for more participants.
|
|
14
14
|
*/
|
|
15
|
-
connect(): Promise<Model
|
|
15
|
+
connect(): Promise<Model<DataType>>;
|
|
16
16
|
/**
|
|
17
17
|
* Disconnection process when user quits the task.
|
|
18
18
|
*/
|
package/dist/client/utils.d.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { Task } from '../index.js';
|
|
1
|
+
import type { DataType, Task } from '../index.js';
|
|
2
2
|
import { client as clients, type aggregator } from '../index.js';
|
|
3
3
|
export declare function timeout(ms?: number, errorMsg?: string): Promise<never>;
|
|
4
|
-
export declare function getClient(trainingScheme:
|
|
4
|
+
export declare function getClient(trainingScheme: Task<DataType>["trainingInformation"]["scheme"], serverURL: URL, task: Task<DataType>, aggregator: aggregator.Aggregator): clients.Client;
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { Batched } from "./types.js";
|
|
2
2
|
type DatasetLike<T> = AsyncIterable<T> | Iterable<T> | (() => AsyncIterator<T, void>) | (() => Iterator<T, void>);
|
|
3
3
|
/** Immutable series of data */
|
|
4
4
|
export declare class Dataset<T> implements AsyncIterable<T> {
|
|
@@ -31,7 +31,9 @@ export declare class Dataset<T> implements AsyncIterable<T> {
|
|
|
31
31
|
*
|
|
32
32
|
* @param size count of element per chunk
|
|
33
33
|
*/
|
|
34
|
-
batch(size: number): Dataset<
|
|
34
|
+
batch(size: number): Dataset<Batched<T>>;
|
|
35
|
+
/** Flatten chunks */
|
|
36
|
+
unbatch<U>(this: Dataset<Batched<U>>): Dataset<U>;
|
|
35
37
|
/** Join side-by-side
|
|
36
38
|
*
|
|
37
39
|
* Stops as soon as one runs out
|
|
@@ -44,5 +46,10 @@ export declare class Dataset<T> implements AsyncIterable<T> {
|
|
|
44
46
|
* This is a costly operation as we need to go through the whole Dataset.
|
|
45
47
|
*/
|
|
46
48
|
size(): Promise<number>;
|
|
49
|
+
/** Try to keep generated elements to avoid recomputing
|
|
50
|
+
*
|
|
51
|
+
* Drops everything when memory pressure is applied.
|
|
52
|
+
*/
|
|
53
|
+
cached(): Dataset<T>;
|
|
47
54
|
}
|
|
48
55
|
export {};
|
package/dist/dataset/dataset.js
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
-
import
|
|
1
|
+
import createDebug from "debug";
|
|
2
|
+
import { List, Range } from "immutable";
|
|
3
|
+
const debug = createDebug("discojs:dataset");
|
|
2
4
|
/** Immutable series of data */
|
|
3
5
|
export class Dataset {
|
|
4
6
|
#content;
|
|
@@ -32,13 +34,10 @@ export class Dataset {
|
|
|
32
34
|
* @param mapper how to change each element
|
|
33
35
|
*/
|
|
34
36
|
map(mapper) {
|
|
35
|
-
const content = {
|
|
36
|
-
[Symbol.asyncIterator]: () => this.#content(),
|
|
37
|
-
};
|
|
38
37
|
return new Dataset(async function* () {
|
|
39
|
-
for await (const e of
|
|
38
|
+
for await (const e of this)
|
|
40
39
|
yield await mapper(e);
|
|
41
|
-
});
|
|
40
|
+
}.bind(this));
|
|
42
41
|
}
|
|
43
42
|
/** Combine with another Dataset.
|
|
44
43
|
*
|
|
@@ -47,13 +46,10 @@ export class Dataset {
|
|
|
47
46
|
chain(other) {
|
|
48
47
|
if (!(other instanceof Dataset))
|
|
49
48
|
other = new Dataset(other);
|
|
50
|
-
const self = {
|
|
51
|
-
[Symbol.asyncIterator]: () => this.#content(),
|
|
52
|
-
};
|
|
53
49
|
return new Dataset(async function* () {
|
|
54
|
-
yield*
|
|
50
|
+
yield* this;
|
|
55
51
|
yield* other;
|
|
56
|
-
});
|
|
52
|
+
}.bind(this));
|
|
57
53
|
}
|
|
58
54
|
/** Divide into two based on given ratio
|
|
59
55
|
*
|
|
@@ -62,16 +58,13 @@ export class Dataset {
|
|
|
62
58
|
split(ratio) {
|
|
63
59
|
if (ratio < 0 || ratio > 1)
|
|
64
60
|
throw new Error("ratio out of range");
|
|
65
|
-
const content = {
|
|
66
|
-
[Symbol.asyncIterator]: () => this.#content(),
|
|
67
|
-
};
|
|
68
61
|
// to avoid using random sampling or knowing the size beforehand,
|
|
69
62
|
// we compute the actual ratio and make it converge towards the wanted one
|
|
70
63
|
return [
|
|
71
64
|
new Dataset(async function* () {
|
|
72
65
|
let yielded_by_other = 0;
|
|
73
66
|
let total_size = 0;
|
|
74
|
-
for await (const e of
|
|
67
|
+
for await (const e of this) {
|
|
75
68
|
total_size++;
|
|
76
69
|
if (yielded_by_other / total_size >= ratio) {
|
|
77
70
|
yield e;
|
|
@@ -80,18 +73,18 @@ export class Dataset {
|
|
|
80
73
|
yielded_by_other++;
|
|
81
74
|
}
|
|
82
75
|
}
|
|
83
|
-
}),
|
|
76
|
+
}.bind(this)),
|
|
84
77
|
new Dataset(async function* () {
|
|
85
78
|
let yielded = 0;
|
|
86
79
|
let total_size = 0;
|
|
87
|
-
for await (const e of
|
|
80
|
+
for await (const e of this) {
|
|
88
81
|
total_size++;
|
|
89
82
|
if (yielded / total_size < ratio) {
|
|
90
83
|
yielded++;
|
|
91
84
|
yield e;
|
|
92
85
|
}
|
|
93
86
|
}
|
|
94
|
-
}),
|
|
87
|
+
}.bind(this)),
|
|
95
88
|
];
|
|
96
89
|
}
|
|
97
90
|
/** Slice into chunks
|
|
@@ -103,21 +96,30 @@ export class Dataset {
|
|
|
103
96
|
batch(size) {
|
|
104
97
|
if (size <= 0 || !Number.isInteger(size))
|
|
105
98
|
throw new Error("invalid size");
|
|
106
|
-
const content = {
|
|
107
|
-
[Symbol.asyncIterator]: () => this.#content(),
|
|
108
|
-
};
|
|
109
99
|
return new Dataset(async function* () {
|
|
110
|
-
|
|
111
|
-
for
|
|
112
|
-
batch =
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
100
|
+
const iter = this[Symbol.asyncIterator]();
|
|
101
|
+
for (;;) {
|
|
102
|
+
const batch = List(await Promise.all(Range(0, size).map(() => iter.next()))).flatMap((res) => {
|
|
103
|
+
if (res.done)
|
|
104
|
+
return [];
|
|
105
|
+
else
|
|
106
|
+
return [res.value];
|
|
107
|
+
});
|
|
108
|
+
if (batch.isEmpty())
|
|
109
|
+
break;
|
|
119
110
|
yield batch;
|
|
120
|
-
|
|
111
|
+
// iterator couldn't generate more
|
|
112
|
+
if (batch.size < size)
|
|
113
|
+
break;
|
|
114
|
+
}
|
|
115
|
+
}.bind(this));
|
|
116
|
+
}
|
|
117
|
+
/** Flatten chunks */
|
|
118
|
+
unbatch() {
|
|
119
|
+
return new Dataset(async function* () {
|
|
120
|
+
for await (const batch of this)
|
|
121
|
+
yield* batch;
|
|
122
|
+
}.bind(this));
|
|
121
123
|
}
|
|
122
124
|
/** Join side-by-side
|
|
123
125
|
*
|
|
@@ -128,11 +130,8 @@ export class Dataset {
|
|
|
128
130
|
zip(other) {
|
|
129
131
|
if (!(other instanceof Dataset))
|
|
130
132
|
other = new Dataset(other);
|
|
131
|
-
const content = {
|
|
132
|
-
[Symbol.asyncIterator]: () => this.#content(),
|
|
133
|
-
};
|
|
134
133
|
return new Dataset(async function* () {
|
|
135
|
-
const left =
|
|
134
|
+
const left = this[Symbol.asyncIterator]();
|
|
136
135
|
const right = other[Symbol.asyncIterator]();
|
|
137
136
|
while (true) {
|
|
138
137
|
const [l, r] = await Promise.all([left.next(), right.next()]);
|
|
@@ -140,7 +139,7 @@ export class Dataset {
|
|
|
140
139
|
return;
|
|
141
140
|
yield [l.value, r.value];
|
|
142
141
|
}
|
|
143
|
-
});
|
|
142
|
+
}.bind(this));
|
|
144
143
|
}
|
|
145
144
|
/** Compute size
|
|
146
145
|
*
|
|
@@ -152,4 +151,52 @@ export class Dataset {
|
|
|
152
151
|
ret++;
|
|
153
152
|
return ret;
|
|
154
153
|
}
|
|
154
|
+
/** Try to keep generated elements to avoid recomputing
|
|
155
|
+
*
|
|
156
|
+
* Drops everything when memory pressure is applied.
|
|
157
|
+
*/
|
|
158
|
+
cached() {
|
|
159
|
+
return new CachingDataset(this.#content);
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
/**
|
|
163
|
+
* Avoid recomputing the parent dataset, without hogging memory
|
|
164
|
+
*
|
|
165
|
+
* As dataset operations can be time-consuming, this keeps a weak reference to
|
|
166
|
+
* the generated elements so that a second iteration might yield theses directly.
|
|
167
|
+
**/
|
|
168
|
+
class CachingDataset extends Dataset {
|
|
169
|
+
// potential reference to all elements
|
|
170
|
+
// tristate: undefined == empty, [false, _] == filling, [true, _] == filled
|
|
171
|
+
#cache = new WeakRef([false, List()]);
|
|
172
|
+
[Symbol.asyncIterator]() {
|
|
173
|
+
const cached = this.#cache.deref();
|
|
174
|
+
if (cached !== undefined && cached[0]) {
|
|
175
|
+
debug("valid cache, reading from it");
|
|
176
|
+
// eslint-disable-next-line @typescript-eslint/require-await
|
|
177
|
+
return (async function* () {
|
|
178
|
+
yield* cached[1];
|
|
179
|
+
})();
|
|
180
|
+
}
|
|
181
|
+
debug("cache invalid, reading from dataset");
|
|
182
|
+
this.#cache = new WeakRef([false, List()]);
|
|
183
|
+
const parentContent = {
|
|
184
|
+
[Symbol.asyncIterator]: () => super[Symbol.asyncIterator](),
|
|
185
|
+
};
|
|
186
|
+
return async function* () {
|
|
187
|
+
for await (const e of parentContent) {
|
|
188
|
+
yield e;
|
|
189
|
+
const caching = this.#cache.deref();
|
|
190
|
+
if (caching !== undefined)
|
|
191
|
+
caching[1] = caching[1].push(e);
|
|
192
|
+
}
|
|
193
|
+
const caching = this.#cache.deref();
|
|
194
|
+
if (caching === undefined) {
|
|
195
|
+
debug("cache evicted while filling");
|
|
196
|
+
return;
|
|
197
|
+
}
|
|
198
|
+
debug("cache filled");
|
|
199
|
+
caching[0] = true;
|
|
200
|
+
}.bind(this)();
|
|
201
|
+
}
|
|
155
202
|
}
|
package/dist/dataset/image.d.ts
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* Raw image with type level dimensions.
|
|
3
3
|
*
|
|
4
|
+
* Per convention, `data` layout is as follow
|
|
5
|
+
* `height` chunk each containing
|
|
6
|
+
* `width` chunk each containing
|
|
7
|
+
* a chunk of `depth` bytes
|
|
8
|
+
*
|
|
4
9
|
* @typeParam D depth of the image
|
|
5
10
|
* @typeParam W width, positive and integral
|
|
6
11
|
* @typeParam H height, positive and integral
|
package/dist/dataset/image.js
CHANGED
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* Raw image with type level dimensions.
|
|
3
3
|
*
|
|
4
|
+
* Per convention, `data` layout is as follow
|
|
5
|
+
* `height` chunk each containing
|
|
6
|
+
* `width` chunk each containing
|
|
7
|
+
* a chunk of `depth` bytes
|
|
8
|
+
*
|
|
4
9
|
* @typeParam D depth of the image
|
|
5
10
|
* @typeParam W width, positive and integral
|
|
6
11
|
* @typeParam H height, positive and integral
|
|
@@ -16,6 +21,6 @@ export class Image {
|
|
|
16
21
|
this.height = height;
|
|
17
22
|
this.depth = depth;
|
|
18
23
|
if (data.length != width * height * depth)
|
|
19
|
-
throw new Error("data isn't of
|
|
24
|
+
throw new Error("data isn't of expected size");
|
|
20
25
|
}
|
|
21
26
|
}
|
package/dist/dataset/index.d.ts
CHANGED
package/dist/dataset/index.js
CHANGED
package/dist/dataset/types.d.ts
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import type { TaskProvider } from '../index.js';
|
|
2
|
-
export declare const cifar10: TaskProvider
|
|
2
|
+
export declare const cifar10: TaskProvider<'image'>;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import {
|
|
2
|
+
import { models } from '../index.js';
|
|
3
3
|
import baseModel from '../models/mobileNet_v1_025_224.js';
|
|
4
4
|
export const cifar10 = {
|
|
5
5
|
getTask() {
|
|
@@ -24,7 +24,6 @@ export const cifar10 = {
|
|
|
24
24
|
validationSplit: 0.2,
|
|
25
25
|
batchSize: 10,
|
|
26
26
|
dataType: 'image',
|
|
27
|
-
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
28
27
|
IMAGE_H: 224,
|
|
29
28
|
IMAGE_W: 224,
|
|
30
29
|
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
|
|
@@ -55,6 +54,6 @@ export const cifar10 = {
|
|
|
55
54
|
loss: 'categoricalCrossentropy',
|
|
56
55
|
metrics: ['accuracy']
|
|
57
56
|
});
|
|
58
|
-
return new models.TFJS(model);
|
|
57
|
+
return new models.TFJS('image', model);
|
|
59
58
|
}
|
|
60
59
|
};
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import type { TaskProvider } from '../index.js';
|
|
2
|
-
export declare const lusCovid: TaskProvider
|
|
2
|
+
export declare const lusCovid: TaskProvider<'image'>;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import {
|
|
2
|
+
import { models } from '../index.js';
|
|
3
3
|
export const lusCovid = {
|
|
4
4
|
getTask() {
|
|
5
5
|
return {
|
|
@@ -24,7 +24,6 @@ export const lusCovid = {
|
|
|
24
24
|
batchSize: 5,
|
|
25
25
|
IMAGE_H: 100,
|
|
26
26
|
IMAGE_W: 100,
|
|
27
|
-
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
28
27
|
LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
|
|
29
28
|
dataType: 'image',
|
|
30
29
|
scheme: 'federated',
|
|
@@ -82,6 +81,6 @@ export const lusCovid = {
|
|
|
82
81
|
loss: 'binaryCrossentropy',
|
|
83
82
|
metrics: ['accuracy']
|
|
84
83
|
});
|
|
85
|
-
return Promise.resolve(new models.TFJS(model));
|
|
84
|
+
return Promise.resolve(new models.TFJS('image', model));
|
|
86
85
|
}
|
|
87
86
|
};
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import type { TaskProvider } from '../index.js';
|
|
2
|
-
export declare const mnist: TaskProvider
|
|
2
|
+
export declare const mnist: TaskProvider<'image'>;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import {
|
|
2
|
+
import { models } from '../index.js';
|
|
3
3
|
export const mnist = {
|
|
4
4
|
getTask() {
|
|
5
5
|
return {
|
|
@@ -25,8 +25,6 @@ export const mnist = {
|
|
|
25
25
|
dataType: 'image',
|
|
26
26
|
IMAGE_H: 28,
|
|
27
27
|
IMAGE_W: 28,
|
|
28
|
-
// Images should already be at the right size but resizing just in case
|
|
29
|
-
preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
|
|
30
28
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
31
29
|
scheme: 'decentralized',
|
|
32
30
|
aggregationStrategy: 'secure',
|
|
@@ -58,6 +56,6 @@ export const mnist = {
|
|
|
58
56
|
loss: 'categoricalCrossentropy',
|
|
59
57
|
metrics: ['accuracy']
|
|
60
58
|
});
|
|
61
|
-
return Promise.resolve(new models.TFJS(model));
|
|
59
|
+
return Promise.resolve(new models.TFJS('image', model));
|
|
62
60
|
}
|
|
63
61
|
};
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import type { TaskProvider } from '../index.js';
|
|
2
|
-
export declare const simpleFace: TaskProvider
|
|
2
|
+
export declare const simpleFace: TaskProvider<'image'>;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import {
|
|
2
|
+
import { models } from '../index.js';
|
|
3
3
|
import baseModel from '../models/mobileNetV2_35_alpha_2_classes.js';
|
|
4
4
|
export const simpleFace = {
|
|
5
5
|
getTask() {
|
|
@@ -22,7 +22,6 @@ export const simpleFace = {
|
|
|
22
22
|
roundDuration: 1,
|
|
23
23
|
validationSplit: 0.2,
|
|
24
24
|
batchSize: 10,
|
|
25
|
-
preprocessingFunctions: [data.ImagePreprocessing.Normalize],
|
|
26
25
|
dataType: 'image',
|
|
27
26
|
IMAGE_H: 200,
|
|
28
27
|
IMAGE_W: 200,
|
|
@@ -43,6 +42,6 @@ export const simpleFace = {
|
|
|
43
42
|
loss: 'categoricalCrossentropy',
|
|
44
43
|
metrics: ['accuracy']
|
|
45
44
|
});
|
|
46
|
-
return new models.TFJS(model);
|
|
45
|
+
return new models.TFJS('image', model);
|
|
47
46
|
}
|
|
48
47
|
};
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import type { TaskProvider } from '../index.js';
|
|
2
|
-
export declare const titanic: TaskProvider
|
|
2
|
+
export declare const titanic: TaskProvider<'tabular'>;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import {
|
|
2
|
+
import { models } from '../index.js';
|
|
3
3
|
export const titanic = {
|
|
4
4
|
getTask() {
|
|
5
5
|
return {
|
|
@@ -49,7 +49,6 @@ export const titanic = {
|
|
|
49
49
|
roundDuration: 2,
|
|
50
50
|
validationSplit: 0.2,
|
|
51
51
|
batchSize: 30,
|
|
52
|
-
preprocessingFunctions: [data.TabularPreprocessing.Sanitize],
|
|
53
52
|
dataType: 'tabular',
|
|
54
53
|
inputColumns: [
|
|
55
54
|
'Age',
|
|
@@ -58,9 +57,7 @@ export const titanic = {
|
|
|
58
57
|
'Fare',
|
|
59
58
|
'Pclass'
|
|
60
59
|
],
|
|
61
|
-
|
|
62
|
-
'Survived'
|
|
63
|
-
],
|
|
60
|
+
outputColumn: 'Survived',
|
|
64
61
|
scheme: 'federated',
|
|
65
62
|
aggregationStrategy: 'mean',
|
|
66
63
|
minNbOfParticipants: 2,
|
|
@@ -84,6 +81,6 @@ export const titanic = {
|
|
|
84
81
|
loss: 'binaryCrossentropy',
|
|
85
82
|
metrics: ['accuracy']
|
|
86
83
|
});
|
|
87
|
-
return Promise.resolve(new models.TFJS(model));
|
|
84
|
+
return Promise.resolve(new models.TFJS('tabular', model));
|
|
88
85
|
}
|
|
89
86
|
};
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import type { TaskProvider } from '../index.js';
|
|
2
|
-
export declare const wikitext: TaskProvider
|
|
2
|
+
export declare const wikitext: TaskProvider<'text'>;
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import {
|
|
1
|
+
import { models } from '../index.js';
|
|
2
2
|
export const wikitext = {
|
|
3
3
|
getTask() {
|
|
4
4
|
return {
|
|
@@ -23,7 +23,6 @@ export const wikitext = {
|
|
|
23
23
|
},
|
|
24
24
|
trainingInformation: {
|
|
25
25
|
dataType: 'text',
|
|
26
|
-
preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
|
|
27
26
|
scheme: 'federated',
|
|
28
27
|
aggregationStrategy: 'mean',
|
|
29
28
|
minNbOfParticipants: 2,
|
package/dist/index.d.ts
CHANGED
|
@@ -7,14 +7,13 @@ export * as aggregator from './aggregator/index.js';
|
|
|
7
7
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
8
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
9
9
|
export { Disco, RoundLogs, RoundStatus } from './training/index.js';
|
|
10
|
-
export { Validator } from './
|
|
10
|
+
export { Validator } from './validator.js';
|
|
11
11
|
export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
|
|
12
12
|
export * as models from './models/index.js';
|
|
13
13
|
export * from './task/index.js';
|
|
14
14
|
export * as defaultTasks from './default_tasks/index.js';
|
|
15
15
|
export * as async_iterator from "./utils/async_iterator.js";
|
|
16
16
|
export { EventEmitter } from "./utils/event_emitter.js";
|
|
17
|
-
export
|
|
18
|
-
export * from "./
|
|
19
|
-
export * from "./
|
|
20
|
-
export * as processing from "./processing.js";
|
|
17
|
+
export * from "./dataset/index.js";
|
|
18
|
+
export * from "./types/index.js";
|
|
19
|
+
export * as processing from "./processing/index.js";
|
package/dist/index.js
CHANGED
|
@@ -7,14 +7,13 @@ export * as aggregator from './aggregator/index.js';
|
|
|
7
7
|
export { WeightsContainer, aggregation } from './weights/index.js';
|
|
8
8
|
export { ConsoleLogger } from './logging/index.js';
|
|
9
9
|
export { Disco } from './training/index.js';
|
|
10
|
-
export { Validator } from './
|
|
10
|
+
export { Validator } from './validator.js';
|
|
11
11
|
export { Model, EpochLogs } from './models/index.js';
|
|
12
12
|
export * as models from './models/index.js';
|
|
13
13
|
export * from './task/index.js';
|
|
14
14
|
export * as defaultTasks from './default_tasks/index.js';
|
|
15
15
|
export * as async_iterator from "./utils/async_iterator.js";
|
|
16
16
|
export { EventEmitter } from "./utils/event_emitter.js";
|
|
17
|
-
export
|
|
18
|
-
export * from "./
|
|
19
|
-
export * from "./
|
|
20
|
-
export * as processing from "./processing.js";
|
|
17
|
+
export * from "./dataset/index.js";
|
|
18
|
+
export * from "./types/index.js";
|
|
19
|
+
export * as processing from "./processing/index.js";
|
|
@@ -1,17 +1,20 @@
|
|
|
1
1
|
/**
|
|
2
2
|
* this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
|
|
3
3
|
**/
|
|
4
|
-
import * as tf from
|
|
5
|
-
import {
|
|
6
|
-
import { WeightsContainer } from
|
|
4
|
+
import * as tf from "@tensorflow/tfjs";
|
|
5
|
+
import type { Batched, Dataset, DataFormat } from "../../index.js";
|
|
6
|
+
import { WeightsContainer } from "../../index.js";
|
|
7
7
|
import { BatchLogs, Model, EpochLogs } from "../index.js";
|
|
8
|
-
import
|
|
9
|
-
import { type GPTConfig } from './config.js';
|
|
8
|
+
import { type GPTConfig } from "./config.js";
|
|
10
9
|
export type GPTSerialization = {
|
|
11
10
|
weights: WeightsContainer;
|
|
12
11
|
config?: GPTConfig;
|
|
13
12
|
};
|
|
14
|
-
|
|
13
|
+
interface PredictConfig {
|
|
14
|
+
temperature: number;
|
|
15
|
+
doSample: boolean;
|
|
16
|
+
}
|
|
17
|
+
export declare class GPT extends Model<"text"> {
|
|
15
18
|
#private;
|
|
16
19
|
private readonly model;
|
|
17
20
|
constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
|
|
@@ -24,20 +27,14 @@ export declare class GPT extends Model {
|
|
|
24
27
|
* @param epochs the number of passes of the training dataset
|
|
25
28
|
* @param tracker
|
|
26
29
|
*/
|
|
27
|
-
train(
|
|
28
|
-
|
|
29
|
-
ys: tf.Tensor3D;
|
|
30
|
-
}>, validationData?: tf.data.Dataset<{
|
|
31
|
-
xs: tf.Tensor2D;
|
|
32
|
-
ys: tf.Tensor3D;
|
|
33
|
-
}>): AsyncGenerator<BatchLogs, EpochLogs>;
|
|
34
|
-
predict(input: Sample): Promise<Prediction>;
|
|
35
|
-
generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
|
|
30
|
+
train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded["text"]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded["text"]>>): AsyncGenerator<BatchLogs, EpochLogs>;
|
|
31
|
+
predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<PredictConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
|
|
36
32
|
get config(): Required<GPTConfig>;
|
|
37
33
|
get weights(): WeightsContainer;
|
|
38
34
|
set weights(ws: WeightsContainer);
|
|
39
|
-
static deserialize(data: GPTSerialization): Model
|
|
35
|
+
static deserialize(data: GPTSerialization): Model<"text">;
|
|
40
36
|
serialize(): GPTSerialization;
|
|
41
37
|
extract(): tf.LayersModel;
|
|
42
38
|
[Symbol.dispose](): void;
|
|
43
39
|
}
|
|
40
|
+
export {};
|