@epfml/discojs 3.0.1-p20241025115642.0 → 3.0.1-p20241107104659.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/get.d.ts +3 -3
- package/dist/client/client.d.ts +5 -5
- package/dist/client/decentralized/decentralized_client.d.ts +2 -2
- package/dist/client/federated/federated_client.d.ts +2 -2
- package/dist/client/utils.d.ts +2 -2
- package/dist/dataset/dataset.d.ts +9 -2
- package/dist/dataset/dataset.js +83 -36
- package/dist/dataset/image.d.ts +5 -0
- package/dist/dataset/image.js +6 -1
- package/dist/dataset/index.d.ts +0 -1
- package/dist/dataset/index.js +0 -1
- package/dist/dataset/types.d.ts +2 -0
- package/dist/default_tasks/cifar10.d.ts +1 -1
- package/dist/default_tasks/cifar10.js +2 -3
- package/dist/default_tasks/lus_covid.d.ts +1 -1
- package/dist/default_tasks/lus_covid.js +2 -3
- package/dist/default_tasks/mnist.d.ts +1 -1
- package/dist/default_tasks/mnist.js +3 -5
- package/dist/default_tasks/simple_face.d.ts +1 -1
- package/dist/default_tasks/simple_face.js +2 -3
- package/dist/default_tasks/titanic.d.ts +1 -1
- package/dist/default_tasks/titanic.js +3 -6
- package/dist/default_tasks/wikitext.d.ts +1 -1
- package/dist/default_tasks/wikitext.js +1 -2
- package/dist/index.d.ts +4 -5
- package/dist/index.js +4 -5
- package/dist/models/gpt/index.d.ts +13 -16
- package/dist/models/gpt/index.js +62 -43
- package/dist/models/gpt/model.d.ts +1 -15
- package/dist/models/gpt/model.js +1 -75
- package/dist/models/model.d.ts +7 -12
- package/dist/models/tfjs.d.ts +10 -8
- package/dist/models/tfjs.js +106 -44
- package/dist/models/tokenizer.d.ts +1 -1
- package/dist/privacy.js +1 -1
- package/dist/processing/image.d.ts +18 -0
- package/dist/processing/image.js +75 -0
- package/dist/processing/index.d.ts +8 -0
- package/dist/processing/index.js +106 -0
- package/dist/processing/tabular.d.ts +19 -0
- package/dist/processing/tabular.js +33 -0
- package/dist/processing/text.d.ts +11 -0
- package/dist/processing/text.js +33 -0
- package/dist/serialization/model.d.ts +3 -3
- package/dist/serialization/model.js +19 -6
- package/dist/task/task.d.ts +4 -3
- package/dist/task/task.js +5 -3
- package/dist/task/task_handler.d.ts +3 -3
- package/dist/task/task_provider.d.ts +4 -4
- package/dist/task/training_information.d.ts +25 -16
- package/dist/task/training_information.js +76 -72
- package/dist/training/disco.d.ts +20 -12
- package/dist/training/disco.js +32 -13
- package/dist/training/trainer.d.ts +6 -7
- package/dist/training/trainer.js +6 -6
- package/dist/types/data_format.d.ts +40 -0
- package/dist/types/index.d.ts +2 -0
- package/dist/types/index.js +1 -0
- package/dist/validator.d.ts +10 -0
- package/dist/validator.js +30 -0
- package/package.json +4 -2
- package/dist/dataset/data/data.d.ts +0 -47
- package/dist/dataset/data/data.js +0 -88
- package/dist/dataset/data/data_split.d.ts +0 -8
- package/dist/dataset/data/helpers.d.ts +0 -10
- package/dist/dataset/data/helpers.js +0 -97
- package/dist/dataset/data/image_data.d.ts +0 -11
- package/dist/dataset/data/image_data.js +0 -43
- package/dist/dataset/data/index.d.ts +0 -5
- package/dist/dataset/data/index.js +0 -5
- package/dist/dataset/data/preprocessing/base.d.ts +0 -16
- package/dist/dataset/data/preprocessing/base.js +0 -1
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/image_preprocessing.js +0 -42
- package/dist/dataset/data/preprocessing/index.d.ts +0 -4
- package/dist/dataset/data/preprocessing/index.js +0 -3
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +0 -45
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +0 -13
- package/dist/dataset/data/preprocessing/text_preprocessing.js +0 -100
- package/dist/dataset/data/tabular_data.d.ts +0 -11
- package/dist/dataset/data/tabular_data.js +0 -24
- package/dist/dataset/data/text_data.d.ts +0 -11
- package/dist/dataset/data/text_data.js +0 -14
- package/dist/processing.d.ts +0 -35
- package/dist/processing.js +0 -89
- package/dist/types.d.ts +0 -3
- package/dist/types.js +0 -1
- package/dist/validation/index.d.ts +0 -1
- package/dist/validation/index.js +0 -1
- package/dist/validation/validator.d.ts +0 -10
- package/dist/validation/validator.js +0 -113
- /package/dist/{dataset/data/data_split.js → types/data_format.js} +0 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import { Repeat, Seq } from "immutable";
|
|
2
|
+
import { createJimp } from "@jimp/core";
|
|
3
|
+
import * as jimpResize from "@jimp/plugin-resize";
|
|
4
|
+
import { Image } from "../index.js";
|
|
5
|
+
/** Image where intensity is represented in the range 0..1 */
|
|
6
|
+
export class NormalizedImage {
|
|
7
|
+
data;
|
|
8
|
+
width;
|
|
9
|
+
height;
|
|
10
|
+
depth;
|
|
11
|
+
// private as it doesn't check that array content is valid
|
|
12
|
+
constructor(data, width, height, depth) {
|
|
13
|
+
this.data = data;
|
|
14
|
+
this.width = width;
|
|
15
|
+
this.height = height;
|
|
16
|
+
this.depth = depth;
|
|
17
|
+
if (data.length != width * height * depth)
|
|
18
|
+
throw new Error("data isn't of expected size");
|
|
19
|
+
}
|
|
20
|
+
static from(image) {
|
|
21
|
+
return new NormalizedImage(Float32Array.from(image.data).map((v) => v / 255), image.width, image.height, image.depth);
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
/** Add a full opaque alpha channel to an image */
|
|
25
|
+
function addAlpha(image) {
|
|
26
|
+
switch (image.depth) {
|
|
27
|
+
case 3:
|
|
28
|
+
return new Image(Uint8Array.from(
|
|
29
|
+
// we are adding a channel, so for every 3 byte in the base image,
|
|
30
|
+
// we need to add a fourth. we choose to "expand" the last channel
|
|
31
|
+
// to two value, the channel base value and the transparency.
|
|
32
|
+
// let's say we want to add a byte A to the bytestring RGB
|
|
33
|
+
// [R, G, B] -> [[R], [G], [B, A]] -> [R, G, B, A]
|
|
34
|
+
Seq(image.data).flatMap((v, i) => {
|
|
35
|
+
const OPAQUE = 0xff;
|
|
36
|
+
if (i % 3 !== 2)
|
|
37
|
+
return [v];
|
|
38
|
+
else
|
|
39
|
+
return [v, OPAQUE];
|
|
40
|
+
})), image.width, image.height, 4);
|
|
41
|
+
case 4:
|
|
42
|
+
return new Image(image.data, image.width, image.height, image.depth);
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
export function removeAlpha(image) {
|
|
46
|
+
switch (image.depth) {
|
|
47
|
+
case 1:
|
|
48
|
+
case 3:
|
|
49
|
+
return new Image(image.data, image.width, image.height, image.depth);
|
|
50
|
+
case 4:
|
|
51
|
+
return new Image(image.data.filter((_, i) => i % 4 !== 3), image.width, image.height, 3);
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
export function expandToMulticolor(image) {
|
|
55
|
+
switch (image.depth) {
|
|
56
|
+
case 1:
|
|
57
|
+
return new Image(Uint8Array.from(Seq(image.data).flatMap((v) => Repeat(v, 3))), image.width, image.height, 3);
|
|
58
|
+
case 3:
|
|
59
|
+
case 4:
|
|
60
|
+
return new Image(image.data, image.width, image.height, image.depth);
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
export function resize(width, height, image) {
|
|
64
|
+
const Jimp = createJimp({
|
|
65
|
+
plugins: [jimpResize.methods],
|
|
66
|
+
});
|
|
67
|
+
const resized = new Jimp(addAlpha(expandToMulticolor(image))).resize({
|
|
68
|
+
w: width,
|
|
69
|
+
h: height,
|
|
70
|
+
});
|
|
71
|
+
return new Image(new Uint8Array(resized.bitmap.data), width, height, 4);
|
|
72
|
+
}
|
|
73
|
+
export function normalize(image) {
|
|
74
|
+
return NormalizedImage.from(image);
|
|
75
|
+
}
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
/** Dataset shapers, convenient to map with */
|
|
2
|
+
import type { Dataset, DataFormat, DataType, Task } from "../index.js";
|
|
3
|
+
export * from "./image.js";
|
|
4
|
+
export * from "./tabular.js";
|
|
5
|
+
export * from "./text.js";
|
|
6
|
+
export declare function preprocess<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.Raw[D]>): Promise<Dataset<DataFormat.ModelEncoded[D]>>;
|
|
7
|
+
export declare function preprocessWithoutLabel<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.RawWithoutLabel[D]>): Promise<Dataset<DataFormat.ModelEncoded[D][0]>>;
|
|
8
|
+
export declare function postprocess<D extends DataType>(task: Task<D>, dataset: Dataset<DataFormat.ModelEncoded[D][1]>): Promise<Dataset<DataFormat.Inferred[D]>>;
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
/** Dataset shapers, convenient to map with */
|
|
2
|
+
import { List } from "immutable";
|
|
3
|
+
import { models } from "../index.js";
|
|
4
|
+
import * as processing from "./index.js";
|
|
5
|
+
export * from "./image.js";
|
|
6
|
+
export * from "./tabular.js";
|
|
7
|
+
export * from "./text.js";
|
|
8
|
+
export async function preprocess(task, dataset) {
|
|
9
|
+
switch (task.trainingInformation.dataType) {
|
|
10
|
+
case "image": {
|
|
11
|
+
// cast as typescript doesn't reduce generic type
|
|
12
|
+
const d = dataset;
|
|
13
|
+
const { IMAGE_H, IMAGE_W, LABEL_LIST } = task.trainingInformation;
|
|
14
|
+
return d.map(([image, label]) => [
|
|
15
|
+
processing.normalize(processing.removeAlpha(processing.resize(IMAGE_W, IMAGE_H, image))),
|
|
16
|
+
processing.indexInList(label, LABEL_LIST),
|
|
17
|
+
]);
|
|
18
|
+
}
|
|
19
|
+
case "tabular": {
|
|
20
|
+
// cast as typescript doesn't reduce generic type
|
|
21
|
+
const d = dataset;
|
|
22
|
+
const { inputColumns, outputColumn } = task.trainingInformation;
|
|
23
|
+
return d.map((row) => {
|
|
24
|
+
const output = processing.extractColumn(row, outputColumn);
|
|
25
|
+
return [
|
|
26
|
+
extractToNumbers(inputColumns, row),
|
|
27
|
+
// TODO sanitization doesn't care about column distribution
|
|
28
|
+
output !== "" ? processing.convertToNumber(output) : 0,
|
|
29
|
+
];
|
|
30
|
+
});
|
|
31
|
+
}
|
|
32
|
+
case "text": {
|
|
33
|
+
// cast as typescript doesn't reduce generic type
|
|
34
|
+
const d = dataset;
|
|
35
|
+
const t = task;
|
|
36
|
+
const tokenizer = await models.getTaskTokenizer(t);
|
|
37
|
+
const totalTokenCount = task.trainingInformation.maxSequenceLength ??
|
|
38
|
+
tokenizer.model_max_length;
|
|
39
|
+
return d
|
|
40
|
+
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
|
|
41
|
+
.map((tokens) => [tokens.pop(), tokens.last()]);
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
export async function preprocessWithoutLabel(task, dataset) {
|
|
46
|
+
switch (task.trainingInformation.dataType) {
|
|
47
|
+
case "image": {
|
|
48
|
+
// cast as typescript doesn't reduce generic type
|
|
49
|
+
const d = dataset;
|
|
50
|
+
const { IMAGE_H, IMAGE_W } = task.trainingInformation;
|
|
51
|
+
return d.map((image) => processing.normalize(processing.removeAlpha(processing.resize(IMAGE_W, IMAGE_H, image))));
|
|
52
|
+
}
|
|
53
|
+
case "tabular": {
|
|
54
|
+
// cast as typescript doesn't reduce generic type
|
|
55
|
+
const d = dataset;
|
|
56
|
+
const { inputColumns } = task.trainingInformation;
|
|
57
|
+
return d.map((row) => extractToNumbers(inputColumns, row));
|
|
58
|
+
}
|
|
59
|
+
case "text": {
|
|
60
|
+
// cast as typescript doesn't reduce generic type
|
|
61
|
+
const d = dataset;
|
|
62
|
+
const t = task;
|
|
63
|
+
const tokenizer = await models.getTaskTokenizer(t);
|
|
64
|
+
const totalTokenCount = t.trainingInformation.maxSequenceLength ??
|
|
65
|
+
tokenizer.model_max_length;
|
|
66
|
+
return d
|
|
67
|
+
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, totalTokenCount))
|
|
68
|
+
.map((tokens) => tokens.pop());
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
export async function postprocess(task, dataset) {
|
|
73
|
+
switch (task.trainingInformation.dataType) {
|
|
74
|
+
case "image": {
|
|
75
|
+
// cast as typescript doesn't reduce generic type
|
|
76
|
+
const d = dataset;
|
|
77
|
+
const { LABEL_LIST } = task.trainingInformation;
|
|
78
|
+
const labels = List(LABEL_LIST);
|
|
79
|
+
return d.map((index) => {
|
|
80
|
+
const v = labels.get(index);
|
|
81
|
+
if (v === undefined)
|
|
82
|
+
throw new Error("index not found in labels");
|
|
83
|
+
return v;
|
|
84
|
+
});
|
|
85
|
+
}
|
|
86
|
+
case "tabular": {
|
|
87
|
+
// cast as typescript doesn't reduce generic type
|
|
88
|
+
const d = dataset;
|
|
89
|
+
return d;
|
|
90
|
+
}
|
|
91
|
+
case "text": {
|
|
92
|
+
// cast as typescript doesn't reduce generic type
|
|
93
|
+
const d = dataset;
|
|
94
|
+
const t = task;
|
|
95
|
+
const tokenizer = await models.getTaskTokenizer(t);
|
|
96
|
+
return d.map((token) => tokenizer.decode([token]));
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
function extractToNumbers(columns, row) {
|
|
101
|
+
return (List(columns)
|
|
102
|
+
.map((column) => processing.extractColumn(row, column))
|
|
103
|
+
// TODO sanitization doesn't care about column distribution
|
|
104
|
+
.map((v) => (v !== "" ? v : "0"))
|
|
105
|
+
.map(processing.convertToNumber));
|
|
106
|
+
}
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import { List } from "immutable";
|
|
2
|
+
/**
|
|
3
|
+
* Convert a string to a number
|
|
4
|
+
*
|
|
5
|
+
* @throws if it isn't written as a number
|
|
6
|
+
*/
|
|
7
|
+
export declare function convertToNumber(raw: string): number;
|
|
8
|
+
/**
|
|
9
|
+
* Return the named field of an object with string values
|
|
10
|
+
*
|
|
11
|
+
* @throws if the named field isn't there
|
|
12
|
+
*/
|
|
13
|
+
export declare function extractColumn(row: Partial<Record<string, string>>, column: string): string;
|
|
14
|
+
/**
|
|
15
|
+
* Return the index of the element in the given list
|
|
16
|
+
*
|
|
17
|
+
* @throws if not found
|
|
18
|
+
*/
|
|
19
|
+
export declare function indexInList(element: string, elements: List<string> | Array<string>): number;
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Convert a string to a number
|
|
3
|
+
*
|
|
4
|
+
* @throws if it isn't written as a number
|
|
5
|
+
*/
|
|
6
|
+
export function convertToNumber(raw) {
|
|
7
|
+
const num = Number.parseFloat(raw);
|
|
8
|
+
if (Number.isNaN(num))
|
|
9
|
+
throw new Error(`unable to parse "${raw}" as number`);
|
|
10
|
+
return num;
|
|
11
|
+
}
|
|
12
|
+
/**
|
|
13
|
+
* Return the named field of an object with string values
|
|
14
|
+
*
|
|
15
|
+
* @throws if the named field isn't there
|
|
16
|
+
*/
|
|
17
|
+
export function extractColumn(row, column) {
|
|
18
|
+
const raw = row[column];
|
|
19
|
+
if (raw === undefined)
|
|
20
|
+
throw new Error(`${column} not found in row`);
|
|
21
|
+
return raw;
|
|
22
|
+
}
|
|
23
|
+
/**
|
|
24
|
+
* Return the index of the element in the given list
|
|
25
|
+
*
|
|
26
|
+
* @throws if not found
|
|
27
|
+
*/
|
|
28
|
+
export function indexInList(element, elements) {
|
|
29
|
+
const ret = elements.indexOf(element);
|
|
30
|
+
if (ret === -1)
|
|
31
|
+
throw new Error(`${element} not found in list`);
|
|
32
|
+
return ret;
|
|
33
|
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import { List } from "immutable";
|
|
2
|
+
import { PreTrainedTokenizer } from "@xenova/transformers";
|
|
3
|
+
type Token = number;
|
|
4
|
+
/**
|
|
5
|
+
* Tokenize and truncates input strings
|
|
6
|
+
*
|
|
7
|
+
* @param length number of tokens
|
|
8
|
+
* @returns encoded string in an array of token, size of max_length
|
|
9
|
+
*/
|
|
10
|
+
export declare function tokenizeAndLeftPad(line: string, tokenizer: PreTrainedTokenizer, length: number): List<Token>;
|
|
11
|
+
export {};
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import { Repeat } from "immutable";
|
|
2
|
+
function isArrayOfNumber(raw) {
|
|
3
|
+
return Array.isArray(raw) && raw.every((e) => typeof e === "number");
|
|
4
|
+
}
|
|
5
|
+
/**
|
|
6
|
+
* Tokenize and truncates input strings
|
|
7
|
+
*
|
|
8
|
+
* @param length number of tokens
|
|
9
|
+
* @returns encoded string in an array of token, size of max_length
|
|
10
|
+
*/
|
|
11
|
+
export function tokenizeAndLeftPad(line, tokenizer, length) {
|
|
12
|
+
if (!Number.isInteger(length))
|
|
13
|
+
throw new Error("length should be an integer");
|
|
14
|
+
// Transformers.js currently only supports right padding while we need left for text generation
|
|
15
|
+
// Right padding should be supported in the future, once it is, we can directly pad while tokenizing
|
|
16
|
+
// https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
|
|
17
|
+
const tokenized = tokenizer(line, {
|
|
18
|
+
padding: false,
|
|
19
|
+
truncation: true,
|
|
20
|
+
return_tensor: false,
|
|
21
|
+
max_length: length,
|
|
22
|
+
});
|
|
23
|
+
if (typeof tokenized !== "object" ||
|
|
24
|
+
tokenized === null ||
|
|
25
|
+
!("input_ids" in tokenized) ||
|
|
26
|
+
!isArrayOfNumber(tokenized.input_ids))
|
|
27
|
+
throw new Error("tokenizer returns unexpected type");
|
|
28
|
+
const tokens = tokenized.input_ids;
|
|
29
|
+
const paddingSize = length - tokens.length;
|
|
30
|
+
if (paddingSize < 0)
|
|
31
|
+
throw new Error("tokenized returned more token than expected");
|
|
32
|
+
return Repeat(tokenizer.pad_token_id, paddingSize).concat(tokens).toList();
|
|
33
|
+
}
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { Model } from '../index.js';
|
|
1
|
+
import type { DataType, Model } from '../index.js';
|
|
2
2
|
import { Encoded } from "./coder.js";
|
|
3
|
-
export declare function encode(model: Model): Promise<Encoded>;
|
|
4
|
-
export declare function decode(encoded: unknown): Promise<Model
|
|
3
|
+
export declare function encode(model: Model<DataType>): Promise<Encoded>;
|
|
4
|
+
export declare function decode(encoded: unknown): Promise<Model<DataType>>;
|
|
@@ -9,7 +9,7 @@ export async function encode(model) {
|
|
|
9
9
|
switch (true) {
|
|
10
10
|
case model instanceof models.TFJS: {
|
|
11
11
|
const serialized = await model.serialize();
|
|
12
|
-
return coder.encode([Type.TFJS, serialized]);
|
|
12
|
+
return coder.encode([Type.TFJS, ...serialized]);
|
|
13
13
|
}
|
|
14
14
|
case model instanceof models.GPT: {
|
|
15
15
|
const { weights, config } = model.serialize();
|
|
@@ -33,12 +33,25 @@ export async function decode(encoded) {
|
|
|
33
33
|
}
|
|
34
34
|
const rawModel = raw[1];
|
|
35
35
|
switch (type) {
|
|
36
|
-
case Type.TFJS:
|
|
37
|
-
if (raw.length !==
|
|
38
|
-
throw new Error(
|
|
36
|
+
case Type.TFJS: {
|
|
37
|
+
if (raw.length !== 3)
|
|
38
|
+
throw new Error("invalid TFJS model encoding: should be an array of length 3");
|
|
39
|
+
const [rawDatatype, rawModel] = raw.slice(1);
|
|
40
|
+
let datatype;
|
|
41
|
+
switch (rawDatatype) {
|
|
42
|
+
case "image":
|
|
43
|
+
case "tabular":
|
|
44
|
+
datatype = rawDatatype;
|
|
45
|
+
break;
|
|
46
|
+
default:
|
|
47
|
+
throw new Error("invalid TFJS model encoding: invalid DataType");
|
|
39
48
|
}
|
|
40
|
-
|
|
41
|
-
|
|
49
|
+
return await models.TFJS.deserialize([
|
|
50
|
+
datatype,
|
|
51
|
+
// TODO totally unsafe casting
|
|
52
|
+
rawModel,
|
|
53
|
+
]);
|
|
54
|
+
}
|
|
42
55
|
case Type.GPT: {
|
|
43
56
|
let config;
|
|
44
57
|
if (raw.length == 2) {
|
package/dist/task/task.d.ts
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
import { DataType } from "../index.js";
|
|
1
2
|
import { type DisplayInformation } from './display_information.js';
|
|
2
3
|
import { type TrainingInformation } from './training_information.js';
|
|
3
4
|
export type TaskID = string;
|
|
4
|
-
export interface Task {
|
|
5
|
+
export interface Task<D extends DataType> {
|
|
5
6
|
id: TaskID;
|
|
6
7
|
displayInformation: DisplayInformation;
|
|
7
|
-
trainingInformation: TrainingInformation
|
|
8
|
+
trainingInformation: TrainingInformation<D>;
|
|
8
9
|
}
|
|
9
10
|
export declare function isTaskID(obj: unknown): obj is TaskID;
|
|
10
|
-
export declare function isTask(raw: unknown): raw is Task
|
|
11
|
+
export declare function isTask(raw: unknown): raw is Task<DataType>;
|
package/dist/task/task.js
CHANGED
|
@@ -13,8 +13,10 @@ export function isTask(raw) {
|
|
|
13
13
|
!isTrainingInformation(trainingInformation)) {
|
|
14
14
|
return false;
|
|
15
15
|
}
|
|
16
|
-
const
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
const _ = {
|
|
17
|
+
id,
|
|
18
|
+
displayInformation,
|
|
19
|
+
trainingInformation,
|
|
20
|
+
};
|
|
19
21
|
return true;
|
|
20
22
|
}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { Map } from "immutable";
|
|
2
|
-
import type { Model } from "../index.js";
|
|
2
|
+
import type { DataType, Model } from "../index.js";
|
|
3
3
|
import type { Task, TaskID } from "./task.js";
|
|
4
|
-
export declare function pushTask(base: URL, task: Task
|
|
5
|
-
export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task
|
|
4
|
+
export declare function pushTask<D extends DataType>(base: URL, task: Task<D>, model: Model<D>): Promise<void>;
|
|
5
|
+
export declare function fetchTasks(base: URL): Promise<Map<TaskID, Task<DataType>>>;
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
import type { Model, Task } from
|
|
2
|
-
export interface TaskProvider {
|
|
3
|
-
getTask
|
|
4
|
-
getModel
|
|
1
|
+
import type { DataType, Model, Task } from "../index.js";
|
|
2
|
+
export interface TaskProvider<D extends DataType> {
|
|
3
|
+
getTask(): Task<D>;
|
|
4
|
+
getModel(): Promise<Model<D>>;
|
|
5
5
|
}
|
|
@@ -1,29 +1,38 @@
|
|
|
1
|
-
import
|
|
2
|
-
import {
|
|
1
|
+
import { PreTrainedTokenizer } from "@xenova/transformers";
|
|
2
|
+
import { DataType } from "../index.js";
|
|
3
3
|
interface Privacy {
|
|
4
4
|
clippingRadius?: number;
|
|
5
5
|
noiseScale?: number;
|
|
6
6
|
}
|
|
7
|
-
export
|
|
7
|
+
export type TrainingInformation<D extends DataType> = {
|
|
8
8
|
epochs: number;
|
|
9
9
|
roundDuration: number;
|
|
10
10
|
validationSplit: number;
|
|
11
11
|
batchSize: number;
|
|
12
|
-
|
|
13
|
-
dataType: 'image' | 'tabular' | 'text';
|
|
14
|
-
inputColumns?: string[];
|
|
15
|
-
outputColumns?: string[];
|
|
16
|
-
IMAGE_H?: number;
|
|
17
|
-
IMAGE_W?: number;
|
|
18
|
-
LABEL_LIST?: string[];
|
|
19
|
-
scheme: 'decentralized' | 'federated' | 'local';
|
|
12
|
+
scheme: "decentralized" | "federated" | "local";
|
|
20
13
|
privacy?: Privacy;
|
|
21
14
|
maxShareValue?: number;
|
|
22
15
|
minNbOfParticipants: number;
|
|
23
|
-
aggregationStrategy?:
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
16
|
+
aggregationStrategy?: "mean" | "secure";
|
|
17
|
+
tensorBackend: "tfjs" | "gpt";
|
|
18
|
+
} & DataTypeToTrainingInformation[D];
|
|
19
|
+
interface DataTypeToTrainingInformation {
|
|
20
|
+
image: {
|
|
21
|
+
dataType: "image";
|
|
22
|
+
LABEL_LIST: string[];
|
|
23
|
+
IMAGE_H: number;
|
|
24
|
+
IMAGE_W: number;
|
|
25
|
+
};
|
|
26
|
+
tabular: {
|
|
27
|
+
dataType: "tabular";
|
|
28
|
+
inputColumns: string[];
|
|
29
|
+
outputColumn: string;
|
|
30
|
+
};
|
|
31
|
+
text: {
|
|
32
|
+
dataType: "text";
|
|
33
|
+
tokenizer: string | PreTrainedTokenizer;
|
|
34
|
+
maxSequenceLength?: number;
|
|
35
|
+
};
|
|
27
36
|
}
|
|
28
|
-
export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation
|
|
37
|
+
export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation<DataType>;
|
|
29
38
|
export {};
|
|
@@ -1,11 +1,4 @@
|
|
|
1
|
-
import { PreTrainedTokenizer } from
|
|
2
|
-
function isStringArray(raw) {
|
|
3
|
-
if (!Array.isArray(raw)) {
|
|
4
|
-
return false;
|
|
5
|
-
}
|
|
6
|
-
const arr = raw; // isArray is unsafely guarding with any[]
|
|
7
|
-
return arr.every((e) => typeof e === 'string');
|
|
8
|
-
}
|
|
1
|
+
import { PreTrainedTokenizer } from "@xenova/transformers";
|
|
9
2
|
function isPrivacy(raw) {
|
|
10
3
|
if (typeof raw !== "object" || raw === null) {
|
|
11
4
|
return false;
|
|
@@ -21,89 +14,100 @@ function isPrivacy(raw) {
|
|
|
21
14
|
return true;
|
|
22
15
|
}
|
|
23
16
|
export function isTrainingInformation(raw) {
|
|
24
|
-
if (typeof raw !==
|
|
17
|
+
if (typeof raw !== "object" || raw === null) {
|
|
25
18
|
return false;
|
|
26
19
|
}
|
|
27
|
-
const {
|
|
28
|
-
if (typeof
|
|
29
|
-
typeof
|
|
30
|
-
typeof
|
|
31
|
-
typeof
|
|
32
|
-
typeof
|
|
33
|
-
typeof minNbOfParticipants !== 'number' ||
|
|
34
|
-
(tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
35
|
-
(maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
|
|
36
|
-
(aggregationStrategy !== undefined && typeof aggregationStrategy !== 'string') ||
|
|
20
|
+
const { aggregationStrategy, batchSize, dataType, privacy, epochs, maxShareValue, minNbOfParticipants, roundDuration, scheme, validationSplit, tensorBackend, } = raw;
|
|
21
|
+
if (typeof epochs !== "number" ||
|
|
22
|
+
typeof batchSize !== "number" ||
|
|
23
|
+
typeof roundDuration !== "number" ||
|
|
24
|
+
typeof validationSplit !== "number" ||
|
|
25
|
+
typeof minNbOfParticipants !== "number" ||
|
|
37
26
|
(privacy !== undefined && !isPrivacy(privacy)) ||
|
|
38
|
-
(maxShareValue !== undefined && typeof maxShareValue !==
|
|
39
|
-
(IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
|
|
40
|
-
(IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
|
|
41
|
-
(LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
|
|
42
|
-
(inputColumns !== undefined && !isStringArray(inputColumns)) ||
|
|
43
|
-
(outputColumns !== undefined && !isStringArray(outputColumns)) ||
|
|
44
|
-
(preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
|
|
27
|
+
(maxShareValue !== undefined && typeof maxShareValue !== "number")) {
|
|
45
28
|
return false;
|
|
46
29
|
}
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
}
|
|
54
|
-
switch (dataType) {
|
|
55
|
-
case 'image': break;
|
|
56
|
-
case 'tabular': break;
|
|
57
|
-
case 'text': break;
|
|
58
|
-
default: return false;
|
|
59
|
-
}
|
|
60
|
-
// interdependencies on data type
|
|
61
|
-
if (dataType === 'image') {
|
|
62
|
-
if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
|
|
63
|
-
return false;
|
|
64
|
-
}
|
|
65
|
-
}
|
|
66
|
-
else if (dataType in ['text', 'tabular']) {
|
|
67
|
-
if (!(Array.isArray(inputColumns) && inputColumns.every((e) => typeof e === 'string'))) {
|
|
68
|
-
return false;
|
|
69
|
-
}
|
|
70
|
-
if (!(Array.isArray(outputColumns) && outputColumns.every((e) => typeof e === 'string'))) {
|
|
30
|
+
switch (aggregationStrategy) {
|
|
31
|
+
case undefined:
|
|
32
|
+
case "mean":
|
|
33
|
+
case "secure":
|
|
34
|
+
break;
|
|
35
|
+
default:
|
|
71
36
|
return false;
|
|
72
|
-
}
|
|
73
37
|
}
|
|
74
38
|
switch (tensorBackend) {
|
|
75
|
-
case
|
|
76
|
-
case
|
|
77
|
-
|
|
39
|
+
case "tfjs":
|
|
40
|
+
case "gpt":
|
|
41
|
+
break;
|
|
42
|
+
default:
|
|
43
|
+
return false;
|
|
78
44
|
}
|
|
79
45
|
switch (scheme) {
|
|
80
|
-
case
|
|
81
|
-
case
|
|
82
|
-
case
|
|
83
|
-
|
|
46
|
+
case "decentralized":
|
|
47
|
+
case "federated":
|
|
48
|
+
case "local":
|
|
49
|
+
break;
|
|
50
|
+
default:
|
|
51
|
+
return false;
|
|
84
52
|
}
|
|
85
53
|
const repack = {
|
|
86
|
-
IMAGE_W,
|
|
87
|
-
IMAGE_H,
|
|
88
|
-
LABEL_LIST,
|
|
89
54
|
aggregationStrategy,
|
|
90
55
|
batchSize,
|
|
91
|
-
dataType,
|
|
92
|
-
privacy,
|
|
93
56
|
epochs,
|
|
94
|
-
inputColumns,
|
|
95
57
|
maxShareValue,
|
|
96
58
|
minNbOfParticipants,
|
|
97
|
-
|
|
98
|
-
preprocessingFunctions,
|
|
59
|
+
privacy,
|
|
99
60
|
roundDuration,
|
|
100
61
|
scheme,
|
|
62
|
+
tensorBackend,
|
|
101
63
|
validationSplit,
|
|
102
|
-
tokenizer,
|
|
103
|
-
maxSequenceLength,
|
|
104
|
-
tensorBackend
|
|
105
64
|
};
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
65
|
+
switch (dataType) {
|
|
66
|
+
case "image": {
|
|
67
|
+
const { LABEL_LIST, IMAGE_W, IMAGE_H } = raw;
|
|
68
|
+
if (!(Array.isArray(LABEL_LIST) &&
|
|
69
|
+
LABEL_LIST.every((e) => typeof e === "string")) ||
|
|
70
|
+
typeof IMAGE_H !== "number" ||
|
|
71
|
+
typeof IMAGE_W !== "number")
|
|
72
|
+
return false;
|
|
73
|
+
const _ = {
|
|
74
|
+
...repack,
|
|
75
|
+
dataType,
|
|
76
|
+
LABEL_LIST,
|
|
77
|
+
IMAGE_W,
|
|
78
|
+
IMAGE_H,
|
|
79
|
+
};
|
|
80
|
+
return true;
|
|
81
|
+
}
|
|
82
|
+
case "tabular": {
|
|
83
|
+
const { inputColumns, outputColumn } = raw;
|
|
84
|
+
if (!(Array.isArray(inputColumns) &&
|
|
85
|
+
inputColumns.every((e) => typeof e === "string")) ||
|
|
86
|
+
typeof outputColumn !== "string")
|
|
87
|
+
return false;
|
|
88
|
+
const _ = {
|
|
89
|
+
...repack,
|
|
90
|
+
dataType,
|
|
91
|
+
inputColumns,
|
|
92
|
+
outputColumn,
|
|
93
|
+
};
|
|
94
|
+
return true;
|
|
95
|
+
}
|
|
96
|
+
case "text": {
|
|
97
|
+
const { maxSequenceLength, tokenizer, } = raw;
|
|
98
|
+
if ((typeof tokenizer !== "string" &&
|
|
99
|
+
!(tokenizer instanceof PreTrainedTokenizer)) ||
|
|
100
|
+
(maxSequenceLength !== undefined &&
|
|
101
|
+
typeof maxSequenceLength !== "number"))
|
|
102
|
+
return false;
|
|
103
|
+
const _ = {
|
|
104
|
+
...repack,
|
|
105
|
+
dataType,
|
|
106
|
+
maxSequenceLength,
|
|
107
|
+
tokenizer,
|
|
108
|
+
};
|
|
109
|
+
return true;
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
return false;
|
|
109
113
|
}
|