@epfml/discojs 3.0.1-p20241025115642.0 → 3.0.1-p20241107104659.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. package/dist/aggregator/get.d.ts +3 -3
  2. package/dist/client/client.d.ts +5 -5
  3. package/dist/client/decentralized/decentralized_client.d.ts +2 -2
  4. package/dist/client/federated/federated_client.d.ts +2 -2
  5. package/dist/client/utils.d.ts +2 -2
  6. package/dist/dataset/dataset.d.ts +9 -2
  7. package/dist/dataset/dataset.js +83 -36
  8. package/dist/dataset/image.d.ts +5 -0
  9. package/dist/dataset/image.js +6 -1
  10. package/dist/dataset/index.d.ts +0 -1
  11. package/dist/dataset/index.js +0 -1
  12. package/dist/dataset/types.d.ts +2 -0
  13. package/dist/default_tasks/cifar10.d.ts +1 -1
  14. package/dist/default_tasks/cifar10.js +2 -3
  15. package/dist/default_tasks/lus_covid.d.ts +1 -1
  16. package/dist/default_tasks/lus_covid.js +2 -3
  17. package/dist/default_tasks/mnist.d.ts +1 -1
  18. package/dist/default_tasks/mnist.js +3 -5
  19. package/dist/default_tasks/simple_face.d.ts +1 -1
  20. package/dist/default_tasks/simple_face.js +2 -3
  21. package/dist/default_tasks/titanic.d.ts +1 -1
  22. package/dist/default_tasks/titanic.js +3 -6
  23. package/dist/default_tasks/wikitext.d.ts +1 -1
  24. package/dist/default_tasks/wikitext.js +1 -2
  25. package/dist/index.d.ts +4 -5
  26. package/dist/index.js +4 -5
  27. package/dist/models/gpt/index.d.ts +13 -16
  28. package/dist/models/gpt/index.js +62 -43
  29. package/dist/models/gpt/model.d.ts +1 -15
  30. package/dist/models/gpt/model.js +1 -75
  31. package/dist/models/model.d.ts +7 -12
  32. package/dist/models/tfjs.d.ts +10 -8
  33. package/dist/models/tfjs.js +106 -44
  34. package/dist/models/tokenizer.d.ts +1 -1
  35. package/dist/privacy.js +1 -1
  36. package/dist/processing/image.d.ts +18 -0
  37. package/dist/processing/image.js +75 -0
  38. package/dist/processing/index.d.ts +8 -0
  39. package/dist/processing/index.js +106 -0
  40. package/dist/processing/tabular.d.ts +19 -0
  41. package/dist/processing/tabular.js +33 -0
  42. package/dist/processing/text.d.ts +11 -0
  43. package/dist/processing/text.js +33 -0
  44. package/dist/serialization/model.d.ts +3 -3
  45. package/dist/serialization/model.js +19 -6
  46. package/dist/task/task.d.ts +4 -3
  47. package/dist/task/task.js +5 -3
  48. package/dist/task/task_handler.d.ts +3 -3
  49. package/dist/task/task_provider.d.ts +4 -4
  50. package/dist/task/training_information.d.ts +25 -16
  51. package/dist/task/training_information.js +76 -72
  52. package/dist/training/disco.d.ts +20 -12
  53. package/dist/training/disco.js +32 -13
  54. package/dist/training/trainer.d.ts +6 -7
  55. package/dist/training/trainer.js +6 -6
  56. package/dist/types/data_format.d.ts +40 -0
  57. package/dist/types/index.d.ts +2 -0
  58. package/dist/types/index.js +1 -0
  59. package/dist/validator.d.ts +10 -0
  60. package/dist/validator.js +30 -0
  61. package/package.json +4 -2
  62. package/dist/dataset/data/data.d.ts +0 -47
  63. package/dist/dataset/data/data.js +0 -88
  64. package/dist/dataset/data/data_split.d.ts +0 -8
  65. package/dist/dataset/data/helpers.d.ts +0 -10
  66. package/dist/dataset/data/helpers.js +0 -97
  67. package/dist/dataset/data/image_data.d.ts +0 -11
  68. package/dist/dataset/data/image_data.js +0 -43
  69. package/dist/dataset/data/index.d.ts +0 -5
  70. package/dist/dataset/data/index.js +0 -5
  71. package/dist/dataset/data/preprocessing/base.d.ts +0 -16
  72. package/dist/dataset/data/preprocessing/base.js +0 -1
  73. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +0 -13
  74. package/dist/dataset/data/preprocessing/image_preprocessing.js +0 -42
  75. package/dist/dataset/data/preprocessing/index.d.ts +0 -4
  76. package/dist/dataset/data/preprocessing/index.js +0 -3
  77. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +0 -13
  78. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +0 -45
  79. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +0 -13
  80. package/dist/dataset/data/preprocessing/text_preprocessing.js +0 -100
  81. package/dist/dataset/data/tabular_data.d.ts +0 -11
  82. package/dist/dataset/data/tabular_data.js +0 -24
  83. package/dist/dataset/data/text_data.d.ts +0 -11
  84. package/dist/dataset/data/text_data.js +0 -14
  85. package/dist/processing.d.ts +0 -35
  86. package/dist/processing.js +0 -89
  87. package/dist/types.d.ts +0 -3
  88. package/dist/types.js +0 -1
  89. package/dist/validation/index.d.ts +0 -1
  90. package/dist/validation/index.js +0 -1
  91. package/dist/validation/validator.d.ts +0 -10
  92. package/dist/validation/validator.js +0 -113
  93. /package/dist/{dataset/data/data_split.js → types/data_format.js} +0 -0
@@ -0,0 +1,75 @@
1
+ import { Repeat, Seq } from "immutable";
2
+ import { createJimp } from "@jimp/core";
3
+ import * as jimpResize from "@jimp/plugin-resize";
4
+ import { Image } from "../index.js";
5
+ /** Image where intensity is represented in the range 0..1 */
6
+ export class NormalizedImage {
7
+ data;
8
+ width;
9
+ height;
10
+ depth;
11
+ // private as it doesn't check that array content is valid
12
+ constructor(data, width, height, depth) {
13
+ this.data = data;
14
+ this.width = width;
15
+ this.height = height;
16
+ this.depth = depth;
17
+ if (data.length != width * height * depth)
18
+ throw new Error("data isn't of expected size");
19
+ }
20
+ static from(image) {
21
+ return new NormalizedImage(Float32Array.from(image.data).map((v) => v / 255), image.width, image.height, image.depth);
22
+ }
23
+ }
24
+ /** Add a full opaque alpha channel to an image */
25
+ function addAlpha(image) {
26
+ switch (image.depth) {
27
+ case 3:
28
+ return new Image(Uint8Array.from(
29
+ // we are adding a channel, so for every 3 byte in the base image,
30
+ // we need to add a fourth. we choose to "expand" the last channel
31
+ // to two value, the channel base value and the transparency.
32
+ // let's say we want to add a byte A to the bytestring RGB
33
+ // [R, G, B] -> [[R], [G], [B, A]] -> [R, G, B, A]
34
+ Seq(image.data).flatMap((v, i) => {
35
+ const OPAQUE = 0xff;
36
+ if (i % 3 !== 2)
37
+ return [v];
38
+ else
39
+ return [v, OPAQUE];
40
+ })), image.width, image.height, 4);
41
+ case 4:
42
+ return new Image(image.data, image.width, image.height, image.depth);
43
+ }
44
+ }
45
+ export function removeAlpha(image) {
46
+ switch (image.depth) {
47
+ case 1:
48
+ case 3:
49
+ return new Image(image.data, image.width, image.height, image.depth);
50
+ case 4:
51
+ return new Image(image.data.filter((_, i) => i % 4 !== 3), image.width, image.height, 3);
52
+ }
53
+ }
54
+ export function expandToMulticolor(image) {
55
+ switch (image.depth) {
56
+ case 1:
57
+ return new Image(Uint8Array.from(Seq(image.data).flatMap((v) => Repeat(v, 3))), image.width, image.height, 3);
58
+ case 3:
59
+ case 4:
60
+ return new Image(image.data, image.width, image.height, image.depth);
61
+ }
62
+ }
63
+ export function resize(width, height, image) {
64
+ const Jimp = createJimp({
65
+ plugins: [jimpResize.methods],
66
+ });
67
+ const resized = new Jimp(addAlpha(expandToMulticolor(image))).resize({
68
+ w: width,
69
+ h: height,
70
+ });
71
+ return new Image(new Uint8Array(resized.bitmap.data), width, height, 4);
72
+ }
73
+ export function normalize(image) {
74
+ return NormalizedImage.from(image);
75
+ }
@@ -0,0 +1,8 @@
1
+ /** Dataset shapers, convenient to map with */
2
+ import type { Dataset, DataFormat, DataType, Task } from "../index.js";
3
+ export * from "./image.js";
4
+ export * from "./tabular.js";
5
+ export * from "./text.js";
6
+ export declare function preprocess<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.Raw[D]>): Promise<Dataset<DataFormat.ModelEncoded[D]>>;
7
+ export declare function preprocessWithoutLabel<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.RawWithoutLabel[D]>): Promise<Dataset<DataFormat.ModelEncoded[D][0]>>;
8
+ export declare function postprocess<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.ModelEncoded[D][1]>): Promise<Dataset<DataFormat.Inferred[D]>>;
@@ -0,0 +1,106 @@
1
+ /** Dataset shapers, convenient to map with */
2
+ import { List } from "immutable";
3
+ import { models } from "../index.js";
4
+ import * as processing from "./index.js";
5
+ export * from "./image.js";
6
+ export * from "./tabular.js";
7
+ export * from "./text.js";
8
+ export async function preprocess(task, dataset) {
9
+ switch (task.trainingInformation.dataType) {
10
+ case "image": {
11
+ // cast as typescript doesn't reduce generic type
12
+ const d = dataset;
13
+ const { IMAGE_H, IMAGE_W, LABEL_LIST } = task.trainingInformation;
14
+ return d.map(([image, label]) => [
15
+ processing.normalize(processing.removeAlpha(processing.resize(IMAGE_W, IMAGE_H, image))),
16
+ processing.indexInList(label, LABEL_LIST),
17
+ ]);
18
+ }
19
+ case "tabular": {
20
+ // cast as typescript doesn't reduce generic type
21
+ const d = dataset;
22
+ const { inputColumns, outputColumn } = task.trainingInformation;
23
+ return d.map((row) => {
24
+ const output = processing.extractColumn(row, outputColumn);
25
+ return [
26
+ extractToNumbers(inputColumns, row),
27
+ // TODO sanitization doesn't care about column distribution
28
+ output !== "" ? processing.convertToNumber(output) : 0,
29
+ ];
30
+ });
31
+ }
32
+ case "text": {
33
+ // cast as typescript doesn't reduce generic type
34
+ const d = dataset;
35
+ const t = task;
36
+ const tokenizer = await models.getTaskTokenizer(t);
37
+ const totalTokenCount = task.trainingInformation.maxSequenceLength ??
38
+ tokenizer.model_max_length;
39
+ return d
40
+ .map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
41
+ .map((tokens) => [tokens.pop(), tokens.last()]);
42
+ }
43
+ }
44
+ }
45
+ export async function preprocessWithoutLabel(task, dataset) {
46
+ switch (task.trainingInformation.dataType) {
47
+ case "image": {
48
+ // cast as typescript doesn't reduce generic type
49
+ const d = dataset;
50
+ const { IMAGE_H, IMAGE_W } = task.trainingInformation;
51
+ return d.map((image) => processing.normalize(processing.removeAlpha(processing.resize(IMAGE_W, IMAGE_H, image))));
52
+ }
53
+ case "tabular": {
54
+ // cast as typescript doesn't reduce generic type
55
+ const d = dataset;
56
+ const { inputColumns } = task.trainingInformation;
57
+ return d.map((row) => extractToNumbers(inputColumns, row));
58
+ }
59
+ case "text": {
60
+ // cast as typescript doesn't reduce generic type
61
+ const d = dataset;
62
+ const t = task;
63
+ const tokenizer = await models.getTaskTokenizer(t);
64
+ const totalTokenCount = t.trainingInformation.maxSequenceLength ??
65
+ tokenizer.model_max_length;
66
+ return d
67
+ .map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
68
+ .map((tokens) => tokens.pop());
69
+ }
70
+ }
71
+ }
72
+ export async function postprocess(task, dataset) {
73
+ switch (task.trainingInformation.dataType) {
74
+ case "image": {
75
+ // cast as typescript doesn't reduce generic type
76
+ const d = dataset;
77
+ const { LABEL_LIST } = task.trainingInformation;
78
+ const labels = List(LABEL_LIST);
79
+ return d.map((index) => {
80
+ const v = labels.get(index);
81
+ if (v === undefined)
82
+ throw new Error("index not found in labels");
83
+ return v;
84
+ });
85
+ }
86
+ case "tabular": {
87
+ // cast as typescript doesn't reduce generic type
88
+ const d = dataset;
89
+ return d;
90
+ }
91
+ case "text": {
92
+ // cast as typescript doesn't reduce generic type
93
+ const d = dataset;
94
+ const t = task;
95
+ const tokenizer = await models.getTaskTokenizer(t);
96
+ return d.map((token) => tokenizer.decode([token]));
97
+ }
98
+ }
99
+ }
100
+ function extractToNumbers(columns, row) {
101
+ return (List(columns)
102
+ .map((column) => processing.extractColumn(row, column))
103
+ // TODO sanitization doesn't care about column distribution
104
+ .map((v) => (v !== "" ? v : "0"))
105
+ .map(processing.convertToNumber));
106
+ }
@@ -0,0 +1,19 @@
1
+ import { List } from "immutable";
2
+ /**
3
+ * Convert a string to a number
4
+ *
5
+ * @throws if it isn't written as a number
6
+ */
7
+ export declare function convertToNumber(raw: string): number;
8
+ /**
9
+ * Return the named field of an object with string values
10
+ *
11
+ * @throws if the named field isn't there
12
+ */
13
+ export declare function extractColumn(row: Partial<Record<string, string>>, column: string): string;
14
+ /**
15
+ * Return the index of the element in the given list
16
+ *
17
+ * @throws if not found
18
+ */
19
+ export declare function indexInList(element: string, elements: List<string> | Array<string>): number;
@@ -0,0 +1,33 @@
1
+ /**
2
+ * Convert a string to a number
3
+ *
4
+ * @throws if it isn't written as a number
5
+ */
6
+ export function convertToNumber(raw) {
7
+ const num = Number.parseFloat(raw);
8
+ if (Number.isNaN(num))
9
+ throw new Error(`unable to parse "${raw}" as number`);
10
+ return num;
11
+ }
12
+ /**
13
+ * Return the named field of an object with string values
14
+ *
15
+ * @throws if the named field isn't there
16
+ */
17
+ export function extractColumn(row, column) {
18
+ const raw = row[column];
19
+ if (raw === undefined)
20
+ throw new Error(`${column} not found in row`);
21
+ return raw;
22
+ }
23
+ /**
24
+ * Return the index of the element in the given list
25
+ *
26
+ * @throws if not found
27
+ */
28
+ export function indexInList(element, elements) {
29
+ const ret = elements.indexOf(element);
30
+ if (ret === -1)
31
+ throw new Error(`${element} not found in list`);
32
+ return ret;
33
+ }
@@ -0,0 +1,11 @@
1
+ import { List } from "immutable";
2
+ import { PreTrainedTokenizer } from "@xenova/transformers";
3
+ type Token = number;
4
+ /**
5
+ * Tokenize and truncates input strings
6
+ *
7
+ * @param length number of tokens
8
+ * @returns encoded string in an array of token, size of max_length
9
+ */
10
+ export declare function tokenizeAndLeftPad(line: string, tokenizer: PreTrainedTokenizer, length: number): List<Token>;
11
+ export {};
@@ -0,0 +1,33 @@
1
+ import { Repeat } from "immutable";
2
+ function isArrayOfNumber(raw) {
3
+ return Array.isArray(raw) && raw.every((e) => typeof e === "number");
4
+ }
5
+ /**
6
+ * Tokenize and truncates input strings
7
+ *
8
+ * @param length number of tokens
9
+ * @returns encoded string in an array of token, size of max_length
10
+ */
11
+ export function tokenizeAndLeftPad(line, tokenizer, length) {
12
+ if (!Number.isInteger(length))
13
+ throw new Error("length should be an integer");
14
+ // Transformers.js currently only supports right padding while we need left for text generation
15
+ // Right padding should be supported in the future, once it is, we can directly pad while tokenizing
16
+ // https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
17
+ const tokenized = tokenizer(line, {
18
+ padding: false,
19
+ truncation: true,
20
+ return_tensor: false,
21
+ max_length: length,
22
+ });
23
+ if (typeof tokenized !== "object" ||
24
+ tokenized === null ||
25
+ !("input_ids" in tokenized) ||
26
+ !isArrayOfNumber(tokenized.input_ids))
27
+ throw new Error("tokenizer returns unexpected type");
28
+ const tokens = tokenized.input_ids;
29
+ const paddingSize = length - tokens.length;
30
+ if (paddingSize < 0)
31
+ throw new Error("tokenized returned more token than expected");
32
+ return Repeat(tokenizer.pad_token_id, paddingSize).concat(tokens).toList();
33
+ }
@@ -1,4 +1,4 @@
1
- import type { Model } from '../index.js';
1
+ import type { DataType, Model } from '../index.js';
2
2
  import { Encoded } from "./coder.js";
3
- export declare function encode(model: Model): Promise<Encoded>;
4
- export declare function decode(encoded: unknown): Promise<Model>;
3
+ export declare function encode(model: Model<DataType>): Promise<Encoded>;
4
+ export declare function decode(encoded: unknown): Promise<Model<DataType>>;
@@ -9,7 +9,7 @@ export async function encode(model) {
9
9
  switch (true) {
10
10
  case model instanceof models.TFJS: {
11
11
  const serialized = await model.serialize();
12
- return coder.encode([Type.TFJS, serialized]);
12
+ return coder.encode([Type.TFJS, ...serialized]);
13
13
  }
14
14
  case model instanceof models.GPT: {
15
15
  const { weights, config } = model.serialize();
@@ -33,12 +33,25 @@ export async function decode(encoded) {
33
33
  }
34
34
  const rawModel = raw[1];
35
35
  switch (type) {
36
- case Type.TFJS:
37
- if (raw.length !== 2) {
38
- throw new Error('invalid encoding, TFJS model encoding should be an array of length 2');
36
+ case Type.TFJS: {
37
+ if (raw.length !== 3)
38
+ throw new Error("invalid TFJS model encoding: should be an array of length 3");
39
+ const [rawDatatype, rawModel] = raw.slice(1);
40
+ let datatype;
41
+ switch (rawDatatype) {
42
+ case "image":
43
+ case "tabular":
44
+ datatype = rawDatatype;
45
+ break;
46
+ default:
47
+ throw new Error("invalid TFJS model encoding: invalid DataType");
39
48
  }
40
- // TODO totally unsafe casting
41
- return await models.TFJS.deserialize(rawModel);
49
+ return await models.TFJS.deserialize([
50
+ datatype,
51
+ // TODO totally unsafe casting
52
+ rawModel,
53
+ ]);
54
+ }
42
55
  case Type.GPT: {
43
56
  let config;
44
57
  if (raw.length == 2) {
@@ -1,10 +1,11 @@
1
+ import { DataType } from "../index.js";
1
2
  import { type DisplayInformation } from './display_information.js';
2
3
  import { type TrainingInformation } from './training_information.js';
3
4
  export type TaskID = string;
4
- export interface Task {
5
+ export interface Task<D extends DataType> {
5
6
  id: TaskID;
6
7
  displayInformation: DisplayInformation;
7
- trainingInformation: TrainingInformation;
8
+ trainingInformation: TrainingInformation<D>;
8
9
  }
9
10
  export declare function isTaskID(obj: unknown): obj is TaskID;
10
- export declare function isTask(raw: unknown): raw is Task;
11
+ export declare function isTask(raw: unknown): raw is Task<DataType>;
package/dist/task/task.js CHANGED
@@ -13,8 +13,10 @@ export function isTask(raw) {
13
13
  !isTrainingInformation(trainingInformation)) {
14
14
  return false;
15
15
  }
16
- const repack = { id, displayInformation, trainingInformation };
17
- const _correct = repack;
18
- const _total = repack;
16
+ const _ = {
17
+ id,
18
+ displayInformation,
19
+ trainingInformation,
20
+ };
19
21
  return true;
20
22
  }
@@ -1,5 +1,5 @@
1
1
  import { Map } from "immutable";
2
- import type { Model } from "../index.js";
2
+ import type { DataType, Model } from "../index.js";
3
3
  import type { Task, TaskID } from "./task.js";
4
- export declare function pushTask(base: URL, task: Task, model: Model): Promise<void>;
5
- export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task>>;
4
+ export declare function pushTask<D extends DataType>(base: URL, task: Task<D>, model: Model<D>): Promise<void>;
5
+ export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task<DataType>>>;
@@ -1,5 +1,5 @@
1
- import type { Model, Task } from '../index.js';
2
- export interface TaskProvider {
3
- getTask: () => Task;
4
- getModel: () => Promise<Model>;
1
+ import type { DataType, Model, Task } from "../index.js";
2
+ export interface TaskProvider<D extends DataType> {
3
+ getTask(): Task<D>;
4
+ getModel(): Promise<Model<D>>;
5
5
  }
@@ -1,29 +1,38 @@
1
- import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
2
- import { PreTrainedTokenizer } from '@xenova/transformers';
1
+ import { PreTrainedTokenizer } from "@xenova/transformers";
2
+ import { DataType } from "../index.js";
3
3
  interface Privacy {
4
4
  clippingRadius?: number;
5
5
  noiseScale?: number;
6
6
  }
7
- export interface TrainingInformation {
7
+ export type TrainingInformation<D extends DataType> = {
8
8
  epochs: number;
9
9
  roundDuration: number;
10
10
  validationSplit: number;
11
11
  batchSize: number;
12
- preprocessingFunctions?: Preprocessing[];
13
- dataType: 'image' | 'tabular' | 'text';
14
- inputColumns?: string[];
15
- outputColumns?: string[];
16
- IMAGE_H?: number;
17
- IMAGE_W?: number;
18
- LABEL_LIST?: string[];
19
- scheme: 'decentralized' | 'federated' | 'local';
12
+ scheme: "decentralized" | "federated" | "local";
20
13
  privacy?: Privacy;
21
14
  maxShareValue?: number;
22
15
  minNbOfParticipants: number;
23
- aggregationStrategy?: 'mean' | 'secure';
24
- tokenizer?: string | PreTrainedTokenizer;
25
- maxSequenceLength?: number;
26
- tensorBackend: 'tfjs' | 'gpt';
16
+ aggregationStrategy?: "mean" | "secure";
17
+ tensorBackend: "tfjs" | "gpt";
18
+ } & DataTypeToTrainingInformation[D];
19
+ interface DataTypeToTrainingInformation {
20
+ image: {
21
+ dataType: "image";
22
+ LABEL_LIST: string[];
23
+ IMAGE_H: number;
24
+ IMAGE_W: number;
25
+ };
26
+ tabular: {
27
+ dataType: "tabular";
28
+ inputColumns: string[];
29
+ outputColumn: string;
30
+ };
31
+ text: {
32
+ dataType: "text";
33
+ tokenizer: string | PreTrainedTokenizer;
34
+ maxSequenceLength?: number;
35
+ };
27
36
  }
28
- export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
37
+ export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation<DataType>;
29
38
  export {};
@@ -1,11 +1,4 @@
1
- import { PreTrainedTokenizer } from '@xenova/transformers';
2
- function isStringArray(raw) {
3
- if (!Array.isArray(raw)) {
4
- return false;
5
- }
6
- const arr = raw; // isArray is unsafely guarding with any[]
7
- return arr.every((e) => typeof e === 'string');
8
- }
1
+ import { PreTrainedTokenizer } from "@xenova/transformers";
9
2
  function isPrivacy(raw) {
10
3
  if (typeof raw !== "object" || raw === null) {
11
4
  return false;
@@ -21,89 +14,100 @@ function isPrivacy(raw) {
21
14
  return true;
22
15
  }
23
16
  export function isTrainingInformation(raw) {
24
- if (typeof raw !== 'object' || raw === null) {
17
+ if (typeof raw !== "object" || raw === null) {
25
18
  return false;
26
19
  }
27
- const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregationStrategy, batchSize, dataType, privacy, epochs, inputColumns, maxShareValue, minNbOfParticipants, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
28
- if (typeof dataType !== 'string' ||
29
- typeof epochs !== 'number' ||
30
- typeof batchSize !== 'number' ||
31
- typeof roundDuration !== 'number' ||
32
- typeof validationSplit !== 'number' ||
33
- typeof minNbOfParticipants !== 'number' ||
34
- (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
35
- (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
36
- (aggregationStrategy !== undefined && typeof aggregationStrategy !== 'string') ||
20
+ const { aggregationStrategy, batchSize, dataType, privacy, epochs, maxShareValue, minNbOfParticipants, roundDuration, scheme, validationSplit, tensorBackend, } = raw;
21
+ if (typeof epochs !== "number" ||
22
+ typeof batchSize !== "number" ||
23
+ typeof roundDuration !== "number" ||
24
+ typeof validationSplit !== "number" ||
25
+ typeof minNbOfParticipants !== "number" ||
37
26
  (privacy !== undefined && !isPrivacy(privacy)) ||
38
- (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
39
- (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
40
- (IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
41
- (LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
42
- (inputColumns !== undefined && !isStringArray(inputColumns)) ||
43
- (outputColumns !== undefined && !isStringArray(outputColumns)) ||
44
- (preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
27
+ (maxShareValue !== undefined && typeof maxShareValue !== "number")) {
45
28
  return false;
46
29
  }
47
- if (aggregationStrategy !== undefined) {
48
- switch (aggregationStrategy) {
49
- case 'mean': break;
50
- case 'secure': break;
51
- default: return false;
52
- }
53
- }
54
- switch (dataType) {
55
- case 'image': break;
56
- case 'tabular': break;
57
- case 'text': break;
58
- default: return false;
59
- }
60
- // interdependencies on data type
61
- if (dataType === 'image') {
62
- if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
63
- return false;
64
- }
65
- }
66
- else if (dataType in ['text', 'tabular']) {
67
- if (!(Array.isArray(inputColumns) && inputColumns.every((e) => typeof e === 'string'))) {
68
- return false;
69
- }
70
- if (!(Array.isArray(outputColumns) && outputColumns.every((e) => typeof e === 'string'))) {
30
+ switch (aggregationStrategy) {
31
+ case undefined:
32
+ case "mean":
33
+ case "secure":
34
+ break;
35
+ default:
71
36
  return false;
72
- }
73
37
  }
74
38
  switch (tensorBackend) {
75
- case 'tfjs': break;
76
- case 'gpt': break;
77
- default: return false;
39
+ case "tfjs":
40
+ case "gpt":
41
+ break;
42
+ default:
43
+ return false;
78
44
  }
79
45
  switch (scheme) {
80
- case 'decentralized': break;
81
- case 'federated': break;
82
- case 'local': break;
83
- default: return false;
46
+ case "decentralized":
47
+ case "federated":
48
+ case "local":
49
+ break;
50
+ default:
51
+ return false;
84
52
  }
85
53
  const repack = {
86
- IMAGE_W,
87
- IMAGE_H,
88
- LABEL_LIST,
89
54
  aggregationStrategy,
90
55
  batchSize,
91
- dataType,
92
- privacy,
93
56
  epochs,
94
- inputColumns,
95
57
  maxShareValue,
96
58
  minNbOfParticipants,
97
- outputColumns,
98
- preprocessingFunctions,
59
+ privacy,
99
60
  roundDuration,
100
61
  scheme,
62
+ tensorBackend,
101
63
  validationSplit,
102
- tokenizer,
103
- maxSequenceLength,
104
- tensorBackend
105
64
  };
106
- const _correct = repack;
107
- const _total = repack;
108
- return true;
65
+ switch (dataType) {
66
+ case "image": {
67
+ const { LABEL_LIST, IMAGE_W, IMAGE_H } = raw;
68
+ if (!(Array.isArray(LABEL_LIST) &&
69
+ LABEL_LIST.every((e) => typeof e === "string")) ||
70
+ typeof IMAGE_H !== "number" ||
71
+ typeof IMAGE_W !== "number")
72
+ return false;
73
+ const _ = {
74
+ ...repack,
75
+ dataType,
76
+ LABEL_LIST,
77
+ IMAGE_W,
78
+ IMAGE_H,
79
+ };
80
+ return true;
81
+ }
82
+ case "tabular": {
83
+ const { inputColumns, outputColumn } = raw;
84
+ if (!(Array.isArray(inputColumns) &&
85
+ inputColumns.every((e) => typeof e === "string")) ||
86
+ typeof outputColumn !== "string")
87
+ return false;
88
+ const _ = {
89
+ ...repack,
90
+ dataType,
91
+ inputColumns,
92
+ outputColumn,
93
+ };
94
+ return true;
95
+ }
96
+ case "text": {
97
+ const { maxSequenceLength, tokenizer, } = raw;
98
+ if ((typeof tokenizer !== "string" &&
99
+ !(tokenizer instanceof PreTrainedTokenizer)) ||
100
+ (maxSequenceLength !== undefined &&
101
+ typeof maxSequenceLength !== "number"))
102
+ return false;
103
+ const _ = {
104
+ ...repack,
105
+ dataType,
106
+ maxSequenceLength,
107
+ tokenizer,
108
+ };
109
+ return true;
110
+ }
111
+ }
112
+ return false;
109
113
  }