@epfml/discojs 3.0.1-p20250814105822.0 → 3.0.1-p20250924150135.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. package/dist/aggregator/get.d.ts +3 -3
  2. package/dist/aggregator/get.js +1 -2
  3. package/dist/client/client.d.ts +6 -6
  4. package/dist/client/decentralized/decentralized_client.d.ts +1 -1
  5. package/dist/client/decentralized/peer_pool.d.ts +1 -1
  6. package/dist/client/federated/federated_client.d.ts +1 -1
  7. package/dist/client/local_client.d.ts +1 -1
  8. package/dist/client/utils.d.ts +2 -2
  9. package/dist/client/utils.js +19 -10
  10. package/dist/default_tasks/cifar10.d.ts +2 -2
  11. package/dist/default_tasks/cifar10.js +9 -8
  12. package/dist/default_tasks/lus_covid.d.ts +2 -2
  13. package/dist/default_tasks/lus_covid.js +9 -8
  14. package/dist/default_tasks/mnist.d.ts +2 -2
  15. package/dist/default_tasks/mnist.js +9 -8
  16. package/dist/default_tasks/simple_face.d.ts +2 -2
  17. package/dist/default_tasks/simple_face.js +9 -8
  18. package/dist/default_tasks/tinder_dog.d.ts +1 -1
  19. package/dist/default_tasks/tinder_dog.js +12 -10
  20. package/dist/default_tasks/titanic.d.ts +2 -2
  21. package/dist/default_tasks/titanic.js +20 -33
  22. package/dist/default_tasks/wikitext.d.ts +2 -2
  23. package/dist/default_tasks/wikitext.js +16 -13
  24. package/dist/index.d.ts +1 -1
  25. package/dist/index.js +1 -1
  26. package/dist/models/gpt/config.d.ts +2 -2
  27. package/dist/models/hellaswag.d.ts +2 -3
  28. package/dist/models/hellaswag.js +3 -4
  29. package/dist/models/index.d.ts +2 -3
  30. package/dist/models/index.js +2 -3
  31. package/dist/models/tokenizer.d.ts +24 -14
  32. package/dist/models/tokenizer.js +42 -21
  33. package/dist/processing/index.d.ts +4 -5
  34. package/dist/processing/index.js +16 -21
  35. package/dist/serialization/coder.d.ts +5 -1
  36. package/dist/serialization/coder.js +4 -1
  37. package/dist/serialization/index.d.ts +4 -0
  38. package/dist/serialization/index.js +1 -0
  39. package/dist/serialization/task.d.ts +5 -0
  40. package/dist/serialization/task.js +34 -0
  41. package/dist/task/display_information.d.ts +91 -14
  42. package/dist/task/display_information.js +34 -58
  43. package/dist/task/index.d.ts +5 -5
  44. package/dist/task/index.js +4 -3
  45. package/dist/task/task.d.ts +837 -10
  46. package/dist/task/task.js +49 -21
  47. package/dist/task/task_handler.d.ts +4 -4
  48. package/dist/task/task_handler.js +14 -18
  49. package/dist/task/task_provider.d.ts +3 -3
  50. package/dist/task/training_information.d.ts +157 -35
  51. package/dist/task/training_information.js +85 -110
  52. package/dist/training/disco.d.ts +8 -8
  53. package/dist/training/disco.js +2 -1
  54. package/dist/training/trainer.d.ts +3 -3
  55. package/dist/training/trainer.js +2 -1
  56. package/dist/types/index.d.ts +1 -0
  57. package/dist/validator.d.ts +4 -4
  58. package/dist/validator.js +7 -6
  59. package/package.json +4 -7
  60. package/dist/processing/text.d.ts +0 -21
  61. package/dist/processing/text.js +0 -36
  62. package/dist/task/data_example.d.ts +0 -5
  63. package/dist/task/data_example.js +0 -14
  64. package/dist/task/summary.d.ts +0 -5
  65. package/dist/task/summary.js +0 -13
package/dist/task/task.js CHANGED
@@ -1,22 +1,50 @@
1
- import { isDisplayInformation } from './display_information.js';
2
- import { isTrainingInformation } from './training_information.js';
3
- export function isTaskID(obj) {
4
- return typeof obj === 'string';
5
- }
6
- export function isTask(raw) {
7
- if (typeof raw !== 'object' || raw === null) {
8
- return false;
9
- }
10
- const { id, displayInformation, trainingInformation } = raw;
11
- if (!isTaskID(id) ||
12
- !isDisplayInformation(displayInformation) ||
13
- !isTrainingInformation(trainingInformation)) {
14
- return false;
15
- }
16
- const _ = {
17
- id,
18
- displayInformation,
19
- trainingInformation,
1
+ import { z } from "zod";
2
+ import { DisplayInformation } from "./display_information.js";
3
+ import { TrainingInformation } from "./training_information.js";
4
+ export var Task;
5
+ (function (Task) {
6
+ Task.baseSchema = z.object({
7
+ id: z.string(),
8
+ displayInformation: DisplayInformation.baseSchema,
9
+ trainingInformation: TrainingInformation.baseSchema,
10
+ });
11
+ Task.dataTypeToSchema = {
12
+ image: z.object({
13
+ dataType: z.literal("image"),
14
+ displayInformation: DisplayInformation.dataTypeToSchema.image,
15
+ trainingInformation: TrainingInformation.dataTypeToSchema.image,
16
+ }),
17
+ tabular: z.object({
18
+ dataType: z.literal("tabular"),
19
+ displayInformation: DisplayInformation.dataTypeToSchema.tabular,
20
+ trainingInformation: TrainingInformation.dataTypeToSchema.tabular,
21
+ }),
22
+ text: z.object({
23
+ dataType: z.literal("text"),
24
+ displayInformation: DisplayInformation.dataTypeToSchema.text,
25
+ trainingInformation: TrainingInformation.dataTypeToSchema.text,
26
+ }),
20
27
  };
21
- return true;
22
- }
28
+ Task.networkToSchema = {
29
+ decentralized: z.object({
30
+ trainingInformation: TrainingInformation.networkToSchema.decentralized,
31
+ }),
32
+ federated: z.object({
33
+ trainingInformation: TrainingInformation.networkToSchema.federated,
34
+ }),
35
+ local: z.object({
36
+ trainingInformation: TrainingInformation.networkToSchema.local,
37
+ }),
38
+ };
39
+ Task.schema = Task.baseSchema
40
+ .and(z.union([
41
+ Task.dataTypeToSchema.image,
42
+ Task.dataTypeToSchema.tabular,
43
+ Task.dataTypeToSchema.text,
44
+ ]))
45
+ .and(z.union([
46
+ Task.networkToSchema.decentralized,
47
+ Task.networkToSchema.federated,
48
+ Task.networkToSchema.local,
49
+ ]));
50
+ })(Task || (Task = {}));
@@ -1,5 +1,5 @@
1
1
  import { Map } from "immutable";
2
- import type { DataType, Model } from "../index.js";
3
- import type { Task, TaskID } from "./task.js";
4
- export declare function pushTask<D extends DataType>(base: URL, task: Task<D>, model: Model<D>): Promise<void>;
5
- export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task<DataType>>>;
2
+ import type { DataType, Model, Network } from "../index.js";
3
+ import type { Task } from "./task.js";
4
+ export declare function pushTask<D extends DataType>(base: URL, task: Task<D, Network>, model: Model<D>): Promise<void>;
5
+ export declare function fetchTasks(base: URL): Promise<Map<Task.ID, Task<DataType, Network>>>;
@@ -1,8 +1,5 @@
1
- import createDebug from "debug";
2
- import { Map } from "immutable";
1
+ import { Map, Seq } from "immutable";
3
2
  import { serialization } from "../index.js";
4
- import { isTask } from "./task.js";
5
- const debug = createDebug("discojs:task:handlers");
6
3
  function urlToTasks(base) {
7
4
  const ret = new URL(base);
8
5
  ret.pathname += "tasks";
@@ -11,10 +8,10 @@ function urlToTasks(base) {
11
8
  export async function pushTask(base, task, model) {
12
9
  const response = await fetch(urlToTasks(base), {
13
10
  method: "POST",
11
+ headers: { "Content-Type": "application/json" },
14
12
  body: JSON.stringify({
15
- task,
16
- model: await serialization.model.encode(model),
17
- weights: await serialization.weights.encode(model.weights),
13
+ task: serialization.task.serializeToJSON(task),
14
+ model: [...(await serialization.model.encode(model))],
18
15
  }),
19
16
  });
20
17
  if (!response.ok)
@@ -24,17 +21,16 @@ export async function fetchTasks(base) {
24
21
  const response = await fetch(urlToTasks(base));
25
22
  if (!response.ok)
26
23
  throw new Error(`fetch: HTTP status ${response.status}`);
27
- const tasks = await response.json();
28
- if (!Array.isArray(tasks)) {
29
- throw new Error("Expected to receive an array of Tasks when fetching tasks");
24
+ const json = (await response.json());
25
+ if (!Array.isArray(json))
26
+ throw new Error("invalid tasks response: expected a JSON array");
27
+ const arr = json;
28
+ try {
29
+ return Map(Seq(await Promise.all(arr.map((t) => serialization.task.deserializeFromJSON(t)))).map((t) => [t.id, t]));
30
30
  }
31
- else if (!tasks.every(isTask)) {
32
- for (const task of tasks) {
33
- if (!isTask(task)) {
34
- debug("task has invalid format: :O", task);
35
- }
36
- }
37
- throw new Error("invalid tasks response, the task object received is not well formatted");
31
+ catch (cause) {
32
+ throw new Error("invalid tasks response: unable to parse all tasks", {
33
+ cause,
34
+ });
38
35
  }
39
- return Map(tasks.map((t) => [t.id, t]));
40
36
  }
@@ -1,5 +1,5 @@
1
- import type { DataType, Model, Task } from "../index.js";
2
- export interface TaskProvider<D extends DataType> {
3
- getTask(): Task<D>;
1
+ import type { DataType, Model, Network, Task } from "../index.js";
2
+ export interface TaskProvider<D extends DataType, N extends Network> {
3
+ getTask(): Promise<Task<D, N>>;
4
4
  getModel(): Promise<Model<D>>;
5
5
  }
@@ -1,38 +1,160 @@
1
- import { PreTrainedTokenizer } from "@xenova/transformers";
2
- import { DataType } from "../index.js";
3
- interface Privacy {
4
- clippingRadius?: number;
5
- noiseScale?: number;
6
- }
7
- export type TrainingInformation<D extends DataType> = {
8
- epochs: number;
9
- roundDuration: number;
10
- validationSplit: number;
11
- batchSize: number;
12
- scheme: "decentralized" | "federated" | "local";
13
- privacy?: Privacy;
14
- maxShareValue?: number;
15
- minNbOfParticipants: number;
16
- aggregationStrategy?: "mean" | "secure";
17
- tensorBackend: "tfjs" | "gpt";
18
- } & DataTypeToTrainingInformation[D];
19
- interface DataTypeToTrainingInformation {
20
- image: {
21
- dataType: "image";
22
- LABEL_LIST: string[];
23
- IMAGE_H: number;
24
- IMAGE_W: number;
25
- };
26
- tabular: {
27
- dataType: "tabular";
28
- inputColumns: string[];
29
- outputColumn: string;
1
+ import { z } from "zod";
2
+ import type { DataType, Network } from "../index.js";
3
+ import { Tokenizer } from "../index.js";
4
+ export declare namespace TrainingInformation {
5
+ const baseSchema: z.ZodObject<{
6
+ epochs: z.ZodNumber;
7
+ roundDuration: z.ZodNumber;
8
+ validationSplit: z.ZodNumber;
9
+ batchSize: z.ZodNumber;
10
+ tensorBackend: z.ZodEnum<["gpt", "tfjs"]>;
11
+ }, "strip", z.ZodTypeAny, {
12
+ epochs: number;
13
+ roundDuration: number;
14
+ validationSplit: number;
15
+ batchSize: number;
16
+ tensorBackend: "gpt" | "tfjs";
17
+ }, {
18
+ epochs: number;
19
+ roundDuration: number;
20
+ validationSplit: number;
21
+ batchSize: number;
22
+ tensorBackend: "gpt" | "tfjs";
23
+ }>;
24
+ const dataTypeToSchema: {
25
+ image: z.ZodObject<{
26
+ LABEL_LIST: z.ZodArray<z.ZodString, "many">;
27
+ IMAGE_W: z.ZodNumber;
28
+ IMAGE_H: z.ZodNumber;
29
+ }, "strip", z.ZodTypeAny, {
30
+ LABEL_LIST: string[];
31
+ IMAGE_W: number;
32
+ IMAGE_H: number;
33
+ }, {
34
+ LABEL_LIST: string[];
35
+ IMAGE_W: number;
36
+ IMAGE_H: number;
37
+ }>;
38
+ tabular: z.ZodObject<{
39
+ inputColumns: z.ZodArray<z.ZodString, "many">;
40
+ outputColumn: z.ZodString;
41
+ }, "strip", z.ZodTypeAny, {
42
+ inputColumns: string[];
43
+ outputColumn: string;
44
+ }, {
45
+ inputColumns: string[];
46
+ outputColumn: string;
47
+ }>;
48
+ text: z.ZodObject<{
49
+ tokenizer: z.ZodType<Tokenizer, z.ZodTypeDef, Tokenizer>;
50
+ contextLength: z.ZodNumber;
51
+ }, "strip", z.ZodTypeAny, {
52
+ contextLength: number;
53
+ tokenizer: Tokenizer;
54
+ }, {
55
+ contextLength: number;
56
+ tokenizer: Tokenizer;
57
+ }>;
30
58
  };
31
- text: {
32
- dataType: "text";
33
- tokenizer: string | PreTrainedTokenizer;
34
- contextLength: number;
59
+ const networkToSchema: {
60
+ decentralized: z.ZodIntersection<z.ZodObject<{
61
+ scheme: z.ZodLiteral<"decentralized">;
62
+ } & {
63
+ privacy: z.ZodOptional<z.ZodEffects<z.ZodObject<{
64
+ clippingRadius: z.ZodOptional<z.ZodNumber>;
65
+ noiseScale: z.ZodOptional<z.ZodNumber>;
66
+ }, "strip", z.ZodTypeAny, {
67
+ clippingRadius?: number | undefined;
68
+ noiseScale?: number | undefined;
69
+ }, {
70
+ clippingRadius?: number | undefined;
71
+ noiseScale?: number | undefined;
72
+ }>, {
73
+ clippingRadius?: number | undefined;
74
+ noiseScale?: number | undefined;
75
+ } | undefined, {
76
+ clippingRadius?: number | undefined;
77
+ noiseScale?: number | undefined;
78
+ }>>;
79
+ minNbOfParticipants: z.ZodNumber;
80
+ }, "strip", z.ZodTypeAny, {
81
+ scheme: "decentralized";
82
+ minNbOfParticipants: number;
83
+ privacy?: {
84
+ clippingRadius?: number | undefined;
85
+ noiseScale?: number | undefined;
86
+ } | undefined;
87
+ }, {
88
+ scheme: "decentralized";
89
+ minNbOfParticipants: number;
90
+ privacy?: {
91
+ clippingRadius?: number | undefined;
92
+ noiseScale?: number | undefined;
93
+ } | undefined;
94
+ }>, z.ZodUnion<[z.ZodObject<{
95
+ aggregationStrategy: z.ZodLiteral<"mean">;
96
+ }, "strip", z.ZodTypeAny, {
97
+ aggregationStrategy: "mean";
98
+ }, {
99
+ aggregationStrategy: "mean";
100
+ }>, z.ZodObject<{
101
+ aggregationStrategy: z.ZodLiteral<"secure">;
102
+ maxShareValue: z.ZodDefault<z.ZodOptional<z.ZodNumber>>;
103
+ }, "strip", z.ZodTypeAny, {
104
+ aggregationStrategy: "secure";
105
+ maxShareValue: number;
106
+ }, {
107
+ aggregationStrategy: "secure";
108
+ maxShareValue?: number | undefined;
109
+ }>]>>;
110
+ federated: z.ZodObject<{
111
+ scheme: z.ZodLiteral<"federated">;
112
+ aggregationStrategy: z.ZodLiteral<"mean">;
113
+ } & {
114
+ privacy: z.ZodOptional<z.ZodEffects<z.ZodObject<{
115
+ clippingRadius: z.ZodOptional<z.ZodNumber>;
116
+ noiseScale: z.ZodOptional<z.ZodNumber>;
117
+ }, "strip", z.ZodTypeAny, {
118
+ clippingRadius?: number | undefined;
119
+ noiseScale?: number | undefined;
120
+ }, {
121
+ clippingRadius?: number | undefined;
122
+ noiseScale?: number | undefined;
123
+ }>, {
124
+ clippingRadius?: number | undefined;
125
+ noiseScale?: number | undefined;
126
+ } | undefined, {
127
+ clippingRadius?: number | undefined;
128
+ noiseScale?: number | undefined;
129
+ }>>;
130
+ minNbOfParticipants: z.ZodNumber;
131
+ }, "strip", z.ZodTypeAny, {
132
+ scheme: "federated";
133
+ minNbOfParticipants: number;
134
+ aggregationStrategy: "mean";
135
+ privacy?: {
136
+ clippingRadius?: number | undefined;
137
+ noiseScale?: number | undefined;
138
+ } | undefined;
139
+ }, {
140
+ scheme: "federated";
141
+ minNbOfParticipants: number;
142
+ aggregationStrategy: "mean";
143
+ privacy?: {
144
+ clippingRadius?: number | undefined;
145
+ noiseScale?: number | undefined;
146
+ } | undefined;
147
+ }>;
148
+ local: z.ZodObject<{
149
+ scheme: z.ZodLiteral<"local">;
150
+ aggregationStrategy: z.ZodLiteral<"mean">;
151
+ }, "strip", z.ZodTypeAny, {
152
+ scheme: "local";
153
+ aggregationStrategy: "mean";
154
+ }, {
155
+ scheme: "local";
156
+ aggregationStrategy: "mean";
157
+ }>;
35
158
  };
36
159
  }
37
- export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation<DataType>;
38
- export {};
160
+ export type TrainingInformation<D extends DataType, N extends Network> = z.infer<typeof TrainingInformation.baseSchema> & z.infer<(typeof TrainingInformation.dataTypeToSchema)[D]> & z.infer<(typeof TrainingInformation.networkToSchema)[N]>;
@@ -1,112 +1,87 @@
1
- import { PreTrainedTokenizer } from "@xenova/transformers";
2
- function isPrivacy(raw) {
3
- if (typeof raw !== "object" || raw === null) {
4
- return false;
5
- }
6
- const { clippingRadius, noiseScale, } = raw;
7
- if ((clippingRadius !== undefined && typeof clippingRadius !== "number") ||
8
- (noiseScale !== undefined && typeof noiseScale !== "number"))
9
- return false;
10
- const _ = {
11
- clippingRadius,
12
- noiseScale,
1
+ import { z } from "zod";
2
+ import { Tokenizer } from "../index.js";
3
+ const nonLocalNetworkSchema = z.object({
4
+ // reduce training accuracy and improve privacy.
5
+ privacy: z
6
+ .object({
7
+ // maximum weights difference between each round
8
+ clippingRadius: z.number().optional(),
9
+ // variance of the Gaussian noise added to the shared weights.
10
+ noiseScale: z.number().optional(),
11
+ })
12
+ .transform((o) => o.clippingRadius === undefined && o.noiseScale === undefined
13
+ ? undefined
14
+ : o)
15
+ .optional(),
16
+ // minimum number of participants required to train collaboratively
17
+ // In decentralized Learning the default is 3, in federated learning it is 2
18
+ minNbOfParticipants: z.number().positive().int(),
19
+ });
20
+ export var TrainingInformation;
21
+ (function (TrainingInformation) {
22
+ TrainingInformation.baseSchema = z.object({
23
+ // number of epochs to run training for
24
+ epochs: z.number().positive().int(),
25
+ // number of epochs between each weight sharing round.
26
+ // e.g.if 3 then weights are shared every 3 epochs (in the distributed setting).
27
+ roundDuration: z.number().positive().int(),
28
+ // fraction of data to keep for validation, note this only works for image data
29
+ validationSplit: z.number().min(0).max(1),
30
+ // batch size of training data
31
+ batchSize: z.number().positive().int(),
32
+ // Tensor framework used by the model
33
+ tensorBackend: z.enum(["gpt", "tfjs"]),
34
+ });
35
+ TrainingInformation.dataTypeToSchema = {
36
+ image: z.object({
37
+ // classes, e.g. if two class of images, one with dogs and one with cats, then we would
38
+ // define ['dogs', 'cats'].
39
+ LABEL_LIST: z.array(z.string()).min(1),
40
+ // height of image to resize to
41
+ IMAGE_W: z.number().positive().int(),
42
+ // width of image to resize to
43
+ IMAGE_H: z.number().positive().int(),
44
+ }),
45
+ tabular: z.object({
46
+ // the columns to be chosen as input data for the model
47
+ inputColumns: z.array(z.string()),
48
+ // the columns to be predicted by the model
49
+ outputColumn: z.string(),
50
+ }),
51
+ text: z.object({
52
+ // should be set with the name of a Transformers.js pre-trained tokenizer, e.g., 'Xenova/gpt2'.
53
+ tokenizer: z.instanceof(Tokenizer),
54
+ // the maximum length of a input string used as input to a GPT model. It is used during preprocessing to
55
+ // truncate strings to a maximum length. The default value is tokenizer.model_max_length
56
+ contextLength: z.number().positive().int(),
57
+ }),
13
58
  };
14
- return true;
15
- }
16
- export function isTrainingInformation(raw) {
17
- if (typeof raw !== "object" || raw === null) {
18
- return false;
19
- }
20
- const { aggregationStrategy, batchSize, dataType, privacy, epochs, maxShareValue, minNbOfParticipants, roundDuration, scheme, validationSplit, tensorBackend, } = raw;
21
- if (typeof epochs !== "number" ||
22
- typeof batchSize !== "number" ||
23
- typeof roundDuration !== "number" ||
24
- typeof validationSplit !== "number" ||
25
- typeof minNbOfParticipants !== "number" ||
26
- (privacy !== undefined && !isPrivacy(privacy)) ||
27
- (maxShareValue !== undefined && typeof maxShareValue !== "number")) {
28
- return false;
29
- }
30
- switch (aggregationStrategy) {
31
- case undefined:
32
- case "mean":
33
- case "secure":
34
- break;
35
- default:
36
- return false;
37
- }
38
- switch (tensorBackend) {
39
- case "tfjs":
40
- case "gpt":
41
- break;
42
- default:
43
- return false;
44
- }
45
- switch (scheme) {
46
- case "decentralized":
47
- case "federated":
48
- case "local":
49
- break;
50
- default:
51
- return false;
52
- }
53
- const repack = {
54
- aggregationStrategy,
55
- batchSize,
56
- epochs,
57
- maxShareValue,
58
- minNbOfParticipants,
59
- privacy,
60
- roundDuration,
61
- scheme,
62
- tensorBackend,
63
- validationSplit,
59
+ TrainingInformation.networkToSchema = {
60
+ decentralized: z
61
+ .object({
62
+ scheme: z.literal("decentralized"),
63
+ })
64
+ .merge(nonLocalNetworkSchema)
65
+ .and(z.union([
66
+ z.object({
67
+ aggregationStrategy: z.literal("mean"),
68
+ }),
69
+ z.object({
70
+ aggregationStrategy: z.literal("secure"),
71
+ // Secure Aggregation: maximum absolute value of a number in a randomly generated share
72
+ // default is 100, must be a positive number, check the docs/PRIVACY.md file for more information on significance of maxShareValue selection
73
+ maxShareValue: z.number().positive().int().optional().default(100),
74
+ }),
75
+ ])),
76
+ federated: z
77
+ .object({
78
+ scheme: z.literal("federated"),
79
+ aggregationStrategy: z.literal("mean"),
80
+ })
81
+ .merge(nonLocalNetworkSchema),
82
+ local: z.object({
83
+ scheme: z.literal("local"),
84
+ aggregationStrategy: z.literal("mean"),
85
+ }),
64
86
  };
65
- switch (dataType) {
66
- case "image": {
67
- const { LABEL_LIST, IMAGE_W, IMAGE_H } = raw;
68
- if (!(Array.isArray(LABEL_LIST) &&
69
- LABEL_LIST.every((e) => typeof e === "string")) ||
70
- typeof IMAGE_H !== "number" ||
71
- typeof IMAGE_W !== "number")
72
- return false;
73
- const _ = {
74
- ...repack,
75
- dataType,
76
- LABEL_LIST,
77
- IMAGE_W,
78
- IMAGE_H,
79
- };
80
- return true;
81
- }
82
- case "tabular": {
83
- const { inputColumns, outputColumn } = raw;
84
- if (!(Array.isArray(inputColumns) &&
85
- inputColumns.every((e) => typeof e === "string")) ||
86
- typeof outputColumn !== "string")
87
- return false;
88
- const _ = {
89
- ...repack,
90
- dataType,
91
- inputColumns,
92
- outputColumn,
93
- };
94
- return true;
95
- }
96
- case "text": {
97
- const { contextLength, tokenizer, } = raw;
98
- if ((typeof tokenizer !== "string" &&
99
- !(tokenizer instanceof PreTrainedTokenizer)) ||
100
- (typeof contextLength !== "number"))
101
- return false;
102
- const _ = {
103
- ...repack,
104
- dataType,
105
- contextLength,
106
- tokenizer,
107
- };
108
- return true;
109
- }
110
- }
111
- return false;
112
- }
87
+ })(TrainingInformation || (TrainingInformation = {}));
@@ -1,10 +1,10 @@
1
- import { client as clients, BatchLogs, EpochLogs, Logger, TrainingInformation, Dataset } from "../index.js";
2
- import type { DataFormat, DataType, Task } from "../index.js";
1
+ import { client as clients, BatchLogs, EpochLogs, Logger, Dataset } from "../index.js";
2
+ import type { DataFormat, DataType, Network, Task } from "../index.js";
3
3
  import type { Aggregator } from "../aggregator/index.js";
4
4
  import { EventEmitter } from "../utils/event_emitter.js";
5
5
  import { RoundLogs, Trainer } from "./trainer.js";
6
- interface DiscoConfig {
7
- scheme: TrainingInformation<DataType>["scheme"];
6
+ interface DiscoConfig<N extends Network> {
7
+ scheme: N;
8
8
  logger: Logger;
9
9
  /**
10
10
  * keep preprocessed dataset in memory while training
@@ -24,12 +24,12 @@ export type RoundStatus = 'not enough participants' | // Server notification to
24
24
  * a convenient object providing a reduced yet complete API that wraps model training and
25
25
  * communication with nodes.
26
26
  */
27
- export declare class Disco<D extends DataType> extends EventEmitter<{
27
+ export declare class Disco<D extends DataType, N extends Network> extends EventEmitter<{
28
28
  status: RoundStatus;
29
29
  participants: number;
30
30
  }> {
31
31
  #private;
32
- readonly trainer: Trainer<D>;
32
+ readonly trainer: Trainer<D, N>;
33
33
  /**
34
34
  * Connect to the given task and get ready to train.
35
35
  *
@@ -37,10 +37,10 @@ export declare class Disco<D extends DataType> extends EventEmitter<{
37
37
  * @param clientConfig client to connect with or parameters on how to create one.
38
38
  * @param config the DiscoConfig
39
39
  */
40
- constructor(task: Task<D>, clientConfig: clients.Client | URL | {
40
+ constructor(task: Task<D, N>, clientConfig: clients.Client<N> | URL | {
41
41
  aggregator: Aggregator;
42
42
  url: URL;
43
- }, config: Partial<DiscoConfig>);
43
+ }, config: Partial<DiscoConfig<N>>);
44
44
  /** Train on dataset, yielding logs of every round. */
45
45
  trainByRound(dataset: Dataset<DataFormat.Raw[D]>): AsyncGenerator<RoundLogs>;
46
46
  /** Train on dataset, yielding logs of every epoch. */
@@ -24,6 +24,7 @@ export class Disco extends EventEmitter {
24
24
  constructor(task, clientConfig, config) {
25
25
  super();
26
26
  const { scheme, logger, preprocessOnce } = {
27
+ // cast as typescript is bad at generic
27
28
  scheme: task.trainingInformation.scheme,
28
29
  logger: new ConsoleLogger(),
29
30
  preprocessOnce: false,
@@ -135,7 +136,7 @@ export class Disco extends EventEmitter {
135
136
  }
136
137
  async #preprocessSplitAndBatch(dataset) {
137
138
  const { batchSize, validationSplit } = this.#task.trainingInformation;
138
- let preprocessed = await processing.preprocess(this.#task, dataset);
139
+ let preprocessed = processing.preprocess(this.#task, dataset);
139
140
  preprocessed = (this.#preprocessOnce
140
141
  ? new Dataset(await arrayFromAsync(preprocessed))
141
142
  : preprocessed);
@@ -1,16 +1,16 @@
1
1
  import { List } from "immutable";
2
- import type { Batched, BatchLogs, Dataset, DataFormat, DataType, EpochLogs, Model, Task } from "../index.js";
2
+ import type { Batched, BatchLogs, Dataset, DataFormat, DataType, EpochLogs, Model, Task, Network } from "../index.js";
3
3
  import { Client } from "../client/index.js";
4
4
  export interface RoundLogs {
5
5
  epochs: List<EpochLogs>;
6
6
  participants: number;
7
7
  }
8
8
  /** Train a model and exchange with others **/
9
- export declare class Trainer<D extends DataType> {
9
+ export declare class Trainer<D extends DataType, N extends Network> {
10
10
  #private;
11
11
  get model(): Model<D>;
12
12
  set model(model: Model<D>);
13
- constructor(task: Task<D>, client: Client);
13
+ constructor(task: Task<D, N>, client: Client<N>);
14
14
  stopTraining(): Promise<void>;
15
15
  train(dataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded[D]>>): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>, void>;
16
16
  }
@@ -22,7 +22,8 @@ export class Trainer {
22
22
  this.#client = client;
23
23
  this.#roundDuration = task.trainingInformation.roundDuration;
24
24
  this.#epochs = task.trainingInformation.epochs;
25
- this.#privacy = task.trainingInformation.privacy;
25
+ if ("privacy" in task.trainingInformation)
26
+ this.#privacy = task.trainingInformation.privacy;
26
27
  if (!Number.isInteger(this.#epochs / this.#roundDuration))
27
28
  throw new Error(`round duration ${this.#roundDuration} doesn't divide number of epochs ${this.#epochs}`);
28
29
  }
@@ -1,2 +1,3 @@
1
1
  export * as DataFormat from "./data_format.js";
2
2
  export type DataType = "image" | "tabular" | "text";
3
+ export type Network = "decentralized" | "federated" | "local";
@@ -1,10 +1,10 @@
1
- import { Dataset, DataFormat, DataType, Model, Task } from "./index.js";
1
+ import type { Dataset, DataFormat, DataType, Model, Task, Network } from "./index.js";
2
2
  export declare class Validator<D extends DataType> {
3
3
  #private;
4
- readonly task: Task<D>;
5
- constructor(task: Task<D>, model: Model<D>);
4
+ readonly task: Task<D, Network>;
5
+ constructor(task: Task<D, Network>, model: Model<D>);
6
6
  /** infer every line of the dataset and check that it is as labelled */
7
- test(dataset: Dataset<DataFormat.Raw[D]>): Promise<Dataset<Record<"predicted" | "truth", DataFormat.Inferred[D]>>>;
7
+ test(dataset: Dataset<DataFormat.Raw[D]>): Dataset<Record<"predicted" | "truth", DataFormat.Inferred[D]>>;
8
8
  /** use the model to predict every line of the dataset */
9
9
  infer(dataset: Dataset<DataFormat.RawWithoutLabel[D]>): AsyncGenerator<DataFormat.Inferred[D], void>;
10
10
  }