@epfml/discojs 3.0.1-p20250729132444.0 → 3.0.1-p20250924113522.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/get.d.ts +3 -3
- package/dist/aggregator/get.js +1 -2
- package/dist/client/client.d.ts +6 -6
- package/dist/client/decentralized/decentralized_client.d.ts +1 -1
- package/dist/client/decentralized/peer_pool.d.ts +1 -1
- package/dist/client/federated/federated_client.d.ts +1 -1
- package/dist/client/local_client.d.ts +1 -1
- package/dist/client/utils.d.ts +2 -2
- package/dist/client/utils.js +19 -10
- package/dist/default_tasks/cifar10.d.ts +2 -2
- package/dist/default_tasks/cifar10.js +9 -8
- package/dist/default_tasks/lus_covid.d.ts +2 -2
- package/dist/default_tasks/lus_covid.js +9 -8
- package/dist/default_tasks/mnist.d.ts +2 -2
- package/dist/default_tasks/mnist.js +9 -8
- package/dist/default_tasks/simple_face.d.ts +2 -2
- package/dist/default_tasks/simple_face.js +9 -8
- package/dist/default_tasks/tinder_dog.d.ts +1 -1
- package/dist/default_tasks/tinder_dog.js +12 -10
- package/dist/default_tasks/titanic.d.ts +2 -2
- package/dist/default_tasks/titanic.js +20 -33
- package/dist/default_tasks/wikitext.d.ts +2 -2
- package/dist/default_tasks/wikitext.js +16 -13
- package/dist/index.d.ts +1 -1
- package/dist/index.js +1 -1
- package/dist/models/gpt/config.d.ts +2 -2
- package/dist/models/hellaswag.d.ts +2 -3
- package/dist/models/hellaswag.js +3 -4
- package/dist/models/index.d.ts +2 -3
- package/dist/models/index.js +2 -3
- package/dist/models/tokenizer.d.ts +24 -14
- package/dist/models/tokenizer.js +42 -21
- package/dist/processing/index.d.ts +4 -5
- package/dist/processing/index.js +16 -21
- package/dist/serialization/coder.d.ts +5 -1
- package/dist/serialization/coder.js +4 -1
- package/dist/serialization/index.d.ts +4 -0
- package/dist/serialization/index.js +1 -0
- package/dist/serialization/task.d.ts +5 -0
- package/dist/serialization/task.js +34 -0
- package/dist/task/display_information.d.ts +91 -14
- package/dist/task/display_information.js +34 -58
- package/dist/task/index.d.ts +5 -5
- package/dist/task/index.js +4 -3
- package/dist/task/task.d.ts +837 -10
- package/dist/task/task.js +49 -21
- package/dist/task/task_handler.d.ts +4 -4
- package/dist/task/task_handler.js +14 -18
- package/dist/task/task_provider.d.ts +3 -3
- package/dist/task/training_information.d.ts +157 -35
- package/dist/task/training_information.js +85 -110
- package/dist/training/disco.d.ts +8 -8
- package/dist/training/disco.js +2 -1
- package/dist/training/trainer.d.ts +3 -3
- package/dist/training/trainer.js +2 -1
- package/dist/types/index.d.ts +1 -0
- package/dist/validator.d.ts +4 -4
- package/dist/validator.js +7 -6
- package/package.json +4 -7
- package/dist/processing/text.d.ts +0 -21
- package/dist/processing/text.js +0 -36
- package/dist/task/data_example.d.ts +0 -5
- package/dist/task/data_example.js +0 -14
- package/dist/task/summary.d.ts +0 -5
- package/dist/task/summary.js +0 -13
package/dist/aggregator/get.d.ts
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import type { DataType, Task } from '../index.js';
|
|
1
|
+
import type { DataType, Network, Task } from '../index.js';
|
|
2
2
|
import { aggregator } from '../index.js';
|
|
3
3
|
type AggregatorOptions = Partial<{
|
|
4
|
-
scheme: Task<DataType>["trainingInformation"]["scheme"];
|
|
4
|
+
scheme: Task<DataType, Network>["trainingInformation"]["scheme"];
|
|
5
5
|
roundCutOff: number;
|
|
6
6
|
threshold: number;
|
|
7
7
|
thresholdType: 'relative' | 'absolute';
|
|
@@ -24,5 +24,5 @@ type AggregatorOptions = Partial<{
|
|
|
24
24
|
* @param options Options passed down to the aggregator's constructor
|
|
25
25
|
* @returns The aggregator
|
|
26
26
|
*/
|
|
27
|
-
export declare function getAggregator(task: Task<DataType>, options?: AggregatorOptions): aggregator.Aggregator;
|
|
27
|
+
export declare function getAggregator(task: Task<DataType, Network>, options?: AggregatorOptions): aggregator.Aggregator;
|
|
28
28
|
export {};
|
package/dist/aggregator/get.js
CHANGED
|
@@ -18,9 +18,8 @@ import { aggregator } from '../index.js';
|
|
|
18
18
|
* @returns The aggregator
|
|
19
19
|
*/
|
|
20
20
|
export function getAggregator(task, options = {}) {
|
|
21
|
-
const aggregationStrategy = task.trainingInformation.aggregationStrategy ?? 'mean';
|
|
22
21
|
const scheme = options.scheme ?? task.trainingInformation.scheme;
|
|
23
|
-
switch (aggregationStrategy) {
|
|
22
|
+
switch (task.trainingInformation.aggregationStrategy) {
|
|
24
23
|
case 'mean':
|
|
25
24
|
if (scheme === 'decentralized') {
|
|
26
25
|
// If options are not specified, we default to expecting a contribution from all peers, so we set the threshold to 100%
|
package/dist/client/client.d.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { DataType, Model, RoundStatus, Task, WeightsContainer } from "../index.js";
|
|
1
|
+
import type { DataType, Model, Network, RoundStatus, Task, WeightsContainer } from "../index.js";
|
|
2
2
|
import type { NodeID } from './types.js';
|
|
3
3
|
import type { EventConnection } from './event_connection.js';
|
|
4
4
|
import type { Aggregator } from '../aggregator/index.js';
|
|
@@ -7,13 +7,13 @@ import { EventEmitter } from '../utils/event_emitter.js';
|
|
|
7
7
|
* Main, abstract, class representing a Disco client in a network, which handles
|
|
8
8
|
* communication with other nodes, be it peers or a server.
|
|
9
9
|
*/
|
|
10
|
-
export declare abstract class Client extends EventEmitter<{
|
|
11
|
-
|
|
12
|
-
|
|
10
|
+
export declare abstract class Client<N extends Network> extends EventEmitter<{
|
|
11
|
+
status: RoundStatus;
|
|
12
|
+
participants: number;
|
|
13
13
|
}> {
|
|
14
14
|
#private;
|
|
15
15
|
readonly url: URL;
|
|
16
|
-
readonly task: Task<DataType>;
|
|
16
|
+
readonly task: Task<DataType, N>;
|
|
17
17
|
readonly aggregator: Aggregator;
|
|
18
18
|
protected _ownId?: NodeID;
|
|
19
19
|
protected _server?: EventConnection;
|
|
@@ -25,7 +25,7 @@ export declare abstract class Client extends EventEmitter<{
|
|
|
25
25
|
*/
|
|
26
26
|
protected promiseForMoreParticipants: Promise<void> | undefined;
|
|
27
27
|
constructor(url: URL, // The network server's URL to connect to
|
|
28
|
-
task: Task<DataType>, // The client's corresponding task
|
|
28
|
+
task: Task<DataType, N>, // The client's corresponding task
|
|
29
29
|
aggregator: Aggregator);
|
|
30
30
|
/**
|
|
31
31
|
* Communication callback called at the beginning of every training round.
|
|
@@ -6,7 +6,7 @@ import { Client } from '../client.js';
|
|
|
6
6
|
* with the server is based off regular WebSockets, whereas peer-to-peer communication uses
|
|
7
7
|
* WebRTC for Node.js.
|
|
8
8
|
*/
|
|
9
|
-
export declare class DecentralizedClient extends Client {
|
|
9
|
+
export declare class DecentralizedClient extends Client<"decentralized"> {
|
|
10
10
|
#private;
|
|
11
11
|
private get isDisconnected();
|
|
12
12
|
private setAggregatorNodes;
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import { Map, type Set } from 'immutable';
|
|
2
2
|
import { type SignalData } from './peer.js';
|
|
3
|
-
import {
|
|
3
|
+
import type { NodeID } from '../types.js';
|
|
4
4
|
import { PeerConnection, type EventConnection } from '../event_connection.js';
|
|
5
5
|
export declare class PeerPool {
|
|
6
6
|
private readonly id;
|
|
@@ -4,7 +4,7 @@ import { Client } from "../client.js";
|
|
|
4
4
|
* Client class that communicates with a centralized, federated server, when training
|
|
5
5
|
* a specific task in the federated setting.
|
|
6
6
|
*/
|
|
7
|
-
export declare class FederatedClient extends Client {
|
|
7
|
+
export declare class FederatedClient extends Client<"federated"> {
|
|
8
8
|
/**
|
|
9
9
|
* Initializes the connection to the server, gets our node ID
|
|
10
10
|
* as well as the latest training information: latest global model, current round and
|
|
@@ -4,7 +4,7 @@ import { Client } from "./client.js";
|
|
|
4
4
|
* A LocalClient represents a Disco user training only on their local data without collaborating
|
|
5
5
|
* with anyone. Thus LocalClient doesn't do anything during communication
|
|
6
6
|
*/
|
|
7
|
-
export declare class LocalClient extends Client {
|
|
7
|
+
export declare class LocalClient extends Client<"local"> {
|
|
8
8
|
onRoundBeginCommunication(): Promise<void>;
|
|
9
9
|
onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
|
|
10
10
|
}
|
package/dist/client/utils.d.ts
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import type { DataType, Task } from
|
|
1
|
+
import type { DataType, Network, Task } from "../index.js";
|
|
2
2
|
import { client as clients, type aggregator } from '../index.js';
|
|
3
3
|
export declare function timeout(ms?: number, errorMsg?: string): Promise<never>;
|
|
4
|
-
export declare function getClient(
|
|
4
|
+
export declare function getClient<D extends DataType, N extends Network>(scheme: N | "local", serverURL: URL, task: Task<D, N>, aggregator: aggregator.Aggregator): clients.Client<N>;
|
package/dist/client/utils.js
CHANGED
|
@@ -6,17 +6,26 @@ export async function timeout(ms = MAX_WAIT_PER_ROUND, errorMsg = 'timeout') {
|
|
|
6
6
|
setTimeout(() => { reject(new Error(errorMsg)); }, ms);
|
|
7
7
|
});
|
|
8
8
|
}
|
|
9
|
-
export function getClient(
|
|
10
|
-
switch (
|
|
11
|
-
case
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
return new clients.
|
|
15
|
-
|
|
16
|
-
|
|
9
|
+
export function getClient(scheme, serverURL, task, aggregator) {
|
|
10
|
+
switch (scheme) {
|
|
11
|
+
case "decentralized": {
|
|
12
|
+
const t = task;
|
|
13
|
+
t.trainingInformation.scheme = scheme;
|
|
14
|
+
return new clients.decentralized.DecentralizedClient(serverURL, t, aggregator);
|
|
15
|
+
}
|
|
16
|
+
case "federated": {
|
|
17
|
+
const t = task;
|
|
18
|
+
t.trainingInformation.scheme = scheme;
|
|
19
|
+
return new clients.federated.FederatedClient(serverURL, t, aggregator);
|
|
20
|
+
}
|
|
21
|
+
case "local": {
|
|
22
|
+
const t = task;
|
|
23
|
+
t.trainingInformation.scheme = scheme;
|
|
24
|
+
return new clients.LocalClient(serverURL, t, aggregator);
|
|
25
|
+
}
|
|
17
26
|
default: {
|
|
18
|
-
const _ =
|
|
19
|
-
throw new Error(
|
|
27
|
+
const _ = scheme;
|
|
28
|
+
throw new Error("should never happen");
|
|
20
29
|
}
|
|
21
30
|
}
|
|
22
31
|
}
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
import type { TaskProvider } from
|
|
2
|
-
export declare const cifar10: TaskProvider<
|
|
1
|
+
import type { TaskProvider } from "../index.js";
|
|
2
|
+
export declare const cifar10: TaskProvider<"image", "decentralized">;
|
|
@@ -3,27 +3,28 @@ import { models } from '../index.js';
|
|
|
3
3
|
import baseModel from '../models/mobileNet_v1_025_224.js';
|
|
4
4
|
export const cifar10 = {
|
|
5
5
|
getTask() {
|
|
6
|
-
return {
|
|
6
|
+
return Promise.resolve({
|
|
7
7
|
id: 'cifar10',
|
|
8
|
+
dataType: "image",
|
|
8
9
|
displayInformation: {
|
|
9
|
-
|
|
10
|
+
title: 'CIFAR10',
|
|
10
11
|
summary: {
|
|
11
12
|
preview: 'CIFAR-10 is a classic image classification task, and one of the most widely used datasets for machine learning research.',
|
|
12
13
|
overview: "The dataset contains 60,000 32x32 color images in 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The official CIFAR-10 website can be found <a class='underline text-blue-400' href='https://www.cs.toronto.edu/~kriz/cifar.html' target='_blank'>here</a>. You can find a link to a sample dataset at the next step (Connect Your Data)."
|
|
13
14
|
},
|
|
14
15
|
model: 'The model is a pretrained <a class="underline text-blue-400" target="_blank" href="https://github.com/tensorflow/tfjs-models/tree/master/mobilenet">MobileNetV1 model</a> trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.',
|
|
15
16
|
dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The CSV file should start with the exact header "filename,label", and each row should contain an image filename (without extension) and its label.<br><br> For example if you have images: 0.png (of a frog) and 1.png (of a car) <br> The CSV file should be: <br>filename, label <br><br> 0, frog <br> 1, car',
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
17
|
+
dataExample: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png",
|
|
18
|
+
sampleDataset: {
|
|
19
|
+
link: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data.tar.gz",
|
|
20
|
+
instructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, use the CSV option below and select the file named "cifar10-labels.csv". You can now connect the images located in the "CIFAR10" folder. Note that there are only 24 images in this sample dataset which is far too few to successfully train a machine learning model.',
|
|
21
|
+
},
|
|
20
22
|
},
|
|
21
23
|
trainingInformation: {
|
|
22
24
|
epochs: 10,
|
|
23
25
|
roundDuration: 10,
|
|
24
26
|
validationSplit: 0.2,
|
|
25
27
|
batchSize: 10,
|
|
26
|
-
dataType: 'image',
|
|
27
28
|
IMAGE_H: 224,
|
|
28
29
|
IMAGE_W: 224,
|
|
29
30
|
LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
|
|
@@ -34,7 +35,7 @@ export const cifar10 = {
|
|
|
34
35
|
maxShareValue: 100,
|
|
35
36
|
tensorBackend: 'tfjs'
|
|
36
37
|
}
|
|
37
|
-
};
|
|
38
|
+
});
|
|
38
39
|
},
|
|
39
40
|
async getModel() {
|
|
40
41
|
const mobilenet = await tf.loadLayersModel({
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
import type { TaskProvider } from
|
|
2
|
-
export declare const lusCovid: TaskProvider<
|
|
1
|
+
import type { TaskProvider } from "../index.js";
|
|
2
|
+
export declare const lusCovid: TaskProvider<"image", "federated">;
|
|
@@ -2,20 +2,22 @@ import * as tf from '@tensorflow/tfjs';
|
|
|
2
2
|
import { models } from '../index.js';
|
|
3
3
|
export const lusCovid = {
|
|
4
4
|
getTask() {
|
|
5
|
-
return {
|
|
5
|
+
return Promise.resolve({
|
|
6
6
|
id: 'lus_covid',
|
|
7
|
+
dataType: "image",
|
|
7
8
|
displayInformation: {
|
|
8
|
-
|
|
9
|
+
title: 'Lung Ultrasound Image Classification',
|
|
9
10
|
summary: {
|
|
10
11
|
preview: "Medical images are a typical example of data that exists in huge quantity yet that can't be shared due to confidentiality reasons. Medical applications would immensely benefit from training on data currently locked. More data diversity leads to better generalization and bias mitigation.",
|
|
11
12
|
overview: "Disco allows data owners to collaboratively train machine learning models using their respective data without any privacy breach. This example problem is about diagnosing whether patients are positive or negative to COVID-19 from lung ultrasounds images. <br>Don't have a dataset of your own? You can find a link to a sample dataset at the next step."
|
|
12
13
|
},
|
|
13
14
|
model: "The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 100x100 pixels and normalizes values between 0 and 1",
|
|
14
15
|
dataFormatInformation: 'This model takes as input an image dataset of lung ultrasounds. The images are resized automatically.',
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
dataExample: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/2_QAID_1.masked.reshaped.squared.224.png",
|
|
17
|
+
sampleDataset: {
|
|
18
|
+
link: "https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly",
|
|
19
|
+
instructions: 'Opening the link will take you to a Switch Drive folder. You can click on the Download button in the top right corner. Unzip the file and you will get two subfolders: "COVID-" and "COVID+". You can connect the data by using the Group option and selecting each image group in its respective field.',
|
|
20
|
+
},
|
|
19
21
|
},
|
|
20
22
|
trainingInformation: {
|
|
21
23
|
epochs: 50,
|
|
@@ -25,13 +27,12 @@ export const lusCovid = {
|
|
|
25
27
|
IMAGE_H: 100,
|
|
26
28
|
IMAGE_W: 100,
|
|
27
29
|
LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
|
|
28
|
-
dataType: 'image',
|
|
29
30
|
scheme: 'federated',
|
|
30
31
|
aggregationStrategy: 'mean',
|
|
31
32
|
minNbOfParticipants: 2,
|
|
32
33
|
tensorBackend: 'tfjs'
|
|
33
34
|
}
|
|
34
|
-
};
|
|
35
|
+
});
|
|
35
36
|
},
|
|
36
37
|
// Model architecture from tensorflow.js docs:
|
|
37
38
|
// https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
import type { TaskProvider } from
|
|
2
|
-
export declare const mnist: TaskProvider<
|
|
1
|
+
import type { TaskProvider } from "../index.js";
|
|
2
|
+
export declare const mnist: TaskProvider<"image", "decentralized">;
|
|
@@ -2,27 +2,28 @@ import * as tf from '@tensorflow/tfjs';
|
|
|
2
2
|
import { models } from '../index.js';
|
|
3
3
|
export const mnist = {
|
|
4
4
|
getTask() {
|
|
5
|
-
return {
|
|
5
|
+
return Promise.resolve({
|
|
6
6
|
id: 'mnist',
|
|
7
|
+
dataType: "image",
|
|
7
8
|
displayInformation: {
|
|
8
|
-
|
|
9
|
+
title: 'Handwritten Digit Recognition',
|
|
9
10
|
summary: {
|
|
10
11
|
preview: "The MNIST handwritten digit classification problem is a classic dataset used in computer vision and deep learning. The objective is to classify handwritten digits from 28x28 pixel images.",
|
|
11
12
|
overview: "Download the classic MNIST dataset of hand-written numbers <a class='underline text-blue-400' target='_blank' href='https://www.kaggle.com/scolianni/mnistasjpg'>here</a>. You can also find a sample dataset at the next step."
|
|
12
13
|
},
|
|
13
14
|
model: "The model is a simple Convolutional Neural Network composed of three convolutional layers with ReLU activations and max pooling layers, followed by two fully connected layers. The data preprocessing simply normalizes values between 0 and 1. The neural network is optimized via RMSProp and a categorical cross-entropy loss.",
|
|
14
15
|
dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can connect your own images of each digit in the box corresponding to its label. The model takes images of size 28x28 as input.',
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
dataExample: "http://storage.googleapis.com/deai-313515.appspot.com/example_training_data/9-mnist-example.png",
|
|
17
|
+
sampleDataset: {
|
|
18
|
+
link: "https://storage.googleapis.com/deai-313515.appspot.com/MNIST_samples.tar.gz",
|
|
19
|
+
instructions: 'Opening the link should start downloading a zip file which you can unzip. You can connect the data with the CSV option below using the CSV file named "mnist_labels.csv". After selecting in the CSV file, you will be able to connect the data under in the "images" folder.',
|
|
20
|
+
},
|
|
19
21
|
},
|
|
20
22
|
trainingInformation: {
|
|
21
23
|
epochs: 20,
|
|
22
24
|
roundDuration: 2,
|
|
23
25
|
validationSplit: 0.2,
|
|
24
26
|
batchSize: 64,
|
|
25
|
-
dataType: 'image',
|
|
26
27
|
IMAGE_H: 28,
|
|
27
28
|
IMAGE_W: 28,
|
|
28
29
|
LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
|
|
@@ -32,7 +33,7 @@ export const mnist = {
|
|
|
32
33
|
maxShareValue: 100,
|
|
33
34
|
tensorBackend: 'tfjs'
|
|
34
35
|
}
|
|
35
|
-
};
|
|
36
|
+
});
|
|
36
37
|
},
|
|
37
38
|
getModel() {
|
|
38
39
|
// Architecture from the PyTorch MNIST example (I made it slightly smaller, 650kB instead of 5MB)
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
import type { TaskProvider } from
|
|
2
|
-
export declare const simpleFace: TaskProvider<
|
|
1
|
+
import type { TaskProvider } from "../index.js";
|
|
2
|
+
export declare const simpleFace: TaskProvider<"image", "federated">;
|
|
@@ -3,26 +3,27 @@ import { models } from '../index.js';
|
|
|
3
3
|
import baseModel from '../models/mobileNetV2_35_alpha_2_classes.js';
|
|
4
4
|
export const simpleFace = {
|
|
5
5
|
getTask() {
|
|
6
|
-
return {
|
|
6
|
+
return Promise.resolve({
|
|
7
7
|
id: 'simple_face',
|
|
8
|
+
dataType: "image",
|
|
8
9
|
displayInformation: {
|
|
9
|
-
|
|
10
|
+
title: 'Simple Face',
|
|
10
11
|
summary: {
|
|
11
12
|
preview: 'Can you detect if the person in a picture is a child or an adult?',
|
|
12
13
|
overview: 'Simple face is a small subset of the public face_task dataset from Kaggle'
|
|
13
14
|
},
|
|
14
15
|
dataFormatInformation: '',
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
dataExample: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png",
|
|
17
|
+
sampleDataset: {
|
|
18
|
+
link: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data.tar.gz",
|
|
19
|
+
instructions: 'Opening the link should start downloading a zip file which you can unzip. Inside the "example_training_data" directory you should find the "simple_face" folder which contains the "adult" and "child" folders. To connect the data, select the Group option below and connect adults and children image groups.',
|
|
20
|
+
},
|
|
19
21
|
},
|
|
20
22
|
trainingInformation: {
|
|
21
23
|
epochs: 50,
|
|
22
24
|
roundDuration: 1,
|
|
23
25
|
validationSplit: 0.2,
|
|
24
26
|
batchSize: 10,
|
|
25
|
-
dataType: 'image',
|
|
26
27
|
IMAGE_H: 200,
|
|
27
28
|
IMAGE_W: 200,
|
|
28
29
|
LABEL_LIST: ['child', 'adult'],
|
|
@@ -31,7 +32,7 @@ export const simpleFace = {
|
|
|
31
32
|
minNbOfParticipants: 2,
|
|
32
33
|
tensorBackend: 'tfjs'
|
|
33
34
|
}
|
|
34
|
-
};
|
|
35
|
+
});
|
|
35
36
|
},
|
|
36
37
|
async getModel() {
|
|
37
38
|
const model = await tf.loadLayersModel({
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
import type { TaskProvider } from '../index.js';
|
|
2
|
-
export declare const tinderDog: TaskProvider<
|
|
2
|
+
export declare const tinderDog: TaskProvider<"image", "federated">;
|
|
@@ -2,27 +2,28 @@ import * as tf from '@tensorflow/tfjs';
|
|
|
2
2
|
import { models } from '../index.js';
|
|
3
3
|
export const tinderDog = {
|
|
4
4
|
getTask() {
|
|
5
|
-
return {
|
|
5
|
+
return Promise.resolve({
|
|
6
6
|
id: 'tinder_dog',
|
|
7
|
+
dataType: "image",
|
|
7
8
|
displayInformation: {
|
|
8
|
-
|
|
9
|
+
title: "GDHF 2024 | TinderDog",
|
|
9
10
|
summary: {
|
|
10
11
|
preview: 'Which dog is the cutest....or not?',
|
|
11
12
|
overview: "Binary classification model for dog cuteness."
|
|
12
13
|
},
|
|
13
14
|
model: 'The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 64x64 pixels and normalizes values between 0 and 1',
|
|
14
15
|
dataFormatInformation: 'Accepted image formats are .png .jpg and .jpeg.',
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
16
|
+
dataExample: "https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog_preview.png",
|
|
17
|
+
sampleDataset: {
|
|
18
|
+
link: "https://storage.googleapis.com/deai-313515.appspot.com/tinder_dog.zip",
|
|
19
|
+
instructions: 'Opening the link should start downloading a zip file which you can unzip. To connect the data, pick one of the data splits (the folder 0 for example) and use the CSV option below to select the file named "labels.csv". You can now connect the images located in the same folder.',
|
|
20
|
+
},
|
|
19
21
|
},
|
|
20
22
|
trainingInformation: {
|
|
21
23
|
epochs: 10,
|
|
22
24
|
roundDuration: 2,
|
|
23
25
|
validationSplit: 0, // nicer plot for GDHF demo
|
|
24
26
|
batchSize: 10,
|
|
25
|
-
dataType: 'image',
|
|
26
27
|
IMAGE_H: 64,
|
|
27
28
|
IMAGE_W: 64,
|
|
28
29
|
LABEL_LIST: ['Cute dogs', 'Less cute dogs'],
|
|
@@ -31,12 +32,13 @@ export const tinderDog = {
|
|
|
31
32
|
minNbOfParticipants: 3,
|
|
32
33
|
tensorBackend: 'tfjs'
|
|
33
34
|
}
|
|
34
|
-
};
|
|
35
|
+
});
|
|
35
36
|
},
|
|
36
37
|
async getModel() {
|
|
38
|
+
const task = await this.getTask();
|
|
37
39
|
const seed = 42; // set a seed to ensure reproducibility during GDHF demo
|
|
38
|
-
const imageHeight =
|
|
39
|
-
const imageWidth =
|
|
40
|
+
const imageHeight = task.trainingInformation.IMAGE_H;
|
|
41
|
+
const imageWidth = task.trainingInformation.IMAGE_W;
|
|
40
42
|
const imageChannels = 3;
|
|
41
43
|
const model = tf.sequential();
|
|
42
44
|
model.add(tf.layers.conv2d({
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
import type { TaskProvider } from
|
|
2
|
-
export declare const titanic: TaskProvider<
|
|
1
|
+
import type { TaskProvider } from "../index.js";
|
|
2
|
+
export declare const titanic: TaskProvider<"tabular", "federated">;
|
|
@@ -2,54 +2,41 @@ import * as tf from '@tensorflow/tfjs';
|
|
|
2
2
|
import { models } from '../index.js';
|
|
3
3
|
export const titanic = {
|
|
4
4
|
getTask() {
|
|
5
|
-
return {
|
|
5
|
+
return Promise.resolve({
|
|
6
6
|
id: 'titanic',
|
|
7
|
+
dataType: "tabular",
|
|
7
8
|
displayInformation: {
|
|
8
|
-
|
|
9
|
+
title: 'Titanic Prediction',
|
|
9
10
|
summary: {
|
|
10
11
|
preview: "The Titanic classification task is one of the main entrypoints into machine learning. Using passenger data (name, age, gender, socio-economic class, etc), the goal is to identify who was more likely to survive the infamous shipwreck.",
|
|
11
12
|
overview: "The original competition can be found on <a target='_blank' class='underline text-blue-400' href='https://www.kaggle.com/c/titanic'>Kaggle</a> and a link to the training set can be found here <a target='_blank' class='underline text-blue-400' href='https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv'>here</a>."
|
|
12
13
|
},
|
|
13
14
|
model: 'The model is a simple 5-layer feedforward network with ReLU activations. The model is optimized with Adam and binary cross-entropy loss. The preprocessing only fills missing value with a placeholder value (0).',
|
|
14
15
|
dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br>The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked"<br>Each subsequent row contains passenger data.',
|
|
15
|
-
dataExampleText: "Here's an example of one data point:",
|
|
16
16
|
dataExample: [
|
|
17
|
-
{
|
|
18
|
-
{
|
|
19
|
-
{
|
|
20
|
-
{
|
|
21
|
-
{
|
|
22
|
-
{
|
|
23
|
-
{
|
|
24
|
-
{
|
|
25
|
-
{
|
|
26
|
-
{
|
|
27
|
-
{
|
|
28
|
-
{
|
|
17
|
+
{ name: "PassengerId", data: "1" },
|
|
18
|
+
{ name: "Survived", data: "0" },
|
|
19
|
+
{ name: "Name", data: "Braund, Mr. Owen Harris" },
|
|
20
|
+
{ name: "Sex", data: "male" },
|
|
21
|
+
{ name: "Age", data: "22" },
|
|
22
|
+
{ name: "SibSp", data: "1" },
|
|
23
|
+
{ name: "Parch", data: "0" },
|
|
24
|
+
{ name: "Ticket", data: "1/5 21171" },
|
|
25
|
+
{ name: "Fare", data: "7.25" },
|
|
26
|
+
{ name: "Cabin", data: "E46" },
|
|
27
|
+
{ name: "Embarked", data: "S" },
|
|
28
|
+
{ name: "Pclass", data: "3" },
|
|
29
29
|
],
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
'Sex',
|
|
35
|
-
'Age',
|
|
36
|
-
'SibSp',
|
|
37
|
-
'Parch',
|
|
38
|
-
'Ticket',
|
|
39
|
-
'Fare',
|
|
40
|
-
'Cabin',
|
|
41
|
-
'Embarked',
|
|
42
|
-
'Pclass'
|
|
43
|
-
],
|
|
44
|
-
sampleDatasetLink: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv",
|
|
45
|
-
sampleDatasetInstructions: 'Opening the link should start downloading a CSV file which you can drag and drop in the field below.'
|
|
30
|
+
sampleDataset: {
|
|
31
|
+
link: "https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv",
|
|
32
|
+
instructions: "Opening the link should start downloading a CSV file which you can drag and drop in the field below.",
|
|
33
|
+
},
|
|
46
34
|
},
|
|
47
35
|
trainingInformation: {
|
|
48
36
|
epochs: 10,
|
|
49
37
|
roundDuration: 2,
|
|
50
38
|
validationSplit: 0.2,
|
|
51
39
|
batchSize: 30,
|
|
52
|
-
dataType: 'tabular',
|
|
53
40
|
inputColumns: [
|
|
54
41
|
'Age',
|
|
55
42
|
'SibSp',
|
|
@@ -63,7 +50,7 @@ export const titanic = {
|
|
|
63
50
|
minNbOfParticipants: 2,
|
|
64
51
|
tensorBackend: 'tfjs'
|
|
65
52
|
}
|
|
66
|
-
};
|
|
53
|
+
});
|
|
67
54
|
},
|
|
68
55
|
getModel() {
|
|
69
56
|
const model = tf.sequential();
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
import type { TaskProvider } from
|
|
2
|
-
export declare const wikitext: TaskProvider<
|
|
1
|
+
import type { TaskProvider } from "../index.js";
|
|
2
|
+
export declare const wikitext: TaskProvider<"text", "federated">;
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
import { models } from
|
|
1
|
+
import { Tokenizer, models } from "../index.js";
|
|
2
2
|
export const wikitext = {
|
|
3
|
-
getTask() {
|
|
3
|
+
async getTask() {
|
|
4
4
|
return {
|
|
5
5
|
id: 'llm_task',
|
|
6
|
+
dataType: "text",
|
|
6
7
|
displayInformation: {
|
|
7
|
-
|
|
8
|
+
title: "GPT Language Modeling",
|
|
8
9
|
summary: {
|
|
9
10
|
preview: 'Train a language model (L)LM in your browser, collaboratively and from scratch.',
|
|
10
11
|
overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling, which you can download <a class='underline text-blue-400' target='_blank' href='https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz'>here</a>. More information on how to connect the dataset at the next step."
|
|
@@ -17,12 +18,13 @@ export const wikitext = {
|
|
|
17
18
|
"To accommodate all devices, the context length is currently kept at 128 and the batch size at 1.",
|
|
18
19
|
].join(" "),
|
|
19
20
|
dataFormatInformation: 'You can use any natural language (text) dataset you like. For example the Wikitext-103 dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.',
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
21
|
+
dataExample: "For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work ( except eight private performances for Ludwig II at Munich in 1884 and 1885 ) .",
|
|
22
|
+
sampleDataset: {
|
|
23
|
+
link: "https://storage.googleapis.com/deai-313515.appspot.com/wikitext.zip",
|
|
24
|
+
instructions: 'Opening the link should start downloading a zip file. Unzip it and drag and drop the training set named "wiki.train.tokens" in the field below (or use the "Select File" button). Even though the file extension is ".tokens" it is indeed a text file. You can use "wiki.test.tokens" at the evaluation step after training a language model.',
|
|
25
|
+
},
|
|
23
26
|
},
|
|
24
27
|
trainingInformation: {
|
|
25
|
-
dataType: 'text',
|
|
26
28
|
scheme: 'federated',
|
|
27
29
|
aggregationStrategy: 'mean',
|
|
28
30
|
minNbOfParticipants: 2,
|
|
@@ -32,15 +34,16 @@ export const wikitext = {
|
|
|
32
34
|
validationSplit: 0.1,
|
|
33
35
|
roundDuration: 2,
|
|
34
36
|
batchSize: 8, // If set too high firefox raises a WebGL error
|
|
35
|
-
tokenizer:
|
|
37
|
+
tokenizer: await Tokenizer.from_pretrained("Xenova/gpt2"),
|
|
36
38
|
contextLength: 64,
|
|
37
39
|
tensorBackend: 'gpt'
|
|
38
40
|
}
|
|
39
41
|
};
|
|
40
42
|
},
|
|
41
|
-
getModel() {
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
43
|
+
async getModel() {
|
|
44
|
+
const task = await this.getTask();
|
|
45
|
+
return new models.GPT({
|
|
46
|
+
contextLength: task.trainingInformation.contextLength,
|
|
47
|
+
});
|
|
48
|
+
},
|
|
46
49
|
};
|
package/dist/index.d.ts
CHANGED
|
@@ -8,7 +8,7 @@ export { WeightsContainer, aggregation } from './weights/index.js';
|
|
|
8
8
|
export { Logger, ConsoleLogger } from './logging/index.js';
|
|
9
9
|
export { Disco, RoundLogs, RoundStatus } from './training/index.js';
|
|
10
10
|
export { Validator } from './validator.js';
|
|
11
|
-
export { Model, BatchLogs, EpochLogs, ValidationMetrics } from
|
|
11
|
+
export { Model, BatchLogs, EpochLogs, Tokenizer, ValidationMetrics, } from "./models/index.js";
|
|
12
12
|
export * as models from './models/index.js';
|
|
13
13
|
export * from './task/index.js';
|
|
14
14
|
export * as defaultTasks from './default_tasks/index.js';
|
package/dist/index.js
CHANGED
|
@@ -8,7 +8,7 @@ export { WeightsContainer, aggregation } from './weights/index.js';
|
|
|
8
8
|
export { ConsoleLogger } from './logging/index.js';
|
|
9
9
|
export { Disco } from './training/index.js';
|
|
10
10
|
export { Validator } from './validator.js';
|
|
11
|
-
export { Model, EpochLogs } from
|
|
11
|
+
export { Model, EpochLogs, Tokenizer, } from "./models/index.js";
|
|
12
12
|
export * as models from './models/index.js';
|
|
13
13
|
export * from './task/index.js';
|
|
14
14
|
export * as defaultTasks from './default_tasks/index.js';
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
type GPTModelType = 'gpt2' | 'gpt2-medium' | 'gpt2-large' | 'gpt2-xl' | 'gpt-mini' | 'gpt-micro' | 'gpt-nano';
|
|
2
|
-
export
|
|
2
|
+
export type GPTConfig = {
|
|
3
3
|
lr: number;
|
|
4
4
|
contextLength: number;
|
|
5
5
|
vocabSize?: number;
|
|
@@ -19,7 +19,7 @@ export interface GPTConfig {
|
|
|
19
19
|
nHead?: number;
|
|
20
20
|
nEmbd?: number;
|
|
21
21
|
seed?: number;
|
|
22
|
-
}
|
|
22
|
+
};
|
|
23
23
|
export declare const DefaultGPTConfig: Required<GPTConfig>;
|
|
24
24
|
export type ModelSize = {
|
|
25
25
|
nLayer: number;
|