@epfml/discojs 3.0.1-p20250814105822.0 → 3.0.1-p20250924150135.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/get.d.ts +3 -3
- package/dist/aggregator/get.js +1 -2
- package/dist/client/client.d.ts +6 -6
- package/dist/client/decentralized/decentralized_client.d.ts +1 -1
- package/dist/client/decentralized/peer_pool.d.ts +1 -1
- package/dist/client/federated/federated_client.d.ts +1 -1
- package/dist/client/local_client.d.ts +1 -1
- package/dist/client/utils.d.ts +2 -2
- package/dist/client/utils.js +19 -10
- package/dist/default_tasks/cifar10.d.ts +2 -2
- package/dist/default_tasks/cifar10.js +9 -8
- package/dist/default_tasks/lus_covid.d.ts +2 -2
- package/dist/default_tasks/lus_covid.js +9 -8
- package/dist/default_tasks/mnist.d.ts +2 -2
- package/dist/default_tasks/mnist.js +9 -8
- package/dist/default_tasks/simple_face.d.ts +2 -2
- package/dist/default_tasks/simple_face.js +9 -8
- package/dist/default_tasks/tinder_dog.d.ts +1 -1
- package/dist/default_tasks/tinder_dog.js +12 -10
- package/dist/default_tasks/titanic.d.ts +2 -2
- package/dist/default_tasks/titanic.js +20 -33
- package/dist/default_tasks/wikitext.d.ts +2 -2
- package/dist/default_tasks/wikitext.js +16 -13
- package/dist/index.d.ts +1 -1
- package/dist/index.js +1 -1
- package/dist/models/gpt/config.d.ts +2 -2
- package/dist/models/hellaswag.d.ts +2 -3
- package/dist/models/hellaswag.js +3 -4
- package/dist/models/index.d.ts +2 -3
- package/dist/models/index.js +2 -3
- package/dist/models/tokenizer.d.ts +24 -14
- package/dist/models/tokenizer.js +42 -21
- package/dist/processing/index.d.ts +4 -5
- package/dist/processing/index.js +16 -21
- package/dist/serialization/coder.d.ts +5 -1
- package/dist/serialization/coder.js +4 -1
- package/dist/serialization/index.d.ts +4 -0
- package/dist/serialization/index.js +1 -0
- package/dist/serialization/task.d.ts +5 -0
- package/dist/serialization/task.js +34 -0
- package/dist/task/display_information.d.ts +91 -14
- package/dist/task/display_information.js +34 -58
- package/dist/task/index.d.ts +5 -5
- package/dist/task/index.js +4 -3
- package/dist/task/task.d.ts +837 -10
- package/dist/task/task.js +49 -21
- package/dist/task/task_handler.d.ts +4 -4
- package/dist/task/task_handler.js +14 -18
- package/dist/task/task_provider.d.ts +3 -3
- package/dist/task/training_information.d.ts +157 -35
- package/dist/task/training_information.js +85 -110
- package/dist/training/disco.d.ts +8 -8
- package/dist/training/disco.js +2 -1
- package/dist/training/trainer.d.ts +3 -3
- package/dist/training/trainer.js +2 -1
- package/dist/types/index.d.ts +1 -0
- package/dist/validator.d.ts +4 -4
- package/dist/validator.js +7 -6
- package/package.json +4 -7
- package/dist/processing/text.d.ts +0 -21
- package/dist/processing/text.js +0 -36
- package/dist/task/data_example.d.ts +0 -5
- package/dist/task/data_example.js +0 -14
- package/dist/task/summary.d.ts +0 -5
- package/dist/task/summary.js +0 -13
package/dist/task/task.js
CHANGED
|
@@ -1,22 +1,50 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import {
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
1
|
+
import { z } from "zod";
|
|
2
|
+
import { DisplayInformation } from "./display_information.js";
|
|
3
|
+
import { TrainingInformation } from "./training_information.js";
|
|
4
|
+
export var Task;
|
|
5
|
+
(function (Task) {
|
|
6
|
+
Task.baseSchema = z.object({
|
|
7
|
+
id: z.string(),
|
|
8
|
+
displayInformation: DisplayInformation.baseSchema,
|
|
9
|
+
trainingInformation: TrainingInformation.baseSchema,
|
|
10
|
+
});
|
|
11
|
+
Task.dataTypeToSchema = {
|
|
12
|
+
image: z.object({
|
|
13
|
+
dataType: z.literal("image"),
|
|
14
|
+
displayInformation: DisplayInformation.dataTypeToSchema.image,
|
|
15
|
+
trainingInformation: TrainingInformation.dataTypeToSchema.image,
|
|
16
|
+
}),
|
|
17
|
+
tabular: z.object({
|
|
18
|
+
dataType: z.literal("tabular"),
|
|
19
|
+
displayInformation: DisplayInformation.dataTypeToSchema.tabular,
|
|
20
|
+
trainingInformation: TrainingInformation.dataTypeToSchema.tabular,
|
|
21
|
+
}),
|
|
22
|
+
text: z.object({
|
|
23
|
+
dataType: z.literal("text"),
|
|
24
|
+
displayInformation: DisplayInformation.dataTypeToSchema.text,
|
|
25
|
+
trainingInformation: TrainingInformation.dataTypeToSchema.text,
|
|
26
|
+
}),
|
|
20
27
|
};
|
|
21
|
-
|
|
22
|
-
|
|
28
|
+
Task.networkToSchema = {
|
|
29
|
+
decentralized: z.object({
|
|
30
|
+
trainingInformation: TrainingInformation.networkToSchema.decentralized,
|
|
31
|
+
}),
|
|
32
|
+
federated: z.object({
|
|
33
|
+
trainingInformation: TrainingInformation.networkToSchema.federated,
|
|
34
|
+
}),
|
|
35
|
+
local: z.object({
|
|
36
|
+
trainingInformation: TrainingInformation.networkToSchema.local,
|
|
37
|
+
}),
|
|
38
|
+
};
|
|
39
|
+
Task.schema = Task.baseSchema
|
|
40
|
+
.and(z.union([
|
|
41
|
+
Task.dataTypeToSchema.image,
|
|
42
|
+
Task.dataTypeToSchema.tabular,
|
|
43
|
+
Task.dataTypeToSchema.text,
|
|
44
|
+
]))
|
|
45
|
+
.and(z.union([
|
|
46
|
+
Task.networkToSchema.decentralized,
|
|
47
|
+
Task.networkToSchema.federated,
|
|
48
|
+
Task.networkToSchema.local,
|
|
49
|
+
]));
|
|
50
|
+
})(Task || (Task = {}));
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { Map } from "immutable";
|
|
2
|
-
import type { DataType, Model } from "../index.js";
|
|
3
|
-
import type { Task
|
|
4
|
-
export declare function pushTask<D extends DataType>(base: URL, task: Task<D>, model: Model<D>): Promise<void>;
|
|
5
|
-
export declare function fetchTasks(base: URL): Promise<Map<
|
|
2
|
+
import type { DataType, Model, Network } from "../index.js";
|
|
3
|
+
import type { Task } from "./task.js";
|
|
4
|
+
export declare function pushTask<D extends DataType>(base: URL, task: Task<D, Network>, model: Model<D>): Promise<void>;
|
|
5
|
+
export declare function fetchTasks(base: URL): Promise<Map<Task.ID, Task<DataType, Network>>>;
|
|
@@ -1,8 +1,5 @@
|
|
|
1
|
-
import
|
|
2
|
-
import { Map } from "immutable";
|
|
1
|
+
import { Map, Seq } from "immutable";
|
|
3
2
|
import { serialization } from "../index.js";
|
|
4
|
-
import { isTask } from "./task.js";
|
|
5
|
-
const debug = createDebug("discojs:task:handlers");
|
|
6
3
|
function urlToTasks(base) {
|
|
7
4
|
const ret = new URL(base);
|
|
8
5
|
ret.pathname += "tasks";
|
|
@@ -11,10 +8,10 @@ function urlToTasks(base) {
|
|
|
11
8
|
export async function pushTask(base, task, model) {
|
|
12
9
|
const response = await fetch(urlToTasks(base), {
|
|
13
10
|
method: "POST",
|
|
11
|
+
headers: { "Content-Type": "application/json" },
|
|
14
12
|
body: JSON.stringify({
|
|
15
|
-
task,
|
|
16
|
-
model: await serialization.model.encode(model),
|
|
17
|
-
weights: await serialization.weights.encode(model.weights),
|
|
13
|
+
task: serialization.task.serializeToJSON(task),
|
|
14
|
+
model: [...(await serialization.model.encode(model))],
|
|
18
15
|
}),
|
|
19
16
|
});
|
|
20
17
|
if (!response.ok)
|
|
@@ -24,17 +21,16 @@ export async function fetchTasks(base) {
|
|
|
24
21
|
const response = await fetch(urlToTasks(base));
|
|
25
22
|
if (!response.ok)
|
|
26
23
|
throw new Error(`fetch: HTTP status ${response.status}`);
|
|
27
|
-
const
|
|
28
|
-
if (!Array.isArray(
|
|
29
|
-
throw new Error("
|
|
24
|
+
const json = (await response.json());
|
|
25
|
+
if (!Array.isArray(json))
|
|
26
|
+
throw new Error("invalid tasks response: expected a JSON array");
|
|
27
|
+
const arr = json;
|
|
28
|
+
try {
|
|
29
|
+
return Map(Seq(await Promise.all(arr.map((t) => serialization.task.deserializeFromJSON(t)))).map((t) => [t.id, t]));
|
|
30
30
|
}
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
}
|
|
36
|
-
}
|
|
37
|
-
throw new Error("invalid tasks response, the task object received is not well formatted");
|
|
31
|
+
catch (cause) {
|
|
32
|
+
throw new Error("invalid tasks response: unable to parse all tasks", {
|
|
33
|
+
cause,
|
|
34
|
+
});
|
|
38
35
|
}
|
|
39
|
-
return Map(tasks.map((t) => [t.id, t]));
|
|
40
36
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import type { DataType, Model, Task } from "../index.js";
|
|
2
|
-
export interface TaskProvider<D extends DataType> {
|
|
3
|
-
getTask(): Task<D
|
|
1
|
+
import type { DataType, Model, Network, Task } from "../index.js";
|
|
2
|
+
export interface TaskProvider<D extends DataType, N extends Network> {
|
|
3
|
+
getTask(): Promise<Task<D, N>>;
|
|
4
4
|
getModel(): Promise<Model<D>>;
|
|
5
5
|
}
|
|
@@ -1,38 +1,160 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { DataType } from "../index.js";
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
1
|
+
import { z } from "zod";
|
|
2
|
+
import type { DataType, Network } from "../index.js";
|
|
3
|
+
import { Tokenizer } from "../index.js";
|
|
4
|
+
export declare namespace TrainingInformation {
|
|
5
|
+
const baseSchema: z.ZodObject<{
|
|
6
|
+
epochs: z.ZodNumber;
|
|
7
|
+
roundDuration: z.ZodNumber;
|
|
8
|
+
validationSplit: z.ZodNumber;
|
|
9
|
+
batchSize: z.ZodNumber;
|
|
10
|
+
tensorBackend: z.ZodEnum<["gpt", "tfjs"]>;
|
|
11
|
+
}, "strip", z.ZodTypeAny, {
|
|
12
|
+
epochs: number;
|
|
13
|
+
roundDuration: number;
|
|
14
|
+
validationSplit: number;
|
|
15
|
+
batchSize: number;
|
|
16
|
+
tensorBackend: "gpt" | "tfjs";
|
|
17
|
+
}, {
|
|
18
|
+
epochs: number;
|
|
19
|
+
roundDuration: number;
|
|
20
|
+
validationSplit: number;
|
|
21
|
+
batchSize: number;
|
|
22
|
+
tensorBackend: "gpt" | "tfjs";
|
|
23
|
+
}>;
|
|
24
|
+
const dataTypeToSchema: {
|
|
25
|
+
image: z.ZodObject<{
|
|
26
|
+
LABEL_LIST: z.ZodArray<z.ZodString, "many">;
|
|
27
|
+
IMAGE_W: z.ZodNumber;
|
|
28
|
+
IMAGE_H: z.ZodNumber;
|
|
29
|
+
}, "strip", z.ZodTypeAny, {
|
|
30
|
+
LABEL_LIST: string[];
|
|
31
|
+
IMAGE_W: number;
|
|
32
|
+
IMAGE_H: number;
|
|
33
|
+
}, {
|
|
34
|
+
LABEL_LIST: string[];
|
|
35
|
+
IMAGE_W: number;
|
|
36
|
+
IMAGE_H: number;
|
|
37
|
+
}>;
|
|
38
|
+
tabular: z.ZodObject<{
|
|
39
|
+
inputColumns: z.ZodArray<z.ZodString, "many">;
|
|
40
|
+
outputColumn: z.ZodString;
|
|
41
|
+
}, "strip", z.ZodTypeAny, {
|
|
42
|
+
inputColumns: string[];
|
|
43
|
+
outputColumn: string;
|
|
44
|
+
}, {
|
|
45
|
+
inputColumns: string[];
|
|
46
|
+
outputColumn: string;
|
|
47
|
+
}>;
|
|
48
|
+
text: z.ZodObject<{
|
|
49
|
+
tokenizer: z.ZodType<Tokenizer, z.ZodTypeDef, Tokenizer>;
|
|
50
|
+
contextLength: z.ZodNumber;
|
|
51
|
+
}, "strip", z.ZodTypeAny, {
|
|
52
|
+
contextLength: number;
|
|
53
|
+
tokenizer: Tokenizer;
|
|
54
|
+
}, {
|
|
55
|
+
contextLength: number;
|
|
56
|
+
tokenizer: Tokenizer;
|
|
57
|
+
}>;
|
|
30
58
|
};
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
59
|
+
const networkToSchema: {
|
|
60
|
+
decentralized: z.ZodIntersection<z.ZodObject<{
|
|
61
|
+
scheme: z.ZodLiteral<"decentralized">;
|
|
62
|
+
} & {
|
|
63
|
+
privacy: z.ZodOptional<z.ZodEffects<z.ZodObject<{
|
|
64
|
+
clippingRadius: z.ZodOptional<z.ZodNumber>;
|
|
65
|
+
noiseScale: z.ZodOptional<z.ZodNumber>;
|
|
66
|
+
}, "strip", z.ZodTypeAny, {
|
|
67
|
+
clippingRadius?: number | undefined;
|
|
68
|
+
noiseScale?: number | undefined;
|
|
69
|
+
}, {
|
|
70
|
+
clippingRadius?: number | undefined;
|
|
71
|
+
noiseScale?: number | undefined;
|
|
72
|
+
}>, {
|
|
73
|
+
clippingRadius?: number | undefined;
|
|
74
|
+
noiseScale?: number | undefined;
|
|
75
|
+
} | undefined, {
|
|
76
|
+
clippingRadius?: number | undefined;
|
|
77
|
+
noiseScale?: number | undefined;
|
|
78
|
+
}>>;
|
|
79
|
+
minNbOfParticipants: z.ZodNumber;
|
|
80
|
+
}, "strip", z.ZodTypeAny, {
|
|
81
|
+
scheme: "decentralized";
|
|
82
|
+
minNbOfParticipants: number;
|
|
83
|
+
privacy?: {
|
|
84
|
+
clippingRadius?: number | undefined;
|
|
85
|
+
noiseScale?: number | undefined;
|
|
86
|
+
} | undefined;
|
|
87
|
+
}, {
|
|
88
|
+
scheme: "decentralized";
|
|
89
|
+
minNbOfParticipants: number;
|
|
90
|
+
privacy?: {
|
|
91
|
+
clippingRadius?: number | undefined;
|
|
92
|
+
noiseScale?: number | undefined;
|
|
93
|
+
} | undefined;
|
|
94
|
+
}>, z.ZodUnion<[z.ZodObject<{
|
|
95
|
+
aggregationStrategy: z.ZodLiteral<"mean">;
|
|
96
|
+
}, "strip", z.ZodTypeAny, {
|
|
97
|
+
aggregationStrategy: "mean";
|
|
98
|
+
}, {
|
|
99
|
+
aggregationStrategy: "mean";
|
|
100
|
+
}>, z.ZodObject<{
|
|
101
|
+
aggregationStrategy: z.ZodLiteral<"secure">;
|
|
102
|
+
maxShareValue: z.ZodDefault<z.ZodOptional<z.ZodNumber>>;
|
|
103
|
+
}, "strip", z.ZodTypeAny, {
|
|
104
|
+
aggregationStrategy: "secure";
|
|
105
|
+
maxShareValue: number;
|
|
106
|
+
}, {
|
|
107
|
+
aggregationStrategy: "secure";
|
|
108
|
+
maxShareValue?: number | undefined;
|
|
109
|
+
}>]>>;
|
|
110
|
+
federated: z.ZodObject<{
|
|
111
|
+
scheme: z.ZodLiteral<"federated">;
|
|
112
|
+
aggregationStrategy: z.ZodLiteral<"mean">;
|
|
113
|
+
} & {
|
|
114
|
+
privacy: z.ZodOptional<z.ZodEffects<z.ZodObject<{
|
|
115
|
+
clippingRadius: z.ZodOptional<z.ZodNumber>;
|
|
116
|
+
noiseScale: z.ZodOptional<z.ZodNumber>;
|
|
117
|
+
}, "strip", z.ZodTypeAny, {
|
|
118
|
+
clippingRadius?: number | undefined;
|
|
119
|
+
noiseScale?: number | undefined;
|
|
120
|
+
}, {
|
|
121
|
+
clippingRadius?: number | undefined;
|
|
122
|
+
noiseScale?: number | undefined;
|
|
123
|
+
}>, {
|
|
124
|
+
clippingRadius?: number | undefined;
|
|
125
|
+
noiseScale?: number | undefined;
|
|
126
|
+
} | undefined, {
|
|
127
|
+
clippingRadius?: number | undefined;
|
|
128
|
+
noiseScale?: number | undefined;
|
|
129
|
+
}>>;
|
|
130
|
+
minNbOfParticipants: z.ZodNumber;
|
|
131
|
+
}, "strip", z.ZodTypeAny, {
|
|
132
|
+
scheme: "federated";
|
|
133
|
+
minNbOfParticipants: number;
|
|
134
|
+
aggregationStrategy: "mean";
|
|
135
|
+
privacy?: {
|
|
136
|
+
clippingRadius?: number | undefined;
|
|
137
|
+
noiseScale?: number | undefined;
|
|
138
|
+
} | undefined;
|
|
139
|
+
}, {
|
|
140
|
+
scheme: "federated";
|
|
141
|
+
minNbOfParticipants: number;
|
|
142
|
+
aggregationStrategy: "mean";
|
|
143
|
+
privacy?: {
|
|
144
|
+
clippingRadius?: number | undefined;
|
|
145
|
+
noiseScale?: number | undefined;
|
|
146
|
+
} | undefined;
|
|
147
|
+
}>;
|
|
148
|
+
local: z.ZodObject<{
|
|
149
|
+
scheme: z.ZodLiteral<"local">;
|
|
150
|
+
aggregationStrategy: z.ZodLiteral<"mean">;
|
|
151
|
+
}, "strip", z.ZodTypeAny, {
|
|
152
|
+
scheme: "local";
|
|
153
|
+
aggregationStrategy: "mean";
|
|
154
|
+
}, {
|
|
155
|
+
scheme: "local";
|
|
156
|
+
aggregationStrategy: "mean";
|
|
157
|
+
}>;
|
|
35
158
|
};
|
|
36
159
|
}
|
|
37
|
-
export
|
|
38
|
-
export {};
|
|
160
|
+
export type TrainingInformation<D extends DataType, N extends Network> = z.infer<typeof TrainingInformation.baseSchema> & z.infer<(typeof TrainingInformation.dataTypeToSchema)[D]> & z.infer<(typeof TrainingInformation.networkToSchema)[N]>;
|
|
@@ -1,112 +1,87 @@
|
|
|
1
|
-
import {
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
noiseScale
|
|
1
|
+
import { z } from "zod";
|
|
2
|
+
import { Tokenizer } from "../index.js";
|
|
3
|
+
const nonLocalNetworkSchema = z.object({
|
|
4
|
+
// reduce training accuracy and improve privacy.
|
|
5
|
+
privacy: z
|
|
6
|
+
.object({
|
|
7
|
+
// maximum weights difference between each round
|
|
8
|
+
clippingRadius: z.number().optional(),
|
|
9
|
+
// variance of the Gaussian noise added to the shared weights.
|
|
10
|
+
noiseScale: z.number().optional(),
|
|
11
|
+
})
|
|
12
|
+
.transform((o) => o.clippingRadius === undefined && o.noiseScale === undefined
|
|
13
|
+
? undefined
|
|
14
|
+
: o)
|
|
15
|
+
.optional(),
|
|
16
|
+
// minimum number of participants required to train collaboratively
|
|
17
|
+
// In decentralized Learning the default is 3, in federated learning it is 2
|
|
18
|
+
minNbOfParticipants: z.number().positive().int(),
|
|
19
|
+
});
|
|
20
|
+
export var TrainingInformation;
|
|
21
|
+
(function (TrainingInformation) {
|
|
22
|
+
TrainingInformation.baseSchema = z.object({
|
|
23
|
+
// number of epochs to run training for
|
|
24
|
+
epochs: z.number().positive().int(),
|
|
25
|
+
// number of epochs between each weight sharing round.
|
|
26
|
+
// e.g.if 3 then weights are shared every 3 epochs (in the distributed setting).
|
|
27
|
+
roundDuration: z.number().positive().int(),
|
|
28
|
+
// fraction of data to keep for validation, note this only works for image data
|
|
29
|
+
validationSplit: z.number().min(0).max(1),
|
|
30
|
+
// batch size of training data
|
|
31
|
+
batchSize: z.number().positive().int(),
|
|
32
|
+
// Tensor framework used by the model
|
|
33
|
+
tensorBackend: z.enum(["gpt", "tfjs"]),
|
|
34
|
+
});
|
|
35
|
+
TrainingInformation.dataTypeToSchema = {
|
|
36
|
+
image: z.object({
|
|
37
|
+
// classes, e.g. if two class of images, one with dogs and one with cats, then we would
|
|
38
|
+
// define ['dogs', 'cats'].
|
|
39
|
+
LABEL_LIST: z.array(z.string()).min(1),
|
|
40
|
+
// height of image to resize to
|
|
41
|
+
IMAGE_W: z.number().positive().int(),
|
|
42
|
+
// width of image to resize to
|
|
43
|
+
IMAGE_H: z.number().positive().int(),
|
|
44
|
+
}),
|
|
45
|
+
tabular: z.object({
|
|
46
|
+
// the columns to be chosen as input data for the model
|
|
47
|
+
inputColumns: z.array(z.string()),
|
|
48
|
+
// the columns to be predicted by the model
|
|
49
|
+
outputColumn: z.string(),
|
|
50
|
+
}),
|
|
51
|
+
text: z.object({
|
|
52
|
+
// should be set with the name of a Transformers.js pre-trained tokenizer, e.g., 'Xenova/gpt2'.
|
|
53
|
+
tokenizer: z.instanceof(Tokenizer),
|
|
54
|
+
// the maximum length of a input string used as input to a GPT model. It is used during preprocessing to
|
|
55
|
+
// truncate strings to a maximum length. The default value is tokenizer.model_max_length
|
|
56
|
+
contextLength: z.number().positive().int(),
|
|
57
|
+
}),
|
|
13
58
|
};
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
break;
|
|
42
|
-
default:
|
|
43
|
-
return false;
|
|
44
|
-
}
|
|
45
|
-
switch (scheme) {
|
|
46
|
-
case "decentralized":
|
|
47
|
-
case "federated":
|
|
48
|
-
case "local":
|
|
49
|
-
break;
|
|
50
|
-
default:
|
|
51
|
-
return false;
|
|
52
|
-
}
|
|
53
|
-
const repack = {
|
|
54
|
-
aggregationStrategy,
|
|
55
|
-
batchSize,
|
|
56
|
-
epochs,
|
|
57
|
-
maxShareValue,
|
|
58
|
-
minNbOfParticipants,
|
|
59
|
-
privacy,
|
|
60
|
-
roundDuration,
|
|
61
|
-
scheme,
|
|
62
|
-
tensorBackend,
|
|
63
|
-
validationSplit,
|
|
59
|
+
TrainingInformation.networkToSchema = {
|
|
60
|
+
decentralized: z
|
|
61
|
+
.object({
|
|
62
|
+
scheme: z.literal("decentralized"),
|
|
63
|
+
})
|
|
64
|
+
.merge(nonLocalNetworkSchema)
|
|
65
|
+
.and(z.union([
|
|
66
|
+
z.object({
|
|
67
|
+
aggregationStrategy: z.literal("mean"),
|
|
68
|
+
}),
|
|
69
|
+
z.object({
|
|
70
|
+
aggregationStrategy: z.literal("secure"),
|
|
71
|
+
// Secure Aggregation: maximum absolute value of a number in a randomly generated share
|
|
72
|
+
// default is 100, must be a positive number, check the docs/PRIVACY.md file for more information on significance of maxShareValue selection
|
|
73
|
+
maxShareValue: z.number().positive().int().optional().default(100),
|
|
74
|
+
}),
|
|
75
|
+
])),
|
|
76
|
+
federated: z
|
|
77
|
+
.object({
|
|
78
|
+
scheme: z.literal("federated"),
|
|
79
|
+
aggregationStrategy: z.literal("mean"),
|
|
80
|
+
})
|
|
81
|
+
.merge(nonLocalNetworkSchema),
|
|
82
|
+
local: z.object({
|
|
83
|
+
scheme: z.literal("local"),
|
|
84
|
+
aggregationStrategy: z.literal("mean"),
|
|
85
|
+
}),
|
|
64
86
|
};
|
|
65
|
-
|
|
66
|
-
case "image": {
|
|
67
|
-
const { LABEL_LIST, IMAGE_W, IMAGE_H } = raw;
|
|
68
|
-
if (!(Array.isArray(LABEL_LIST) &&
|
|
69
|
-
LABEL_LIST.every((e) => typeof e === "string")) ||
|
|
70
|
-
typeof IMAGE_H !== "number" ||
|
|
71
|
-
typeof IMAGE_W !== "number")
|
|
72
|
-
return false;
|
|
73
|
-
const _ = {
|
|
74
|
-
...repack,
|
|
75
|
-
dataType,
|
|
76
|
-
LABEL_LIST,
|
|
77
|
-
IMAGE_W,
|
|
78
|
-
IMAGE_H,
|
|
79
|
-
};
|
|
80
|
-
return true;
|
|
81
|
-
}
|
|
82
|
-
case "tabular": {
|
|
83
|
-
const { inputColumns, outputColumn } = raw;
|
|
84
|
-
if (!(Array.isArray(inputColumns) &&
|
|
85
|
-
inputColumns.every((e) => typeof e === "string")) ||
|
|
86
|
-
typeof outputColumn !== "string")
|
|
87
|
-
return false;
|
|
88
|
-
const _ = {
|
|
89
|
-
...repack,
|
|
90
|
-
dataType,
|
|
91
|
-
inputColumns,
|
|
92
|
-
outputColumn,
|
|
93
|
-
};
|
|
94
|
-
return true;
|
|
95
|
-
}
|
|
96
|
-
case "text": {
|
|
97
|
-
const { contextLength, tokenizer, } = raw;
|
|
98
|
-
if ((typeof tokenizer !== "string" &&
|
|
99
|
-
!(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
100
|
-
(typeof contextLength !== "number"))
|
|
101
|
-
return false;
|
|
102
|
-
const _ = {
|
|
103
|
-
...repack,
|
|
104
|
-
dataType,
|
|
105
|
-
contextLength,
|
|
106
|
-
tokenizer,
|
|
107
|
-
};
|
|
108
|
-
return true;
|
|
109
|
-
}
|
|
110
|
-
}
|
|
111
|
-
return false;
|
|
112
|
-
}
|
|
87
|
+
})(TrainingInformation || (TrainingInformation = {}));
|
package/dist/training/disco.d.ts
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import { client as clients, BatchLogs, EpochLogs, Logger,
|
|
2
|
-
import type { DataFormat, DataType, Task } from "../index.js";
|
|
1
|
+
import { client as clients, BatchLogs, EpochLogs, Logger, Dataset } from "../index.js";
|
|
2
|
+
import type { DataFormat, DataType, Network, Task } from "../index.js";
|
|
3
3
|
import type { Aggregator } from "../aggregator/index.js";
|
|
4
4
|
import { EventEmitter } from "../utils/event_emitter.js";
|
|
5
5
|
import { RoundLogs, Trainer } from "./trainer.js";
|
|
6
|
-
interface DiscoConfig {
|
|
7
|
-
scheme:
|
|
6
|
+
interface DiscoConfig<N extends Network> {
|
|
7
|
+
scheme: N;
|
|
8
8
|
logger: Logger;
|
|
9
9
|
/**
|
|
10
10
|
* keep preprocessed dataset in memory while training
|
|
@@ -24,12 +24,12 @@ export type RoundStatus = 'not enough participants' | // Server notification to
|
|
|
24
24
|
* a convenient object providing a reduced yet complete API that wraps model training and
|
|
25
25
|
* communication with nodes.
|
|
26
26
|
*/
|
|
27
|
-
export declare class Disco<D extends DataType> extends EventEmitter<{
|
|
27
|
+
export declare class Disco<D extends DataType, N extends Network> extends EventEmitter<{
|
|
28
28
|
status: RoundStatus;
|
|
29
29
|
participants: number;
|
|
30
30
|
}> {
|
|
31
31
|
#private;
|
|
32
|
-
readonly trainer: Trainer<D>;
|
|
32
|
+
readonly trainer: Trainer<D, N>;
|
|
33
33
|
/**
|
|
34
34
|
* Connect to the given task and get ready to train.
|
|
35
35
|
*
|
|
@@ -37,10 +37,10 @@ export declare class Disco<D extends DataType> extends EventEmitter<{
|
|
|
37
37
|
* @param clientConfig client to connect with or parameters on how to create one.
|
|
38
38
|
* @param config the DiscoConfig
|
|
39
39
|
*/
|
|
40
|
-
constructor(task: Task<D>, clientConfig: clients.Client | URL | {
|
|
40
|
+
constructor(task: Task<D, N>, clientConfig: clients.Client<N> | URL | {
|
|
41
41
|
aggregator: Aggregator;
|
|
42
42
|
url: URL;
|
|
43
|
-
}, config: Partial<DiscoConfig
|
|
43
|
+
}, config: Partial<DiscoConfig<N>>);
|
|
44
44
|
/** Train on dataset, yielding logs of every round. */
|
|
45
45
|
trainByRound(dataset: Dataset<DataFormat.Raw[D]>): AsyncGenerator<RoundLogs>;
|
|
46
46
|
/** Train on dataset, yielding logs of every epoch. */
|
package/dist/training/disco.js
CHANGED
|
@@ -24,6 +24,7 @@ export class Disco extends EventEmitter {
|
|
|
24
24
|
constructor(task, clientConfig, config) {
|
|
25
25
|
super();
|
|
26
26
|
const { scheme, logger, preprocessOnce } = {
|
|
27
|
+
// cast as typescript is bad at generic
|
|
27
28
|
scheme: task.trainingInformation.scheme,
|
|
28
29
|
logger: new ConsoleLogger(),
|
|
29
30
|
preprocessOnce: false,
|
|
@@ -135,7 +136,7 @@ export class Disco extends EventEmitter {
|
|
|
135
136
|
}
|
|
136
137
|
async #preprocessSplitAndBatch(dataset) {
|
|
137
138
|
const { batchSize, validationSplit } = this.#task.trainingInformation;
|
|
138
|
-
let preprocessed =
|
|
139
|
+
let preprocessed = processing.preprocess(this.#task, dataset);
|
|
139
140
|
preprocessed = (this.#preprocessOnce
|
|
140
141
|
? new Dataset(await arrayFromAsync(preprocessed))
|
|
141
142
|
: preprocessed);
|
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
import { List } from "immutable";
|
|
2
|
-
import type { Batched, BatchLogs, Dataset, DataFormat, DataType, EpochLogs, Model, Task } from "../index.js";
|
|
2
|
+
import type { Batched, BatchLogs, Dataset, DataFormat, DataType, EpochLogs, Model, Task, Network } from "../index.js";
|
|
3
3
|
import { Client } from "../client/index.js";
|
|
4
4
|
export interface RoundLogs {
|
|
5
5
|
epochs: List<EpochLogs>;
|
|
6
6
|
participants: number;
|
|
7
7
|
}
|
|
8
8
|
/** Train a model and exchange with others **/
|
|
9
|
-
export declare class Trainer<D extends DataType> {
|
|
9
|
+
export declare class Trainer<D extends DataType, N extends Network> {
|
|
10
10
|
#private;
|
|
11
11
|
get model(): Model<D>;
|
|
12
12
|
set model(model: Model<D>);
|
|
13
|
-
constructor(task: Task<D>, client: Client);
|
|
13
|
+
constructor(task: Task<D, N>, client: Client<N>);
|
|
14
14
|
stopTraining(): Promise<void>;
|
|
15
15
|
train(dataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
|
|
16
16
|
}
|
package/dist/training/trainer.js
CHANGED
|
@@ -22,7 +22,8 @@ export class Trainer {
|
|
|
22
22
|
this.#client = client;
|
|
23
23
|
this.#roundDuration = task.trainingInformation.roundDuration;
|
|
24
24
|
this.#epochs = task.trainingInformation.epochs;
|
|
25
|
-
|
|
25
|
+
if ("privacy" in task.trainingInformation)
|
|
26
|
+
this.#privacy = task.trainingInformation.privacy;
|
|
26
27
|
if (!Number.isInteger(this.#epochs / this.#roundDuration))
|
|
27
28
|
throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`);
|
|
28
29
|
}
|
package/dist/types/index.d.ts
CHANGED
package/dist/validator.d.ts
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
import { Dataset, DataFormat, DataType, Model, Task } from "./index.js";
|
|
1
|
+
import type { Dataset, DataFormat, DataType, Model, Task, Network } from "./index.js";
|
|
2
2
|
export declare class Validator<D extends DataType> {
|
|
3
3
|
#private;
|
|
4
|
-
readonly task: Task<D>;
|
|
5
|
-
constructor(task: Task<D>, model: Model<D>);
|
|
4
|
+
readonly task: Task<D, Network>;
|
|
5
|
+
constructor(task: Task<D, Network>, model: Model<D>);
|
|
6
6
|
/** infer every line of the dataset and check that it is as labelled */
|
|
7
|
-
test(dataset: Dataset<DataFormat.Raw[D]>):
|
|
7
|
+
test(dataset: Dataset<DataFormat.Raw[D]>): Dataset<Record<"predicted" | "truth", DataFormat.Inferred[D]>>;
|
|
8
8
|
/** use the model to predict every line of the dataset */
|
|
9
9
|
infer(dataset: Dataset<DataFormat.RawWithoutLabel[D]>): AsyncGenerator<DataFormat.Inferred[D], void>;
|
|
10
10
|
}
|