@epfml/discojs 2.1.2-p20240515133413.0 → 2.1.2-p20240531085945.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.js +1 -0
- package/dist/aggregator/mean.d.ts +10 -15
- package/dist/aggregator/mean.js +36 -50
- package/dist/aggregator/secure.d.ts +5 -7
- package/dist/aggregator/secure.js +56 -44
- package/dist/client/federated/messages.d.ts +1 -8
- package/dist/client/federated/messages.js +1 -10
- package/dist/client/messages.d.ts +1 -3
- package/dist/client/messages.js +0 -2
- package/dist/dataset/dataset_builder.d.ts +2 -11
- package/dist/dataset/dataset_builder.js +22 -46
- package/dist/default_tasks/cifar10.d.ts +2 -0
- package/dist/default_tasks/{cifar10/index.js → cifar10.js} +2 -2
- package/dist/default_tasks/index.d.ts +3 -2
- package/dist/default_tasks/index.js +3 -2
- package/dist/default_tasks/lus_covid.js +1 -1
- package/dist/default_tasks/simple_face.d.ts +2 -0
- package/dist/default_tasks/{simple_face/index.js → simple_face.js} +3 -3
- package/dist/default_tasks/skin_condition.d.ts +2 -0
- package/dist/default_tasks/skin_condition.js +79 -0
- package/dist/models/gpt/config.d.ts +32 -0
- package/dist/models/gpt/config.js +42 -0
- package/dist/models/gpt/evaluate.d.ts +7 -0
- package/dist/models/gpt/evaluate.js +44 -0
- package/dist/models/gpt/index.d.ts +35 -0
- package/dist/models/gpt/index.js +104 -0
- package/dist/models/gpt/layers.d.ts +13 -0
- package/dist/models/gpt/layers.js +272 -0
- package/dist/models/gpt/model.d.ts +43 -0
- package/dist/models/gpt/model.js +191 -0
- package/dist/models/gpt/optimizers.d.ts +4 -0
- package/dist/models/gpt/optimizers.js +95 -0
- package/dist/models/index.d.ts +5 -0
- package/dist/models/index.js +4 -0
- package/dist/{default_tasks/simple_face/model.js → models/mobileNetV2_35_alpha_2_classes.js} +2 -0
- package/dist/{default_tasks/cifar10/model.js → models/mobileNet_v1_025_224.js} +1 -0
- package/dist/models/model.d.ts +51 -0
- package/dist/models/model.js +8 -0
- package/dist/models/tfjs.d.ts +24 -0
- package/dist/models/tfjs.js +107 -0
- package/dist/models/tokenizer.d.ts +14 -0
- package/dist/models/tokenizer.js +22 -0
- package/dist/validation/validator.js +8 -7
- package/package.json +1 -1
- package/dist/default_tasks/cifar10/index.d.ts +0 -2
- package/dist/default_tasks/simple_face/index.d.ts +0 -2
- /package/dist/{default_tasks/simple_face/model.d.ts → models/mobileNetV2_35_alpha_2_classes.d.ts} +0 -0
- /package/dist/{default_tasks/cifar10/model.d.ts → models/mobileNet_v1_025_224.d.ts} +0 -0
package/dist/aggregator/base.js
CHANGED
|
@@ -24,6 +24,7 @@ export class Base {
|
|
|
24
24
|
* It defines the effective aggregation group, which is possibly a subset
|
|
25
25
|
* of all active nodes, depending on the aggregation scheme.
|
|
26
26
|
*/
|
|
27
|
+
// communication round -> NodeID -> T
|
|
27
28
|
contributions;
|
|
28
29
|
/**
|
|
29
30
|
* Emits the aggregation event whenever an aggregation step is performed.
|
|
@@ -1,23 +1,18 @@
|
|
|
1
|
-
import type { Map } from
|
|
2
|
-
import { Base as Aggregator } from
|
|
3
|
-
import type { Model, WeightsContainer, client } from
|
|
4
|
-
/**
|
|
5
|
-
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
6
|
-
*/
|
|
1
|
+
import type { Map } from "immutable";
|
|
2
|
+
import { Base as Aggregator } from "./base.js";
|
|
3
|
+
import type { Model, WeightsContainer, client } from "../index.js";
|
|
4
|
+
/** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
|
|
7
5
|
export declare class MeanAggregator extends Aggregator<WeightsContainer> {
|
|
6
|
+
#private;
|
|
8
7
|
/**
|
|
9
|
-
*
|
|
10
|
-
*
|
|
11
|
-
*
|
|
8
|
+
* @param threshold - how many contributions for trigger an aggregation step.
|
|
9
|
+
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
|
|
10
|
+
* - absolute: t > 1, thus requiring t contributions
|
|
12
11
|
*/
|
|
13
|
-
readonly threshold: number;
|
|
14
12
|
constructor(model?: Model, roundCutoff?: number, threshold?: number);
|
|
15
|
-
/**
|
|
16
|
-
* Checks whether the contributions buffer is full, according to the set threshold.
|
|
17
|
-
* @returns Whether the contributions buffer is full
|
|
18
|
-
*/
|
|
13
|
+
/** Checks whether the contributions buffer is full. */
|
|
19
14
|
isFull(): boolean;
|
|
20
|
-
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number): boolean;
|
|
15
|
+
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
|
|
21
16
|
aggregate(): void;
|
|
22
17
|
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
23
18
|
}
|
package/dist/aggregator/mean.js
CHANGED
|
@@ -1,65 +1,51 @@
|
|
|
1
|
-
import { AggregationStep, Base as Aggregator } from
|
|
2
|
-
import { aggregation } from
|
|
3
|
-
/**
|
|
4
|
-
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
|
|
5
|
-
*/
|
|
1
|
+
import { AggregationStep, Base as Aggregator } from "./base.js";
|
|
2
|
+
import { aggregation } from "../index.js";
|
|
3
|
+
/** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
|
|
6
4
|
export class MeanAggregator extends Aggregator {
|
|
5
|
+
#threshold;
|
|
7
6
|
/**
|
|
8
|
-
*
|
|
9
|
-
*
|
|
10
|
-
*
|
|
7
|
+
* @param threshold - how many contributions for trigger an aggregation step.
|
|
8
|
+
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
|
|
9
|
+
* - absolute: t > 1, thus requiring t contributions
|
|
11
10
|
*/
|
|
12
|
-
|
|
11
|
+
// TODO no way to require a single contribution
|
|
13
12
|
constructor(model, roundCutoff = 0, threshold = 1) {
|
|
13
|
+
if (threshold <= 0)
|
|
14
|
+
throw new Error("threshold must be striclty positive");
|
|
15
|
+
if (threshold > 1 && !Number.isInteger(threshold))
|
|
16
|
+
throw new Error("absolute thresholds must be integeral");
|
|
14
17
|
super(model, roundCutoff, 1);
|
|
15
|
-
|
|
16
|
-
if (threshold === undefined) {
|
|
17
|
-
this.threshold = 1;
|
|
18
|
-
// Threshold must be positive
|
|
19
|
-
}
|
|
20
|
-
else if (threshold <= 0) {
|
|
21
|
-
throw new Error('threshold must be positive');
|
|
22
|
-
// Thresholds greater than 1 are considered absolute instead of relative to the number of nodes
|
|
23
|
-
}
|
|
24
|
-
else if (threshold > 1 && Math.round(threshold) !== threshold) {
|
|
25
|
-
throw new Error('absolute thresholds must integers');
|
|
26
|
-
}
|
|
27
|
-
else {
|
|
28
|
-
this.threshold = threshold;
|
|
29
|
-
}
|
|
18
|
+
this.#threshold = threshold;
|
|
30
19
|
}
|
|
31
|
-
/**
|
|
32
|
-
* Checks whether the contributions buffer is full, according to the set threshold.
|
|
33
|
-
* @returns Whether the contributions buffer is full
|
|
34
|
-
*/
|
|
20
|
+
/** Checks whether the contributions buffer is full. */
|
|
35
21
|
isFull() {
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
}
|
|
41
|
-
return contribs.size >= this.threshold * this.nodes.size;
|
|
42
|
-
}
|
|
43
|
-
return this.contributions.size >= this.threshold;
|
|
22
|
+
const actualThreshold = this.#threshold <= 1
|
|
23
|
+
? this.#threshold * this.nodes.size
|
|
24
|
+
: this.#threshold;
|
|
25
|
+
return (this.contributions.get(0)?.size ?? 0) >= actualThreshold;
|
|
44
26
|
}
|
|
45
|
-
add(nodeId, contribution, round) {
|
|
46
|
-
if (
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
27
|
+
add(nodeId, contribution, round, currentContributions = 0) {
|
|
28
|
+
if (currentContributions !== 0)
|
|
29
|
+
throw new Error("only a single communication round");
|
|
30
|
+
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
|
|
31
|
+
return false;
|
|
32
|
+
this.log(this.contributions.hasIn([0, nodeId])
|
|
33
|
+
? AggregationStep.UPDATE
|
|
34
|
+
: AggregationStep.ADD, nodeId);
|
|
35
|
+
this.contributions = this.contributions.setIn([0, nodeId], contribution);
|
|
36
|
+
this.informant?.update();
|
|
37
|
+
if (this.isFull())
|
|
38
|
+
this.aggregate();
|
|
39
|
+
return true;
|
|
56
40
|
}
|
|
57
41
|
aggregate() {
|
|
42
|
+
const currentContributions = this.contributions.get(0);
|
|
43
|
+
if (currentContributions === undefined)
|
|
44
|
+
throw new Error("aggregating without any contribution");
|
|
58
45
|
this.log(AggregationStep.AGGREGATE);
|
|
59
|
-
const result = aggregation.avg(
|
|
60
|
-
if (this.model !== undefined)
|
|
46
|
+
const result = aggregation.avg(currentContributions.values());
|
|
47
|
+
if (this.model !== undefined)
|
|
61
48
|
this.model.weights = result;
|
|
62
|
-
}
|
|
63
49
|
this.emit(result);
|
|
64
50
|
}
|
|
65
51
|
makePayloads(weights) {
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import { Map, List } from
|
|
2
|
-
import { Base as Aggregator } from
|
|
3
|
-
import type { Model, WeightsContainer, client } from
|
|
1
|
+
import { Map, List } from "immutable";
|
|
2
|
+
import { Base as Aggregator } from "./base.js";
|
|
3
|
+
import type { Model, WeightsContainer, client } from "../index.js";
|
|
4
4
|
/**
|
|
5
5
|
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
6
6
|
* An aggregation consists of two communication rounds:
|
|
@@ -12,12 +12,10 @@ export declare class SecureAggregator extends Aggregator<WeightsContainer> {
|
|
|
12
12
|
private readonly maxShareValue;
|
|
13
13
|
constructor(model?: Model, maxShareValue?: number);
|
|
14
14
|
aggregate(): void;
|
|
15
|
-
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound
|
|
15
|
+
add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
|
|
16
16
|
isFull(): boolean;
|
|
17
17
|
makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
|
|
18
|
-
/**
|
|
19
|
-
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
|
|
20
|
-
*/
|
|
18
|
+
/** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
|
|
21
19
|
generateAllShares(secret: WeightsContainer): List<WeightsContainer>;
|
|
22
20
|
/**
|
|
23
21
|
* Generates one share in the same shape as the secret that is populated with values randomly chosen from
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
import { Map, List, Range } from
|
|
2
|
-
import * as tf from
|
|
3
|
-
import { AggregationStep, Base as Aggregator } from
|
|
4
|
-
import { aggregation } from
|
|
1
|
+
import { Map, List, Range } from "immutable";
|
|
2
|
+
import * as tf from "@tensorflow/tfjs";
|
|
3
|
+
import { AggregationStep, Base as Aggregator } from "./base.js";
|
|
4
|
+
import { aggregation } from "../index.js";
|
|
5
5
|
/**
|
|
6
6
|
* Aggregator implementing secure multi-party computation for decentralized learning.
|
|
7
7
|
* An aggregation consists of two communication rounds:
|
|
@@ -17,60 +17,72 @@ export class SecureAggregator extends Aggregator {
|
|
|
17
17
|
}
|
|
18
18
|
aggregate() {
|
|
19
19
|
this.log(AggregationStep.AGGREGATE);
|
|
20
|
-
|
|
20
|
+
switch (this.communicationRound) {
|
|
21
21
|
// Sum the received shares
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
22
|
+
case 0: {
|
|
23
|
+
const currentContributions = this.contributions.get(0);
|
|
24
|
+
if (currentContributions === undefined)
|
|
25
|
+
throw new Error("aggregating without any contribution");
|
|
26
|
+
const result = aggregation.sum(currentContributions.values());
|
|
27
|
+
this.emit(result);
|
|
28
|
+
break;
|
|
29
|
+
}
|
|
26
30
|
// Average the received partial sums
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
31
|
+
case 1: {
|
|
32
|
+
const currentContributions = this.contributions.get(1);
|
|
33
|
+
if (currentContributions === undefined)
|
|
34
|
+
throw new Error("aggregating without any contribution");
|
|
35
|
+
const result = aggregation.avg(currentContributions.values());
|
|
36
|
+
if (this.model !== undefined)
|
|
37
|
+
this.model.weights = result;
|
|
38
|
+
this.emit(result);
|
|
39
|
+
break;
|
|
30
40
|
}
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
else {
|
|
34
|
-
throw new Error('communication round is out of bounds');
|
|
41
|
+
default:
|
|
42
|
+
throw new Error("communication round is out of bounds");
|
|
35
43
|
}
|
|
36
44
|
}
|
|
37
45
|
add(nodeId, contribution, round, communicationRound) {
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
}
|
|
45
|
-
return true;
|
|
46
|
+
switch (communicationRound) {
|
|
47
|
+
case 0:
|
|
48
|
+
case 1:
|
|
49
|
+
break;
|
|
50
|
+
default:
|
|
51
|
+
throw new Error("requires communication round to be 0 or 1");
|
|
46
52
|
}
|
|
47
|
-
|
|
53
|
+
if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
|
|
54
|
+
return false;
|
|
55
|
+
this.log(this.contributions.hasIn([communicationRound, nodeId])
|
|
56
|
+
? AggregationStep.UPDATE
|
|
57
|
+
: AggregationStep.ADD, nodeId);
|
|
58
|
+
this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
|
|
59
|
+
this.informant?.update();
|
|
60
|
+
if (this.isFull())
|
|
61
|
+
this.aggregate();
|
|
62
|
+
return true;
|
|
48
63
|
}
|
|
49
64
|
isFull() {
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
return false;
|
|
53
|
-
}
|
|
54
|
-
return contribs.size === this.nodes.size;
|
|
65
|
+
return ((this.contributions.get(this.communicationRound)?.size ?? 0) ===
|
|
66
|
+
this.nodes.size);
|
|
55
67
|
}
|
|
56
68
|
makePayloads(weights) {
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
69
|
+
switch (this.communicationRound) {
|
|
70
|
+
case 0: {
|
|
71
|
+
const shares = this.generateAllShares(weights);
|
|
72
|
+
// Abitrarily assign our shares to the available nodes
|
|
73
|
+
return Map(List(this.nodes).zip(shares));
|
|
74
|
+
}
|
|
63
75
|
// Send our partial sum to every other nodes
|
|
64
|
-
|
|
76
|
+
case 1:
|
|
77
|
+
return this.nodes.toMap().map(() => weights);
|
|
78
|
+
default:
|
|
79
|
+
throw new Error("communication round is out of bounds");
|
|
65
80
|
}
|
|
66
81
|
}
|
|
67
|
-
/**
|
|
68
|
-
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
|
|
69
|
-
*/
|
|
82
|
+
/** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
|
|
70
83
|
generateAllShares(secret) {
|
|
71
|
-
if (this.nodes.size === 0)
|
|
72
|
-
throw new Error(
|
|
73
|
-
}
|
|
84
|
+
if (this.nodes.size === 0)
|
|
85
|
+
throw new Error("too few participants to generate shares");
|
|
74
86
|
// Generate N-1 shares
|
|
75
87
|
const shares = Range(0, this.nodes.size - 1)
|
|
76
88
|
.map(() => this.generateRandomShare(secret))
|
|
@@ -86,6 +98,6 @@ export class SecureAggregator extends Aggregator {
|
|
|
86
98
|
const MAX_SEED_BITS = 47;
|
|
87
99
|
const random = crypto.getRandomValues(new BigInt64Array(1))[0];
|
|
88
100
|
const seed = Number(BigInt.asUintN(MAX_SEED_BITS, random));
|
|
89
|
-
return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue,
|
|
101
|
+
return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, "float32", seed));
|
|
90
102
|
}
|
|
91
103
|
}
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import { type client, type MetadataKey, type MetadataValue } from '../../index.js';
|
|
2
2
|
import { type weights } from '../../serialization/index.js';
|
|
3
3
|
import { type, type AssignNodeID, type ClientConnected } from '../messages.js';
|
|
4
|
-
export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload |
|
|
4
|
+
export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | ReceiveServerMetadata | AssignNodeID;
|
|
5
5
|
export interface SendPayload {
|
|
6
6
|
type: type.SendPayload;
|
|
7
7
|
payload: weights.Encoded;
|
|
@@ -12,13 +12,6 @@ export interface ReceiveServerPayload {
|
|
|
12
12
|
payload: weights.Encoded;
|
|
13
13
|
round: number;
|
|
14
14
|
}
|
|
15
|
-
export interface RequestServerStatistics {
|
|
16
|
-
type: type.RequestServerStatistics;
|
|
17
|
-
}
|
|
18
|
-
export interface ReceiveServerStatistics {
|
|
19
|
-
type: type.ReceiveServerStatistics;
|
|
20
|
-
statistics: Record<string, number>;
|
|
21
|
-
}
|
|
22
15
|
export interface ReceiveServerMetadata {
|
|
23
16
|
type: type.ReceiveServerMetadata;
|
|
24
17
|
nodeId: client.NodeID;
|
|
@@ -5,20 +5,11 @@ export function isMessageFederated(raw) {
|
|
|
5
5
|
}
|
|
6
6
|
switch (raw.type) {
|
|
7
7
|
case type.ClientConnected:
|
|
8
|
-
return true;
|
|
9
8
|
case type.SendPayload:
|
|
10
|
-
return true;
|
|
11
9
|
case type.ReceiveServerPayload:
|
|
12
|
-
return true;
|
|
13
|
-
case type.RequestServerStatistics:
|
|
14
|
-
return true;
|
|
15
|
-
case type.ReceiveServerStatistics:
|
|
16
|
-
return true;
|
|
17
10
|
case type.ReceiveServerMetadata:
|
|
18
|
-
return true;
|
|
19
11
|
case type.AssignNodeID:
|
|
20
12
|
return true;
|
|
21
|
-
default:
|
|
22
|
-
return false;
|
|
23
13
|
}
|
|
14
|
+
return false;
|
|
24
15
|
}
|
|
@@ -10,9 +10,7 @@ export declare enum type {
|
|
|
10
10
|
Payload = 5,
|
|
11
11
|
SendPayload = 6,
|
|
12
12
|
ReceiveServerMetadata = 7,
|
|
13
|
-
ReceiveServerPayload = 8
|
|
14
|
-
RequestServerStatistics = 9,
|
|
15
|
-
ReceiveServerStatistics = 10
|
|
13
|
+
ReceiveServerPayload = 8
|
|
16
14
|
}
|
|
17
15
|
export interface ClientConnected {
|
|
18
16
|
type: type.ClientConnected;
|
package/dist/client/messages.js
CHANGED
|
@@ -11,8 +11,6 @@ export var type;
|
|
|
11
11
|
type[type["SendPayload"] = 6] = "SendPayload";
|
|
12
12
|
type[type["ReceiveServerMetadata"] = 7] = "ReceiveServerMetadata";
|
|
13
13
|
type[type["ReceiveServerPayload"] = 8] = "ReceiveServerPayload";
|
|
14
|
-
type[type["RequestServerStatistics"] = 9] = "RequestServerStatistics";
|
|
15
|
-
type[type["ReceiveServerStatistics"] = 10] = "ReceiveServerStatistics";
|
|
16
14
|
})(type || (type = {}));
|
|
17
15
|
export function hasMessageType(raw) {
|
|
18
16
|
if (typeof raw !== 'object' || raw === null) {
|
|
@@ -17,15 +17,11 @@ export declare class DatasetBuilder<Source> {
|
|
|
17
17
|
/**
|
|
18
18
|
* The buffer of unlabelled file sources.
|
|
19
19
|
*/
|
|
20
|
-
private
|
|
20
|
+
private _unlabeledSources;
|
|
21
21
|
/**
|
|
22
22
|
* The buffer of labelled file sources.
|
|
23
23
|
*/
|
|
24
|
-
private
|
|
25
|
-
/**
|
|
26
|
-
* Whether a dataset was already produced.
|
|
27
|
-
*/
|
|
28
|
-
private _built;
|
|
24
|
+
private _labeledSources;
|
|
29
25
|
constructor(
|
|
30
26
|
/**
|
|
31
27
|
* The data loader used to load the data contained in the provided files.
|
|
@@ -48,13 +44,8 @@ export declare class DatasetBuilder<Source> {
|
|
|
48
44
|
* @param label The file sources label
|
|
49
45
|
*/
|
|
50
46
|
clearFiles(label?: string): void;
|
|
51
|
-
private resetBuiltState;
|
|
52
47
|
private getLabels;
|
|
53
48
|
build(config?: DataConfig): Promise<DataSplit>;
|
|
54
|
-
/**
|
|
55
|
-
* Whether the dataset builder has already been consumed to produce a dataset.
|
|
56
|
-
*/
|
|
57
|
-
get built(): boolean;
|
|
58
49
|
get size(): number;
|
|
59
50
|
get sources(): Source[];
|
|
60
51
|
}
|
|
@@ -9,16 +9,11 @@ export class DatasetBuilder {
|
|
|
9
9
|
/**
|
|
10
10
|
* The buffer of unlabelled file sources.
|
|
11
11
|
*/
|
|
12
|
-
|
|
12
|
+
_unlabeledSources;
|
|
13
13
|
/**
|
|
14
14
|
* The buffer of labelled file sources.
|
|
15
15
|
*/
|
|
16
|
-
|
|
17
|
-
/**
|
|
18
|
-
* Whether a dataset was already produced.
|
|
19
|
-
*/
|
|
20
|
-
// TODO useless, responsibility on callers
|
|
21
|
-
_built;
|
|
16
|
+
_labeledSources;
|
|
22
17
|
constructor(
|
|
23
18
|
/**
|
|
24
19
|
* The data loader used to load the data contained in the provided files.
|
|
@@ -30,9 +25,9 @@ export class DatasetBuilder {
|
|
|
30
25
|
task) {
|
|
31
26
|
this.dataLoader = dataLoader;
|
|
32
27
|
this.task = task;
|
|
33
|
-
this.
|
|
34
|
-
|
|
35
|
-
this.
|
|
28
|
+
this._unlabeledSources = [];
|
|
29
|
+
// Map from label to sources
|
|
30
|
+
this._labeledSources = Map();
|
|
36
31
|
}
|
|
37
32
|
/**
|
|
38
33
|
* Adds the given file sources to the builder's buffer. Sources may be provided a label in the case
|
|
@@ -41,19 +36,16 @@ export class DatasetBuilder {
|
|
|
41
36
|
* @param label The file sources label
|
|
42
37
|
*/
|
|
43
38
|
addFiles(sources, label) {
|
|
44
|
-
if (this.built) {
|
|
45
|
-
this.resetBuiltState();
|
|
46
|
-
}
|
|
47
39
|
if (label === undefined) {
|
|
48
|
-
this.
|
|
40
|
+
this._unlabeledSources = this._unlabeledSources.concat(sources);
|
|
49
41
|
}
|
|
50
42
|
else {
|
|
51
|
-
const currentSources = this.
|
|
43
|
+
const currentSources = this._labeledSources.get(label);
|
|
52
44
|
if (currentSources === undefined) {
|
|
53
|
-
this.
|
|
45
|
+
this._labeledSources = this._labeledSources.set(label, sources);
|
|
54
46
|
}
|
|
55
47
|
else {
|
|
56
|
-
this.
|
|
48
|
+
this._labeledSources = this._labeledSources.set(label, currentSources.concat(sources));
|
|
57
49
|
}
|
|
58
50
|
}
|
|
59
51
|
}
|
|
@@ -63,27 +55,19 @@ export class DatasetBuilder {
|
|
|
63
55
|
* @param label The file sources label
|
|
64
56
|
*/
|
|
65
57
|
clearFiles(label) {
|
|
66
|
-
if (this.built) {
|
|
67
|
-
this.resetBuiltState();
|
|
68
|
-
}
|
|
69
58
|
if (label === undefined) {
|
|
70
|
-
this.
|
|
59
|
+
this._unlabeledSources = [];
|
|
71
60
|
}
|
|
72
61
|
else {
|
|
73
|
-
this.
|
|
62
|
+
this._labeledSources = this._labeledSources.delete(label);
|
|
74
63
|
}
|
|
75
64
|
}
|
|
76
|
-
// If files are added or removed, then this should be called since the latest
|
|
77
|
-
// version of the dataset_builder has not yet been built.
|
|
78
|
-
resetBuiltState() {
|
|
79
|
-
this._built = false;
|
|
80
|
-
}
|
|
81
65
|
getLabels() {
|
|
82
66
|
// We need to duplicate the labels as we need one for each source.
|
|
83
67
|
// Say for label A we have sources [img1, img2, img3], then we
|
|
84
68
|
// need labels [A, A, A].
|
|
85
69
|
let labels = [];
|
|
86
|
-
this.
|
|
70
|
+
this._labeledSources.forEach((sources, label) => {
|
|
87
71
|
const sourcesLabels = Array.from({ length: sources.length }, (_) => label);
|
|
88
72
|
labels = labels.concat(sourcesLabels);
|
|
89
73
|
});
|
|
@@ -91,17 +75,17 @@ export class DatasetBuilder {
|
|
|
91
75
|
}
|
|
92
76
|
async build(config) {
|
|
93
77
|
// Require that at least one source collection is non-empty, but not both
|
|
94
|
-
if (
|
|
95
|
-
throw new Error('
|
|
78
|
+
if (this._unlabeledSources.length + this._labeledSources.size === 0) {
|
|
79
|
+
throw new Error('No input files connected'); // This error message is parsed in Trainer.vue
|
|
96
80
|
}
|
|
97
81
|
let dataTuple;
|
|
98
|
-
if (this.
|
|
82
|
+
if (this._unlabeledSources.length > 0) {
|
|
99
83
|
let defaultConfig = {};
|
|
100
84
|
if (config?.inference === true) {
|
|
101
85
|
// Inferring model, no labels needed
|
|
102
86
|
defaultConfig = {
|
|
103
87
|
features: this.task.trainingInformation.inputColumns,
|
|
104
|
-
shuffle:
|
|
88
|
+
shuffle: true
|
|
105
89
|
};
|
|
106
90
|
}
|
|
107
91
|
else {
|
|
@@ -109,34 +93,26 @@ export class DatasetBuilder {
|
|
|
109
93
|
defaultConfig = {
|
|
110
94
|
features: this.task.trainingInformation.inputColumns,
|
|
111
95
|
labels: this.task.trainingInformation.outputColumns,
|
|
112
|
-
shuffle:
|
|
96
|
+
shuffle: true
|
|
113
97
|
};
|
|
114
98
|
}
|
|
115
|
-
dataTuple = await this.dataLoader.loadAll(this.
|
|
99
|
+
dataTuple = await this.dataLoader.loadAll(this._unlabeledSources, { ...defaultConfig, ...config });
|
|
116
100
|
}
|
|
117
101
|
else {
|
|
118
102
|
// Labels are inferred from the file selection boxes
|
|
119
103
|
const defaultConfig = {
|
|
120
104
|
labels: this.getLabels(),
|
|
121
|
-
shuffle:
|
|
105
|
+
shuffle: true
|
|
122
106
|
};
|
|
123
|
-
const sources = this.
|
|
107
|
+
const sources = this._labeledSources.valueSeq().toArray().flat();
|
|
124
108
|
dataTuple = await this.dataLoader.loadAll(sources, { ...defaultConfig, ...config });
|
|
125
109
|
}
|
|
126
|
-
// TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
|
|
127
|
-
this._built = true;
|
|
128
110
|
return dataTuple;
|
|
129
111
|
}
|
|
130
|
-
/**
|
|
131
|
-
* Whether the dataset builder has already been consumed to produce a dataset.
|
|
132
|
-
*/
|
|
133
|
-
get built() {
|
|
134
|
-
return this._built;
|
|
135
|
-
}
|
|
136
112
|
get size() {
|
|
137
|
-
return Math.max(this.
|
|
113
|
+
return Math.max(this._unlabeledSources.length, this._labeledSources.size);
|
|
138
114
|
}
|
|
139
115
|
get sources() {
|
|
140
|
-
return this.
|
|
116
|
+
return this._unlabeledSources.length > 0 ? this._unlabeledSources : this._labeledSources.valueSeq().toArray().flat();
|
|
141
117
|
}
|
|
142
118
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { data, models } from '
|
|
3
|
-
import baseModel from '
|
|
2
|
+
import { data, models } from '../index.js';
|
|
3
|
+
import baseModel from '../models/mobileNet_v1_025_224.js';
|
|
4
4
|
export const cifar10 = {
|
|
5
5
|
getTask() {
|
|
6
6
|
return {
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
export { cifar10 } from './cifar10
|
|
1
|
+
export { cifar10 } from './cifar10.js';
|
|
2
2
|
export { lusCovid } from './lus_covid.js';
|
|
3
|
+
export { skinCondition } from './skin_condition.js';
|
|
3
4
|
export { mnist } from './mnist.js';
|
|
4
|
-
export { simpleFace } from './simple_face
|
|
5
|
+
export { simpleFace } from './simple_face.js';
|
|
5
6
|
export { titanic } from './titanic.js';
|
|
6
7
|
export { wikitext } from './wikitext.js';
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
export { cifar10 } from './cifar10
|
|
1
|
+
export { cifar10 } from './cifar10.js';
|
|
2
2
|
export { lusCovid } from './lus_covid.js';
|
|
3
|
+
export { skinCondition } from './skin_condition.js';
|
|
3
4
|
export { mnist } from './mnist.js';
|
|
4
|
-
export { simpleFace } from './simple_face
|
|
5
|
+
export { simpleFace } from './simple_face.js';
|
|
5
6
|
export { titanic } from './titanic.js';
|
|
6
7
|
export { wikitext } from './wikitext.js';
|
|
@@ -8,7 +8,7 @@ export const lusCovid = {
|
|
|
8
8
|
taskTitle: 'COVID Lung Ultrasound',
|
|
9
9
|
summary: {
|
|
10
10
|
preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
|
|
11
|
-
overview: "Don
|
|
11
|
+
overview: "Don't have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
|
|
12
12
|
},
|
|
13
13
|
model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
|
|
14
14
|
dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import * as tf from '@tensorflow/tfjs';
|
|
2
|
-
import { data, models } from '
|
|
3
|
-
import baseModel from '
|
|
2
|
+
import { data, models } from '../index.js';
|
|
3
|
+
import baseModel from '../models/mobileNetV2_35_alpha_2_classes.js';
|
|
4
4
|
export const simpleFace = {
|
|
5
5
|
getTask() {
|
|
6
6
|
return {
|
|
@@ -12,7 +12,7 @@ export const simpleFace = {
|
|
|
12
12
|
overview: 'Simple face is a small subset of face_task from Kaggle'
|
|
13
13
|
},
|
|
14
14
|
dataFormatInformation: '',
|
|
15
|
-
dataExampleText: 'Below you find an example',
|
|
15
|
+
dataExampleText: 'Below you can find an example',
|
|
16
16
|
dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
|
|
17
17
|
},
|
|
18
18
|
trainingInformation: {
|