@epfml/discojs 2.1.2-p20240515133413.0 → 2.1.2-p20240531085945.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/dist/aggregator/base.js +1 -0
  2. package/dist/aggregator/mean.d.ts +10 -15
  3. package/dist/aggregator/mean.js +36 -50
  4. package/dist/aggregator/secure.d.ts +5 -7
  5. package/dist/aggregator/secure.js +56 -44
  6. package/dist/client/federated/messages.d.ts +1 -8
  7. package/dist/client/federated/messages.js +1 -10
  8. package/dist/client/messages.d.ts +1 -3
  9. package/dist/client/messages.js +0 -2
  10. package/dist/dataset/dataset_builder.d.ts +2 -11
  11. package/dist/dataset/dataset_builder.js +22 -46
  12. package/dist/default_tasks/cifar10.d.ts +2 -0
  13. package/dist/default_tasks/{cifar10/index.js → cifar10.js} +2 -2
  14. package/dist/default_tasks/index.d.ts +3 -2
  15. package/dist/default_tasks/index.js +3 -2
  16. package/dist/default_tasks/lus_covid.js +1 -1
  17. package/dist/default_tasks/simple_face.d.ts +2 -0
  18. package/dist/default_tasks/{simple_face/index.js → simple_face.js} +3 -3
  19. package/dist/default_tasks/skin_condition.d.ts +2 -0
  20. package/dist/default_tasks/skin_condition.js +79 -0
  21. package/dist/models/gpt/config.d.ts +32 -0
  22. package/dist/models/gpt/config.js +42 -0
  23. package/dist/models/gpt/evaluate.d.ts +7 -0
  24. package/dist/models/gpt/evaluate.js +44 -0
  25. package/dist/models/gpt/index.d.ts +35 -0
  26. package/dist/models/gpt/index.js +104 -0
  27. package/dist/models/gpt/layers.d.ts +13 -0
  28. package/dist/models/gpt/layers.js +272 -0
  29. package/dist/models/gpt/model.d.ts +43 -0
  30. package/dist/models/gpt/model.js +191 -0
  31. package/dist/models/gpt/optimizers.d.ts +4 -0
  32. package/dist/models/gpt/optimizers.js +95 -0
  33. package/dist/models/index.d.ts +5 -0
  34. package/dist/models/index.js +4 -0
  35. package/dist/{default_tasks/simple_face/model.js → models/mobileNetV2_35_alpha_2_classes.js} +2 -0
  36. package/dist/{default_tasks/cifar10/model.js → models/mobileNet_v1_025_224.js} +1 -0
  37. package/dist/models/model.d.ts +51 -0
  38. package/dist/models/model.js +8 -0
  39. package/dist/models/tfjs.d.ts +24 -0
  40. package/dist/models/tfjs.js +107 -0
  41. package/dist/models/tokenizer.d.ts +14 -0
  42. package/dist/models/tokenizer.js +22 -0
  43. package/dist/validation/validator.js +8 -7
  44. package/package.json +1 -1
  45. package/dist/default_tasks/cifar10/index.d.ts +0 -2
  46. package/dist/default_tasks/simple_face/index.d.ts +0 -2
  47. /package/dist/{default_tasks/simple_face/model.d.ts → models/mobileNetV2_35_alpha_2_classes.d.ts} +0 -0
  48. /package/dist/{default_tasks/cifar10/model.d.ts → models/mobileNet_v1_025_224.d.ts} +0 -0
@@ -24,6 +24,7 @@ export class Base {
24
24
  * It defines the effective aggregation group, which is possibly a subset
25
25
  * of all active nodes, depending on the aggregation scheme.
26
26
  */
27
+ // communication round -> NodeID -> T
27
28
  contributions;
28
29
  /**
29
30
  * Emits the aggregation event whenever an aggregation step is performed.
@@ -1,23 +1,18 @@
1
- import type { Map } from 'immutable';
2
- import { Base as Aggregator } from './base.js';
3
- import type { Model, WeightsContainer, client } from '../index.js';
4
- /**
5
- * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
6
- */
1
+ import type { Map } from "immutable";
2
+ import { Base as Aggregator } from "./base.js";
3
+ import type { Model, WeightsContainer, client } from "../index.js";
4
+ /** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
7
5
  export declare class MeanAggregator extends Aggregator<WeightsContainer> {
6
+ #private;
8
7
  /**
9
- * The threshold t to fulfill to trigger an aggregation step. It can either be:
10
- * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
11
- * - absolute: t > 1, thus requiring t contributions
8
+ * @param threshold - how many contributions for trigger an aggregation step.
9
+ * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
10
+ * - absolute: t > 1, thus requiring t contributions
12
11
  */
13
- readonly threshold: number;
14
12
  constructor(model?: Model, roundCutoff?: number, threshold?: number);
15
- /**
16
- * Checks whether the contributions buffer is full, according to the set threshold.
17
- * @returns Whether the contributions buffer is full
18
- */
13
+ /** Checks whether the contributions buffer is full. */
19
14
  isFull(): boolean;
20
- add(nodeId: client.NodeID, contribution: WeightsContainer, round: number): boolean;
15
+ add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, currentContributions?: number): boolean;
21
16
  aggregate(): void;
22
17
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
23
18
  }
@@ -1,65 +1,51 @@
1
- import { AggregationStep, Base as Aggregator } from './base.js';
2
- import { aggregation } from '../index.js';
3
- /**
4
- * Mean aggregator whose aggregation step consists in computing the mean of the received weights.
5
- */
1
+ import { AggregationStep, Base as Aggregator } from "./base.js";
2
+ import { aggregation } from "../index.js";
3
+ /** Mean aggregator whose aggregation step consists in computing the mean of the received weights. */
6
4
  export class MeanAggregator extends Aggregator {
5
+ #threshold;
7
6
  /**
8
- * The threshold t to fulfill to trigger an aggregation step. It can either be:
9
- * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
10
- * - absolute: t > 1, thus requiring t contributions
7
+ * @param threshold - how many contributions for trigger an aggregation step.
8
+ * - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
9
+ * - absolute: t > 1, thus requiring t contributions
11
10
  */
12
- threshold;
11
+ // TODO no way to require a single contribution
13
12
  constructor(model, roundCutoff = 0, threshold = 1) {
13
+ if (threshold <= 0)
14
+ throw new Error("threshold must be striclty positive");
15
+ if (threshold > 1 && !Number.isInteger(threshold))
16
+ throw new Error("absolute thresholds must be integeral");
14
17
  super(model, roundCutoff, 1);
15
- // Default threshold is 100% of node participation
16
- if (threshold === undefined) {
17
- this.threshold = 1;
18
- // Threshold must be positive
19
- }
20
- else if (threshold <= 0) {
21
- throw new Error('threshold must be positive');
22
- // Thresholds greater than 1 are considered absolute instead of relative to the number of nodes
23
- }
24
- else if (threshold > 1 && Math.round(threshold) !== threshold) {
25
- throw new Error('absolute thresholds must integers');
26
- }
27
- else {
28
- this.threshold = threshold;
29
- }
18
+ this.#threshold = threshold;
30
19
  }
31
- /**
32
- * Checks whether the contributions buffer is full, according to the set threshold.
33
- * @returns Whether the contributions buffer is full
34
- */
20
+ /** Checks whether the contributions buffer is full. */
35
21
  isFull() {
36
- if (this.threshold <= 1) {
37
- const contribs = this.contributions.get(this.communicationRound);
38
- if (contribs === undefined) {
39
- return false;
40
- }
41
- return contribs.size >= this.threshold * this.nodes.size;
42
- }
43
- return this.contributions.size >= this.threshold;
22
+ const actualThreshold = this.#threshold <= 1
23
+ ? this.#threshold * this.nodes.size
24
+ : this.#threshold;
25
+ return (this.contributions.get(0)?.size ?? 0) >= actualThreshold;
44
26
  }
45
- add(nodeId, contribution, round) {
46
- if (this.nodes.has(nodeId) && this.isWithinRoundCutoff(round)) {
47
- this.log(this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
48
- this.contributions = this.contributions.setIn([0, nodeId], contribution);
49
- this.informant?.update();
50
- if (this.isFull()) {
51
- this.aggregate();
52
- }
53
- return true;
54
- }
55
- return false;
27
+ add(nodeId, contribution, round, currentContributions = 0) {
28
+ if (currentContributions !== 0)
29
+ throw new Error("only a single communication round");
30
+ if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
31
+ return false;
32
+ this.log(this.contributions.hasIn([0, nodeId])
33
+ ? AggregationStep.UPDATE
34
+ : AggregationStep.ADD, nodeId);
35
+ this.contributions = this.contributions.setIn([0, nodeId], contribution);
36
+ this.informant?.update();
37
+ if (this.isFull())
38
+ this.aggregate();
39
+ return true;
56
40
  }
57
41
  aggregate() {
42
+ const currentContributions = this.contributions.get(0);
43
+ if (currentContributions === undefined)
44
+ throw new Error("aggregating without any contribution");
58
45
  this.log(AggregationStep.AGGREGATE);
59
- const result = aggregation.avg(this.contributions.get(0)?.values());
60
- if (this.model !== undefined) {
46
+ const result = aggregation.avg(currentContributions.values());
47
+ if (this.model !== undefined)
61
48
  this.model.weights = result;
62
- }
63
49
  this.emit(result);
64
50
  }
65
51
  makePayloads(weights) {
@@ -1,6 +1,6 @@
1
- import { Map, List } from 'immutable';
2
- import { Base as Aggregator } from './base.js';
3
- import type { Model, WeightsContainer, client } from '../index.js';
1
+ import { Map, List } from "immutable";
2
+ import { Base as Aggregator } from "./base.js";
3
+ import type { Model, WeightsContainer, client } from "../index.js";
4
4
  /**
5
5
  * Aggregator implementing secure multi-party computation for decentralized learning.
6
6
  * An aggregation consists of two communication rounds:
@@ -12,12 +12,10 @@ export declare class SecureAggregator extends Aggregator<WeightsContainer> {
12
12
  private readonly maxShareValue;
13
13
  constructor(model?: Model, maxShareValue?: number);
14
14
  aggregate(): void;
15
- add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound: number): boolean;
15
+ add(nodeId: client.NodeID, contribution: WeightsContainer, round: number, communicationRound?: number): boolean;
16
16
  isFull(): boolean;
17
17
  makePayloads(weights: WeightsContainer): Map<client.NodeID, WeightsContainer>;
18
- /**
19
- * Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
20
- */
18
+ /** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
21
19
  generateAllShares(secret: WeightsContainer): List<WeightsContainer>;
22
20
  /**
23
21
  * Generates one share in the same shape as the secret that is populated with values randomly chosen from
@@ -1,7 +1,7 @@
1
- import { Map, List, Range } from 'immutable';
2
- import * as tf from '@tensorflow/tfjs';
3
- import { AggregationStep, Base as Aggregator } from './base.js';
4
- import { aggregation } from '../index.js';
1
+ import { Map, List, Range } from "immutable";
2
+ import * as tf from "@tensorflow/tfjs";
3
+ import { AggregationStep, Base as Aggregator } from "./base.js";
4
+ import { aggregation } from "../index.js";
5
5
  /**
6
6
  * Aggregator implementing secure multi-party computation for decentralized learning.
7
7
  * An aggregation consists of two communication rounds:
@@ -17,60 +17,72 @@ export class SecureAggregator extends Aggregator {
17
17
  }
18
18
  aggregate() {
19
19
  this.log(AggregationStep.AGGREGATE);
20
- if (this.communicationRound === 0) {
20
+ switch (this.communicationRound) {
21
21
  // Sum the received shares
22
- const result = aggregation.sum(this.contributions.get(0)?.values());
23
- this.emit(result);
24
- }
25
- else if (this.communicationRound === 1) {
22
+ case 0: {
23
+ const currentContributions = this.contributions.get(0);
24
+ if (currentContributions === undefined)
25
+ throw new Error("aggregating without any contribution");
26
+ const result = aggregation.sum(currentContributions.values());
27
+ this.emit(result);
28
+ break;
29
+ }
26
30
  // Average the received partial sums
27
- const result = aggregation.avg(this.contributions.get(1)?.values());
28
- if (this.model !== undefined) {
29
- this.model.weights = result;
31
+ case 1: {
32
+ const currentContributions = this.contributions.get(1);
33
+ if (currentContributions === undefined)
34
+ throw new Error("aggregating without any contribution");
35
+ const result = aggregation.avg(currentContributions.values());
36
+ if (this.model !== undefined)
37
+ this.model.weights = result;
38
+ this.emit(result);
39
+ break;
30
40
  }
31
- this.emit(result);
32
- }
33
- else {
34
- throw new Error('communication round is out of bounds');
41
+ default:
42
+ throw new Error("communication round is out of bounds");
35
43
  }
36
44
  }
37
45
  add(nodeId, contribution, round, communicationRound) {
38
- if (this.nodes.has(nodeId) && this.isWithinRoundCutoff(round)) {
39
- this.log(this.contributions.hasIn([communicationRound, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, nodeId);
40
- this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
41
- this.informant?.update();
42
- if (this.isFull()) {
43
- this.aggregate();
44
- }
45
- return true;
46
+ switch (communicationRound) {
47
+ case 0:
48
+ case 1:
49
+ break;
50
+ default:
51
+ throw new Error("requires communication round to be 0 or 1");
46
52
  }
47
- return false;
53
+ if (!this.nodes.has(nodeId) || !this.isWithinRoundCutoff(round))
54
+ return false;
55
+ this.log(this.contributions.hasIn([communicationRound, nodeId])
56
+ ? AggregationStep.UPDATE
57
+ : AggregationStep.ADD, nodeId);
58
+ this.contributions = this.contributions.setIn([communicationRound, nodeId], contribution);
59
+ this.informant?.update();
60
+ if (this.isFull())
61
+ this.aggregate();
62
+ return true;
48
63
  }
49
64
  isFull() {
50
- const contribs = this.contributions.get(this.communicationRound);
51
- if (contribs === undefined) {
52
- return false;
53
- }
54
- return contribs.size === this.nodes.size;
65
+ return ((this.contributions.get(this.communicationRound)?.size ?? 0) ===
66
+ this.nodes.size);
55
67
  }
56
68
  makePayloads(weights) {
57
- if (this.communicationRound === 0) {
58
- const shares = this.generateAllShares(weights);
59
- // Abitrarily assign our shares to the available nodes
60
- return Map(List(this.nodes).zip(shares));
61
- }
62
- else {
69
+ switch (this.communicationRound) {
70
+ case 0: {
71
+ const shares = this.generateAllShares(weights);
72
+ // Abitrarily assign our shares to the available nodes
73
+ return Map(List(this.nodes).zip(shares));
74
+ }
63
75
  // Send our partial sum to every other nodes
64
- return this.nodes.toMap().map(() => weights);
76
+ case 1:
77
+ return this.nodes.toMap().map(() => weights);
78
+ default:
79
+ throw new Error("communication round is out of bounds");
65
80
  }
66
81
  }
67
- /**
68
- * Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
69
- */
82
+ /** Generate N additive shares that aggregate to the secret weights array, where N is the number of peers. */
70
83
  generateAllShares(secret) {
71
- if (this.nodes.size === 0) {
72
- throw new Error('too few participants to generate shares');
73
- }
84
+ if (this.nodes.size === 0)
85
+ throw new Error("too few participants to generate shares");
74
86
  // Generate N-1 shares
75
87
  const shares = Range(0, this.nodes.size - 1)
76
88
  .map(() => this.generateRandomShare(secret))
@@ -86,6 +98,6 @@ export class SecureAggregator extends Aggregator {
86
98
  const MAX_SEED_BITS = 47;
87
99
  const random = crypto.getRandomValues(new BigInt64Array(1))[0];
88
100
  const seed = Number(BigInt.asUintN(MAX_SEED_BITS, random));
89
- return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, 'float32', seed));
101
+ return secret.map((t) => tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, "float32", seed));
90
102
  }
91
103
  }
@@ -1,7 +1,7 @@
1
1
  import { type client, type MetadataKey, type MetadataValue } from '../../index.js';
2
2
  import { type weights } from '../../serialization/index.js';
3
3
  import { type, type AssignNodeID, type ClientConnected } from '../messages.js';
4
- export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | RequestServerStatistics | ReceiveServerStatistics | ReceiveServerMetadata | AssignNodeID;
4
+ export type MessageFederated = ClientConnected | SendPayload | ReceiveServerPayload | ReceiveServerMetadata | AssignNodeID;
5
5
  export interface SendPayload {
6
6
  type: type.SendPayload;
7
7
  payload: weights.Encoded;
@@ -12,13 +12,6 @@ export interface ReceiveServerPayload {
12
12
  payload: weights.Encoded;
13
13
  round: number;
14
14
  }
15
- export interface RequestServerStatistics {
16
- type: type.RequestServerStatistics;
17
- }
18
- export interface ReceiveServerStatistics {
19
- type: type.ReceiveServerStatistics;
20
- statistics: Record<string, number>;
21
- }
22
15
  export interface ReceiveServerMetadata {
23
16
  type: type.ReceiveServerMetadata;
24
17
  nodeId: client.NodeID;
@@ -5,20 +5,11 @@ export function isMessageFederated(raw) {
5
5
  }
6
6
  switch (raw.type) {
7
7
  case type.ClientConnected:
8
- return true;
9
8
  case type.SendPayload:
10
- return true;
11
9
  case type.ReceiveServerPayload:
12
- return true;
13
- case type.RequestServerStatistics:
14
- return true;
15
- case type.ReceiveServerStatistics:
16
- return true;
17
10
  case type.ReceiveServerMetadata:
18
- return true;
19
11
  case type.AssignNodeID:
20
12
  return true;
21
- default:
22
- return false;
23
13
  }
14
+ return false;
24
15
  }
@@ -10,9 +10,7 @@ export declare enum type {
10
10
  Payload = 5,
11
11
  SendPayload = 6,
12
12
  ReceiveServerMetadata = 7,
13
- ReceiveServerPayload = 8,
14
- RequestServerStatistics = 9,
15
- ReceiveServerStatistics = 10
13
+ ReceiveServerPayload = 8
16
14
  }
17
15
  export interface ClientConnected {
18
16
  type: type.ClientConnected;
@@ -11,8 +11,6 @@ export var type;
11
11
  type[type["SendPayload"] = 6] = "SendPayload";
12
12
  type[type["ReceiveServerMetadata"] = 7] = "ReceiveServerMetadata";
13
13
  type[type["ReceiveServerPayload"] = 8] = "ReceiveServerPayload";
14
- type[type["RequestServerStatistics"] = 9] = "RequestServerStatistics";
15
- type[type["ReceiveServerStatistics"] = 10] = "ReceiveServerStatistics";
16
14
  })(type || (type = {}));
17
15
  export function hasMessageType(raw) {
18
16
  if (typeof raw !== 'object' || raw === null) {
@@ -17,15 +17,11 @@ export declare class DatasetBuilder<Source> {
17
17
  /**
18
18
  * The buffer of unlabelled file sources.
19
19
  */
20
- private _sources;
20
+ private _unlabeledSources;
21
21
  /**
22
22
  * The buffer of labelled file sources.
23
23
  */
24
- private labelledSources;
25
- /**
26
- * Whether a dataset was already produced.
27
- */
28
- private _built;
24
+ private _labeledSources;
29
25
  constructor(
30
26
  /**
31
27
  * The data loader used to load the data contained in the provided files.
@@ -48,13 +44,8 @@ export declare class DatasetBuilder<Source> {
48
44
  * @param label The file sources label
49
45
  */
50
46
  clearFiles(label?: string): void;
51
- private resetBuiltState;
52
47
  private getLabels;
53
48
  build(config?: DataConfig): Promise<DataSplit>;
54
- /**
55
- * Whether the dataset builder has already been consumed to produce a dataset.
56
- */
57
- get built(): boolean;
58
49
  get size(): number;
59
50
  get sources(): Source[];
60
51
  }
@@ -9,16 +9,11 @@ export class DatasetBuilder {
9
9
  /**
10
10
  * The buffer of unlabelled file sources.
11
11
  */
12
- _sources;
12
+ _unlabeledSources;
13
13
  /**
14
14
  * The buffer of labelled file sources.
15
15
  */
16
- labelledSources;
17
- /**
18
- * Whether a dataset was already produced.
19
- */
20
- // TODO useless, responsibility on callers
21
- _built;
16
+ _labeledSources;
22
17
  constructor(
23
18
  /**
24
19
  * The data loader used to load the data contained in the provided files.
@@ -30,9 +25,9 @@ export class DatasetBuilder {
30
25
  task) {
31
26
  this.dataLoader = dataLoader;
32
27
  this.task = task;
33
- this._sources = [];
34
- this.labelledSources = Map();
35
- this._built = false;
28
+ this._unlabeledSources = [];
29
+ // Map from label to sources
30
+ this._labeledSources = Map();
36
31
  }
37
32
  /**
38
33
  * Adds the given file sources to the builder's buffer. Sources may be provided a label in the case
@@ -41,19 +36,16 @@ export class DatasetBuilder {
41
36
  * @param label The file sources label
42
37
  */
43
38
  addFiles(sources, label) {
44
- if (this.built) {
45
- this.resetBuiltState();
46
- }
47
39
  if (label === undefined) {
48
- this._sources = this._sources.concat(sources);
40
+ this._unlabeledSources = this._unlabeledSources.concat(sources);
49
41
  }
50
42
  else {
51
- const currentSources = this.labelledSources.get(label);
43
+ const currentSources = this._labeledSources.get(label);
52
44
  if (currentSources === undefined) {
53
- this.labelledSources = this.labelledSources.set(label, sources);
45
+ this._labeledSources = this._labeledSources.set(label, sources);
54
46
  }
55
47
  else {
56
- this.labelledSources = this.labelledSources.set(label, currentSources.concat(sources));
48
+ this._labeledSources = this._labeledSources.set(label, currentSources.concat(sources));
57
49
  }
58
50
  }
59
51
  }
@@ -63,27 +55,19 @@ export class DatasetBuilder {
63
55
  * @param label The file sources label
64
56
  */
65
57
  clearFiles(label) {
66
- if (this.built) {
67
- this.resetBuiltState();
68
- }
69
58
  if (label === undefined) {
70
- this._sources = [];
59
+ this._unlabeledSources = [];
71
60
  }
72
61
  else {
73
- this.labelledSources = this.labelledSources.delete(label);
62
+ this._labeledSources = this._labeledSources.delete(label);
74
63
  }
75
64
  }
76
- // If files are added or removed, then this should be called since the latest
77
- // version of the dataset_builder has not yet been built.
78
- resetBuiltState() {
79
- this._built = false;
80
- }
81
65
  getLabels() {
82
66
  // We need to duplicate the labels as we need one for each source.
83
67
  // Say for label A we have sources [img1, img2, img3], then we
84
68
  // need labels [A, A, A].
85
69
  let labels = [];
86
- this.labelledSources.forEach((sources, label) => {
70
+ this._labeledSources.forEach((sources, label) => {
87
71
  const sourcesLabels = Array.from({ length: sources.length }, (_) => label);
88
72
  labels = labels.concat(sourcesLabels);
89
73
  });
@@ -91,17 +75,17 @@ export class DatasetBuilder {
91
75
  }
92
76
  async build(config) {
93
77
  // Require that at least one source collection is non-empty, but not both
94
- if ((this._sources.length > 0) === (this.labelledSources.size > 0)) {
95
- throw new Error('Please provide dataset input files'); // This error message is parsed in DatasetInput.vue
78
+ if (this._unlabeledSources.length + this._labeledSources.size === 0) {
79
+ throw new Error('No input files connected'); // This error message is parsed in Trainer.vue
96
80
  }
97
81
  let dataTuple;
98
- if (this._sources.length > 0) {
82
+ if (this._unlabeledSources.length > 0) {
99
83
  let defaultConfig = {};
100
84
  if (config?.inference === true) {
101
85
  // Inferring model, no labels needed
102
86
  defaultConfig = {
103
87
  features: this.task.trainingInformation.inputColumns,
104
- shuffle: false
88
+ shuffle: true
105
89
  };
106
90
  }
107
91
  else {
@@ -109,34 +93,26 @@ export class DatasetBuilder {
109
93
  defaultConfig = {
110
94
  features: this.task.trainingInformation.inputColumns,
111
95
  labels: this.task.trainingInformation.outputColumns,
112
- shuffle: false
96
+ shuffle: true
113
97
  };
114
98
  }
115
- dataTuple = await this.dataLoader.loadAll(this._sources, { ...defaultConfig, ...config });
99
+ dataTuple = await this.dataLoader.loadAll(this._unlabeledSources, { ...defaultConfig, ...config });
116
100
  }
117
101
  else {
118
102
  // Labels are inferred from the file selection boxes
119
103
  const defaultConfig = {
120
104
  labels: this.getLabels(),
121
- shuffle: false
105
+ shuffle: true
122
106
  };
123
- const sources = this.labelledSources.valueSeq().toArray().flat();
107
+ const sources = this._labeledSources.valueSeq().toArray().flat();
124
108
  dataTuple = await this.dataLoader.loadAll(sources, { ...defaultConfig, ...config });
125
109
  }
126
- // TODO @s314cy: Support .csv labels for image datasets (supervised training or testing)
127
- this._built = true;
128
110
  return dataTuple;
129
111
  }
130
- /**
131
- * Whether the dataset builder has already been consumed to produce a dataset.
132
- */
133
- get built() {
134
- return this._built;
135
- }
136
112
  get size() {
137
- return Math.max(this._sources.length, this.labelledSources.size);
113
+ return Math.max(this._unlabeledSources.length, this._labeledSources.size);
138
114
  }
139
115
  get sources() {
140
- return this._sources.length > 0 ? this._sources : this.labelledSources.valueSeq().toArray().flat();
116
+ return this._unlabeledSources.length > 0 ? this._unlabeledSources : this._labeledSources.valueSeq().toArray().flat();
141
117
  }
142
118
  }
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const cifar10: TaskProvider;
@@ -1,6 +1,6 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { data, models } from '../../index.js';
3
- import baseModel from './model.js';
2
+ import { data, models } from '../index.js';
3
+ import baseModel from '../models/mobileNet_v1_025_224.js';
4
4
  export const cifar10 = {
5
5
  getTask() {
6
6
  return {
@@ -1,6 +1,7 @@
1
- export { cifar10 } from './cifar10/index.js';
1
+ export { cifar10 } from './cifar10.js';
2
2
  export { lusCovid } from './lus_covid.js';
3
+ export { skinCondition } from './skin_condition.js';
3
4
  export { mnist } from './mnist.js';
4
- export { simpleFace } from './simple_face/index.js';
5
+ export { simpleFace } from './simple_face.js';
5
6
  export { titanic } from './titanic.js';
6
7
  export { wikitext } from './wikitext.js';
@@ -1,6 +1,7 @@
1
- export { cifar10 } from './cifar10/index.js';
1
+ export { cifar10 } from './cifar10.js';
2
2
  export { lusCovid } from './lus_covid.js';
3
+ export { skinCondition } from './skin_condition.js';
3
4
  export { mnist } from './mnist.js';
4
- export { simpleFace } from './simple_face/index.js';
5
+ export { simpleFace } from './simple_face.js';
5
6
  export { titanic } from './titanic.js';
6
7
  export { wikitext } from './wikitext.js';
@@ -8,7 +8,7 @@ export const lusCovid = {
8
8
  taskTitle: 'COVID Lung Ultrasound',
9
9
  summary: {
10
10
  preview: 'Do you have a data of lung ultrasound images on patients <b>suspected of Lower Respiratory Tract infection (LRTI) during the COVID pandemic</b>? <br> Learn how to discriminate between COVID positive and negative patients by joining this task.',
11
- overview: "Dont have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
11
+ overview: "Don't have a dataset of your own? Download a sample of a few cases <a class='underline' href='https://drive.switch.ch/index.php/s/zM5ZrUWK3taaIly' target='_blank'>here</a>."
12
12
  },
13
13
  model: "We use a simplified* version of the <b>DeepChest model</b>: A deep learning model developed in our lab (<a class='underline' href='https://www.epfl.ch/labs/mlo/igh-intelligent-global-health/'>intelligent Global Health</a>.). On a cohort of 400 Swiss patients suspected of LRTI, the model obtained over 90% area under the ROC curve for this task. <br><br>*Simplified to ensure smooth running on your browser, the performance is minimally affected. Details of the adaptations are below <br>- <b>Removed</b>: positional embedding (i.e. we don’t take the anatomic position into consideration). Rather, the model now does mean pooling over the feature vector of the images for each patient <br>- <b>Replaced</b>: ResNet18 by Mobilenet",
14
14
  dataFormatInformation: 'This model takes as input an image dataset. It consists on a set of lung ultrasound images per patient with its corresponding label of covid positive or negative. Moreover, to identify the images per patient you have to follow the follwing naming pattern: "patientId_*.png"',
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const simpleFace: TaskProvider;
@@ -1,6 +1,6 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { data, models } from '../../index.js';
3
- import baseModel from './model.js';
2
+ import { data, models } from '../index.js';
3
+ import baseModel from '../models/mobileNetV2_35_alpha_2_classes.js';
4
4
  export const simpleFace = {
5
5
  getTask() {
6
6
  return {
@@ -12,7 +12,7 @@ export const simpleFace = {
12
12
  overview: 'Simple face is a small subset of face_task from Kaggle'
13
13
  },
14
14
  dataFormatInformation: '',
15
- dataExampleText: 'Below you find an example',
15
+ dataExampleText: 'Below you can find an example',
16
16
  dataExampleImage: 'https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/simple_face-example.png'
17
17
  },
18
18
  trainingInformation: {
@@ -0,0 +1,2 @@
1
+ import type { TaskProvider } from '../index.js';
2
+ export declare const skinCondition: TaskProvider;