@epfml/discojs 3.0.1-p20241025115642.0 → 3.0.1-p20241028120035.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. package/dist/aggregator/get.d.ts +3 -3
  2. package/dist/client/client.d.ts +5 -5
  3. package/dist/client/decentralized/decentralized_client.d.ts +2 -2
  4. package/dist/client/federated/federated_client.d.ts +2 -2
  5. package/dist/client/utils.d.ts +2 -2
  6. package/dist/dataset/dataset.d.ts +9 -2
  7. package/dist/dataset/dataset.js +83 -36
  8. package/dist/dataset/image.d.ts +5 -0
  9. package/dist/dataset/image.js +6 -1
  10. package/dist/dataset/index.d.ts +0 -1
  11. package/dist/dataset/index.js +0 -1
  12. package/dist/dataset/types.d.ts +2 -0
  13. package/dist/default_tasks/cifar10.d.ts +1 -1
  14. package/dist/default_tasks/cifar10.js +2 -3
  15. package/dist/default_tasks/lus_covid.d.ts +1 -1
  16. package/dist/default_tasks/lus_covid.js +2 -3
  17. package/dist/default_tasks/mnist.d.ts +1 -1
  18. package/dist/default_tasks/mnist.js +2 -4
  19. package/dist/default_tasks/simple_face.d.ts +1 -1
  20. package/dist/default_tasks/simple_face.js +2 -3
  21. package/dist/default_tasks/titanic.d.ts +1 -1
  22. package/dist/default_tasks/titanic.js +3 -6
  23. package/dist/default_tasks/wikitext.d.ts +1 -1
  24. package/dist/default_tasks/wikitext.js +1 -2
  25. package/dist/index.d.ts +4 -5
  26. package/dist/index.js +4 -5
  27. package/dist/models/gpt/index.d.ts +13 -16
  28. package/dist/models/gpt/index.js +62 -43
  29. package/dist/models/gpt/model.d.ts +1 -15
  30. package/dist/models/gpt/model.js +1 -75
  31. package/dist/models/model.d.ts +7 -12
  32. package/dist/models/tfjs.d.ts +10 -8
  33. package/dist/models/tfjs.js +106 -44
  34. package/dist/models/tokenizer.d.ts +1 -1
  35. package/dist/privacy.js +1 -1
  36. package/dist/processing/image.d.ts +18 -0
  37. package/dist/processing/image.js +75 -0
  38. package/dist/processing/index.d.ts +8 -0
  39. package/dist/processing/index.js +106 -0
  40. package/dist/processing/tabular.d.ts +19 -0
  41. package/dist/processing/tabular.js +33 -0
  42. package/dist/processing/text.d.ts +11 -0
  43. package/dist/processing/text.js +33 -0
  44. package/dist/serialization/model.d.ts +3 -3
  45. package/dist/serialization/model.js +19 -6
  46. package/dist/task/task.d.ts +4 -3
  47. package/dist/task/task.js +5 -3
  48. package/dist/task/task_handler.d.ts +3 -3
  49. package/dist/task/task_provider.d.ts +4 -4
  50. package/dist/task/training_information.d.ts +25 -16
  51. package/dist/task/training_information.js +76 -72
  52. package/dist/training/disco.d.ts +20 -12
  53. package/dist/training/disco.js +32 -13
  54. package/dist/training/trainer.d.ts +6 -7
  55. package/dist/training/trainer.js +6 -6
  56. package/dist/types/data_format.d.ts +40 -0
  57. package/dist/types/index.d.ts +2 -0
  58. package/dist/types/index.js +1 -0
  59. package/dist/validator.d.ts +10 -0
  60. package/dist/validator.js +30 -0
  61. package/package.json +4 -2
  62. package/dist/dataset/data/data.d.ts +0 -47
  63. package/dist/dataset/data/data.js +0 -88
  64. package/dist/dataset/data/data_split.d.ts +0 -8
  65. package/dist/dataset/data/helpers.d.ts +0 -10
  66. package/dist/dataset/data/helpers.js +0 -97
  67. package/dist/dataset/data/image_data.d.ts +0 -11
  68. package/dist/dataset/data/image_data.js +0 -43
  69. package/dist/dataset/data/index.d.ts +0 -5
  70. package/dist/dataset/data/index.js +0 -5
  71. package/dist/dataset/data/preprocessing/base.d.ts +0 -16
  72. package/dist/dataset/data/preprocessing/base.js +0 -1
  73. package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +0 -13
  74. package/dist/dataset/data/preprocessing/image_preprocessing.js +0 -42
  75. package/dist/dataset/data/preprocessing/index.d.ts +0 -4
  76. package/dist/dataset/data/preprocessing/index.js +0 -3
  77. package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +0 -13
  78. package/dist/dataset/data/preprocessing/tabular_preprocessing.js +0 -45
  79. package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +0 -13
  80. package/dist/dataset/data/preprocessing/text_preprocessing.js +0 -100
  81. package/dist/dataset/data/tabular_data.d.ts +0 -11
  82. package/dist/dataset/data/tabular_data.js +0 -24
  83. package/dist/dataset/data/text_data.d.ts +0 -11
  84. package/dist/dataset/data/text_data.js +0 -14
  85. package/dist/processing.d.ts +0 -35
  86. package/dist/processing.js +0 -89
  87. package/dist/types.d.ts +0 -3
  88. package/dist/types.js +0 -1
  89. package/dist/validation/index.d.ts +0 -1
  90. package/dist/validation/index.js +0 -1
  91. package/dist/validation/validator.d.ts +0 -10
  92. package/dist/validation/validator.js +0 -113
  93. /package/dist/{dataset/data/data_split.js → types/data_format.js} +0 -0
@@ -1,7 +1,7 @@
1
- import type { Task } from '../index.js';
1
+ import type { DataType, Task } from '../index.js';
2
2
  import { aggregator } from '../index.js';
3
3
  type AggregatorOptions = Partial<{
4
- scheme: Task['trainingInformation']['scheme'];
4
+ scheme: Task<DataType>["trainingInformation"]["scheme"];
5
5
  roundCutOff: number;
6
6
  threshold: number;
7
7
  thresholdType: 'relative' | 'absolute';
@@ -24,5 +24,5 @@ type AggregatorOptions = Partial<{
24
24
  * @param options Options passed down to the aggregator's constructor
25
25
  * @returns The aggregator
26
26
  */
27
- export declare function getAggregator(task: Task, options?: AggregatorOptions): aggregator.Aggregator;
27
+ export declare function getAggregator(task: Task<DataType>, options?: AggregatorOptions): aggregator.Aggregator;
28
28
  export {};
@@ -1,4 +1,4 @@
1
- import type { Model, Task, WeightsContainer, RoundStatus } from '../index.js';
1
+ import type { DataType, Model, RoundStatus, Task, WeightsContainer } from "../index.js";
2
2
  import type { NodeID } from './types.js';
3
3
  import type { EventConnection } from './event_connection.js';
4
4
  import type { Aggregator } from '../aggregator/index.js';
@@ -11,7 +11,7 @@ export declare abstract class Client extends EventEmitter<{
11
11
  'status': RoundStatus;
12
12
  }> {
13
13
  readonly url: URL;
14
- readonly task: Task;
14
+ readonly task: Task<DataType>;
15
15
  readonly aggregator: Aggregator;
16
16
  protected _ownId?: NodeID;
17
17
  protected _server?: EventConnection;
@@ -30,7 +30,7 @@ export declare abstract class Client extends EventEmitter<{
30
30
  */
31
31
  private previousStatus;
32
32
  constructor(url: URL, // The network server's URL to connect to
33
- task: Task, // The client's corresponding task
33
+ task: Task<DataType>, // The client's corresponding task
34
34
  aggregator: Aggregator);
35
35
  /**
36
36
  * Communication callback called at the beginning of every training round.
@@ -47,7 +47,7 @@ export declare abstract class Client extends EventEmitter<{
47
47
  * This method is overriden by the federated and decentralized clients
48
48
  * By default, it fetches and returns the server's base model
49
49
  */
50
- connect(): Promise<Model>;
50
+ connect(): Promise<Model<DataType>>;
51
51
  /**
52
52
  * Handles the disconnection process of the client from any sort of network server.
53
53
  */
@@ -94,7 +94,7 @@ export declare abstract class Client extends EventEmitter<{
94
94
  * Fetches the latest model available on the network's server, for the adequate task.
95
95
  * @returns The latest model
96
96
  */
97
- getLatestModel(): Promise<Model>;
97
+ getLatestModel(): Promise<Model<DataType>>;
98
98
  /**
99
99
  * Number of contributors to a collaborative session
100
100
  * If decentralized, it should be the number of peers
@@ -1,4 +1,4 @@
1
- import type { Model, WeightsContainer } from "../../index.js";
1
+ import type { DataType, Model, WeightsContainer } from "../../index.js";
2
2
  import { Client } from '../client.js';
3
3
  /**
4
4
  * Represents a decentralized client in a network of peers. Peers coordinate each other with the
@@ -18,7 +18,7 @@ export declare class DecentralizedClient extends Client {
18
18
  * create peer-to-peer WebRTC connections with peers. The server is used to exchange
19
19
  * peers network information.
20
20
  */
21
- connect(): Promise<Model>;
21
+ connect(): Promise<Model<DataType>>;
22
22
  disconnect(): Promise<void>;
23
23
  /**
24
24
  * At the beginning of a round, each peer tells the server it is ready to proceed
@@ -1,4 +1,4 @@
1
- import type { Model, WeightsContainer } from "../../index.js";
1
+ import type { DataType, Model, WeightsContainer } from "../../index.js";
2
2
  import { Client } from "../client.js";
3
3
  /**
4
4
  * Client class that communicates with a centralized, federated server, when training
@@ -12,7 +12,7 @@ export declare class FederatedClient extends Client {
12
12
  * as well as the latest training information: latest global model, current round and
13
13
  * whether we are waiting for more participants.
14
14
  */
15
- connect(): Promise<Model>;
15
+ connect(): Promise<Model<DataType>>;
16
16
  /**
17
17
  * Disconnection process when user quits the task.
18
18
  */
@@ -1,4 +1,4 @@
1
- import type { Task } from '../index.js';
1
+ import type { DataType, Task } from '../index.js';
2
2
  import { client as clients, type aggregator } from '../index.js';
3
3
  export declare function timeout(ms?: number, errorMsg?: string): Promise<never>;
4
- export declare function getClient(trainingScheme: Required<Task['trainingInformation']['scheme']>, serverURL: URL, task: Task, aggregator: aggregator.Aggregator): clients.Client;
4
+ export declare function getClient(trainingScheme: Task<DataType>["trainingInformation"]["scheme"], serverURL: URL, task: Task<DataType>, aggregator: aggregator.Aggregator): clients.Client;
@@ -1,4 +1,4 @@
1
- import { List } from "immutable";
1
+ import { Batched } from "./types.js";
2
2
  type DatasetLike<T> = AsyncIterable<T> | Iterable<T> | (() => AsyncIterator<T, void>) | (() => Iterator<T, void>);
3
3
  /** Immutable series of data */
4
4
  export declare class Dataset<T> implements AsyncIterable<T> {
@@ -31,7 +31,9 @@ export declare class Dataset<T> implements AsyncIterable<T> {
31
31
  *
32
32
  * @param size count of element per chunk
33
33
  */
34
- batch(size: number): Dataset<List<T>>;
34
+ batch(size: number): Dataset<Batched<T>>;
35
+ /** Flatten chunks */
36
+ unbatch<U>(this: Dataset<Batched<U>>): Dataset<U>;
35
37
  /** Join side-by-side
36
38
  *
37
39
  * Stops as soon as one runs out
@@ -44,5 +46,10 @@ export declare class Dataset<T> implements AsyncIterable<T> {
44
46
  * This is a costly operation as we need to go through the whole Dataset.
45
47
  */
46
48
  size(): Promise<number>;
49
+ /** Try to keep generated elements to avoid recomputing
50
+ *
51
+ * Drops everything when memory pressure is applied.
52
+ */
53
+ cached(): Dataset<T>;
47
54
  }
48
55
  export {};
@@ -1,4 +1,6 @@
1
- import { List } from "immutable";
1
+ import createDebug from "debug";
2
+ import { List, Range } from "immutable";
3
+ const debug = createDebug("discojs:dataset");
2
4
  /** Immutable series of data */
3
5
  export class Dataset {
4
6
  #content;
@@ -32,13 +34,10 @@ export class Dataset {
32
34
  * @param mapper how to change each element
33
35
  */
34
36
  map(mapper) {
35
- const content = {
36
- [Symbol.asyncIterator]: () => this.#content(),
37
- };
38
37
  return new Dataset(async function* () {
39
- for await (const e of content)
38
+ for await (const e of this)
40
39
  yield await mapper(e);
41
- });
40
+ }.bind(this));
42
41
  }
43
42
  /** Combine with another Dataset.
44
43
  *
@@ -47,13 +46,10 @@ export class Dataset {
47
46
  chain(other) {
48
47
  if (!(other instanceof Dataset))
49
48
  other = new Dataset(other);
50
- const self = {
51
- [Symbol.asyncIterator]: () => this.#content(),
52
- };
53
49
  return new Dataset(async function* () {
54
- yield* self;
50
+ yield* this;
55
51
  yield* other;
56
- });
52
+ }.bind(this));
57
53
  }
58
54
  /** Divide into two based on given ratio
59
55
  *
@@ -62,16 +58,13 @@ export class Dataset {
62
58
  split(ratio) {
63
59
  if (ratio < 0 || ratio > 1)
64
60
  throw new Error("ratio out of range");
65
- const content = {
66
- [Symbol.asyncIterator]: () => this.#content(),
67
- };
68
61
  // to avoid using random sampling or knowing the size beforehand,
69
62
  // we compute the actual ratio and make it converge towards the wanted one
70
63
  return [
71
64
  new Dataset(async function* () {
72
65
  let yielded_by_other = 0;
73
66
  let total_size = 0;
74
- for await (const e of content) {
67
+ for await (const e of this) {
75
68
  total_size++;
76
69
  if (yielded_by_other / total_size >= ratio) {
77
70
  yield e;
@@ -80,18 +73,18 @@ export class Dataset {
80
73
  yielded_by_other++;
81
74
  }
82
75
  }
83
- }),
76
+ }.bind(this)),
84
77
  new Dataset(async function* () {
85
78
  let yielded = 0;
86
79
  let total_size = 0;
87
- for await (const e of content) {
80
+ for await (const e of this) {
88
81
  total_size++;
89
82
  if (yielded / total_size < ratio) {
90
83
  yielded++;
91
84
  yield e;
92
85
  }
93
86
  }
94
- }),
87
+ }.bind(this)),
95
88
  ];
96
89
  }
97
90
  /** Slice into chunks
@@ -103,21 +96,30 @@ export class Dataset {
103
96
  batch(size) {
104
97
  if (size <= 0 || !Number.isInteger(size))
105
98
  throw new Error("invalid size");
106
- const content = {
107
- [Symbol.asyncIterator]: () => this.#content(),
108
- };
109
99
  return new Dataset(async function* () {
110
- let batch = List();
111
- for await (const e of content) {
112
- batch = batch.push(e);
113
- if (batch.size === size) {
114
- yield batch;
115
- batch = List();
116
- }
117
- }
118
- if (!batch.isEmpty())
100
+ const iter = this[Symbol.asyncIterator]();
101
+ for (;;) {
102
+ const batch = List(await Promise.all(Range(0, size).map(() => iter.next()))).flatMap((res) => {
103
+ if (res.done)
104
+ return [];
105
+ else
106
+ return [res.value];
107
+ });
108
+ if (batch.isEmpty())
109
+ break;
119
110
  yield batch;
120
- });
111
+ // iterator couldn't generate more
112
+ if (batch.size < size)
113
+ break;
114
+ }
115
+ }.bind(this));
116
+ }
117
+ /** Flatten chunks */
118
+ unbatch() {
119
+ return new Dataset(async function* () {
120
+ for await (const batch of this)
121
+ yield* batch;
122
+ }.bind(this));
121
123
  }
122
124
  /** Join side-by-side
123
125
  *
@@ -128,11 +130,8 @@ export class Dataset {
128
130
  zip(other) {
129
131
  if (!(other instanceof Dataset))
130
132
  other = new Dataset(other);
131
- const content = {
132
- [Symbol.asyncIterator]: () => this.#content(),
133
- };
134
133
  return new Dataset(async function* () {
135
- const left = content[Symbol.asyncIterator]();
134
+ const left = this[Symbol.asyncIterator]();
136
135
  const right = other[Symbol.asyncIterator]();
137
136
  while (true) {
138
137
  const [l, r] = await Promise.all([left.next(), right.next()]);
@@ -140,7 +139,7 @@ export class Dataset {
140
139
  return;
141
140
  yield [l.value, r.value];
142
141
  }
143
- });
142
+ }.bind(this));
144
143
  }
145
144
  /** Compute size
146
145
  *
@@ -152,4 +151,52 @@ export class Dataset {
152
151
  ret++;
153
152
  return ret;
154
153
  }
154
+ /** Try to keep generated elements to avoid recomputing
155
+ *
156
+ * Drops everything when memory pressure is applied.
157
+ */
158
+ cached() {
159
+ return new CachingDataset(this.#content);
160
+ }
161
+ }
162
+ /**
163
+ * Avoid recomputing the parent dataset, without hogging memory
164
+ *
165
+ * As dataset operations can be time-consuming, this keeps a weak reference to
166
+ * the generated elements so that a second iteration might yield theses directly.
167
+ **/
168
+ class CachingDataset extends Dataset {
169
+ // potential reference to all elements
170
+ // tristate: undefined == empty, [false, _] == filling, [true, _] == filled
171
+ #cache = new WeakRef([false, List()]);
172
+ [Symbol.asyncIterator]() {
173
+ const cached = this.#cache.deref();
174
+ if (cached !== undefined && cached[0]) {
175
+ debug("valid cache, reading from it");
176
+ // eslint-disable-next-line @typescript-eslint/require-await
177
+ return (async function* () {
178
+ yield* cached[1];
179
+ })();
180
+ }
181
+ debug("cache invalid, reading from dataset");
182
+ this.#cache = new WeakRef([false, List()]);
183
+ const parentContent = {
184
+ [Symbol.asyncIterator]: () => super[Symbol.asyncIterator](),
185
+ };
186
+ return async function* () {
187
+ for await (const e of parentContent) {
188
+ yield e;
189
+ const caching = this.#cache.deref();
190
+ if (caching !== undefined)
191
+ caching[1] = caching[1].push(e);
192
+ }
193
+ const caching = this.#cache.deref();
194
+ if (caching === undefined) {
195
+ debug("cache evicted while filling");
196
+ return;
197
+ }
198
+ debug("cache filled");
199
+ caching[0] = true;
200
+ }.bind(this)();
201
+ }
155
202
  }
@@ -1,6 +1,11 @@
1
1
  /**
2
2
  * Raw image with type level dimensions.
3
3
  *
4
+ * Per convention, `data` layout is as follow
5
+ * `height` chunk each containing
6
+ * `width` chunk each containing
7
+ * a chunk of `depth` bytes
8
+ *
4
9
  * @typeParam D depth of the image
5
10
  * @typeParam W width, positive and integral
6
11
  * @typeParam H height, positive and integral
@@ -1,6 +1,11 @@
1
1
  /**
2
2
  * Raw image with type level dimensions.
3
3
  *
4
+ * Per convention, `data` layout is as follow
5
+ * `height` chunk each containing
6
+ * `width` chunk each containing
7
+ * a chunk of `depth` bytes
8
+ *
4
9
  * @typeParam D depth of the image
5
10
  * @typeParam W width, positive and integral
6
11
  * @typeParam H height, positive and integral
@@ -16,6 +21,6 @@ export class Image {
16
21
  this.height = height;
17
22
  this.depth = depth;
18
23
  if (data.length != width * height * depth)
19
- throw new Error("data isn't of excepted size");
24
+ throw new Error("data isn't of expected size");
20
25
  }
21
26
  }
@@ -1,3 +1,2 @@
1
1
  export { Dataset } from "./dataset.js";
2
2
  export * from "./types.js";
3
- export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING, } from "./data/index.js";
@@ -1,3 +1,2 @@
1
1
  export { Dataset } from "./dataset.js";
2
2
  export * from "./types.js";
3
- export { Data, TabularData, ImageData, TextData, ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING, } from "./data/index.js";
@@ -1,4 +1,6 @@
1
+ import { List } from "immutable";
1
2
  import { Image } from "./image.js";
3
+ export type Batched<T> = List<T>;
2
4
  export { Image };
3
5
  export type Tabular = Partial<Record<string, string>>;
4
6
  export type Text = string;
@@ -1,2 +1,2 @@
1
1
  import type { TaskProvider } from '../index.js';
2
- export declare const cifar10: TaskProvider;
2
+ export declare const cifar10: TaskProvider<'image'>;
@@ -1,5 +1,5 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { data, models } from '../index.js';
2
+ import { models } from '../index.js';
3
3
  import baseModel from '../models/mobileNet_v1_025_224.js';
4
4
  export const cifar10 = {
5
5
  getTask() {
@@ -24,7 +24,6 @@ export const cifar10 = {
24
24
  validationSplit: 0.2,
25
25
  batchSize: 10,
26
26
  dataType: 'image',
27
- preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
28
27
  IMAGE_H: 224,
29
28
  IMAGE_W: 224,
30
29
  LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
@@ -55,6 +54,6 @@ export const cifar10 = {
55
54
  loss: 'categoricalCrossentropy',
56
55
  metrics: ['accuracy']
57
56
  });
58
- return new models.TFJS(model);
57
+ return new models.TFJS('image', model);
59
58
  }
60
59
  };
@@ -1,2 +1,2 @@
1
1
  import type { TaskProvider } from '../index.js';
2
- export declare const lusCovid: TaskProvider;
2
+ export declare const lusCovid: TaskProvider<'image'>;
@@ -1,5 +1,5 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { data, models } from '../index.js';
2
+ import { models } from '../index.js';
3
3
  export const lusCovid = {
4
4
  getTask() {
5
5
  return {
@@ -24,7 +24,6 @@ export const lusCovid = {
24
24
  batchSize: 5,
25
25
  IMAGE_H: 100,
26
26
  IMAGE_W: 100,
27
- preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
28
27
  LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
29
28
  dataType: 'image',
30
29
  scheme: 'federated',
@@ -82,6 +81,6 @@ export const lusCovid = {
82
81
  loss: 'binaryCrossentropy',
83
82
  metrics: ['accuracy']
84
83
  });
85
- return Promise.resolve(new models.TFJS(model));
84
+ return Promise.resolve(new models.TFJS('image', model));
86
85
  }
87
86
  };
@@ -1,2 +1,2 @@
1
1
  import type { TaskProvider } from '../index.js';
2
- export declare const mnist: TaskProvider;
2
+ export declare const mnist: TaskProvider<'image'>;
@@ -1,5 +1,5 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { data, models } from '../index.js';
2
+ import { models } from '../index.js';
3
3
  export const mnist = {
4
4
  getTask() {
5
5
  return {
@@ -25,8 +25,6 @@ export const mnist = {
25
25
  dataType: 'image',
26
26
  IMAGE_H: 28,
27
27
  IMAGE_W: 28,
28
- // Images should already be at the right size but resizing just in case
29
- preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
30
28
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
31
29
  scheme: 'decentralized',
32
30
  aggregationStrategy: 'secure',
@@ -58,6 +56,6 @@ export const mnist = {
58
56
  loss: 'categoricalCrossentropy',
59
57
  metrics: ['accuracy']
60
58
  });
61
- return Promise.resolve(new models.TFJS(model));
59
+ return Promise.resolve(new models.TFJS('image', model));
62
60
  }
63
61
  };
@@ -1,2 +1,2 @@
1
1
  import type { TaskProvider } from '../index.js';
2
- export declare const simpleFace: TaskProvider;
2
+ export declare const simpleFace: TaskProvider<'image'>;
@@ -1,5 +1,5 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { data, models } from '../index.js';
2
+ import { models } from '../index.js';
3
3
  import baseModel from '../models/mobileNetV2_35_alpha_2_classes.js';
4
4
  export const simpleFace = {
5
5
  getTask() {
@@ -22,7 +22,6 @@ export const simpleFace = {
22
22
  roundDuration: 1,
23
23
  validationSplit: 0.2,
24
24
  batchSize: 10,
25
- preprocessingFunctions: [data.ImagePreprocessing.Normalize],
26
25
  dataType: 'image',
27
26
  IMAGE_H: 200,
28
27
  IMAGE_W: 200,
@@ -43,6 +42,6 @@ export const simpleFace = {
43
42
  loss: 'categoricalCrossentropy',
44
43
  metrics: ['accuracy']
45
44
  });
46
- return new models.TFJS(model);
45
+ return new models.TFJS('image', model);
47
46
  }
48
47
  };
@@ -1,2 +1,2 @@
1
1
  import type { TaskProvider } from '../index.js';
2
- export declare const titanic: TaskProvider;
2
+ export declare const titanic: TaskProvider<'tabular'>;
@@ -1,5 +1,5 @@
1
1
  import * as tf from '@tensorflow/tfjs';
2
- import { data, models } from '../index.js';
2
+ import { models } from '../index.js';
3
3
  export const titanic = {
4
4
  getTask() {
5
5
  return {
@@ -49,7 +49,6 @@ export const titanic = {
49
49
  roundDuration: 2,
50
50
  validationSplit: 0.2,
51
51
  batchSize: 30,
52
- preprocessingFunctions: [data.TabularPreprocessing.Sanitize],
53
52
  dataType: 'tabular',
54
53
  inputColumns: [
55
54
  'Age',
@@ -58,9 +57,7 @@ export const titanic = {
58
57
  'Fare',
59
58
  'Pclass'
60
59
  ],
61
- outputColumns: [
62
- 'Survived'
63
- ],
60
+ outputColumn: 'Survived',
64
61
  scheme: 'federated',
65
62
  aggregationStrategy: 'mean',
66
63
  minNbOfParticipants: 2,
@@ -84,6 +81,6 @@ export const titanic = {
84
81
  loss: 'binaryCrossentropy',
85
82
  metrics: ['accuracy']
86
83
  });
87
- return Promise.resolve(new models.TFJS(model));
84
+ return Promise.resolve(new models.TFJS('tabular', model));
88
85
  }
89
86
  };
@@ -1,2 +1,2 @@
1
1
  import type { TaskProvider } from '../index.js';
2
- export declare const wikitext: TaskProvider;
2
+ export declare const wikitext: TaskProvider<'text'>;
@@ -1,4 +1,4 @@
1
- import { data, models } from '../index.js';
1
+ import { models } from '../index.js';
2
2
  export const wikitext = {
3
3
  getTask() {
4
4
  return {
@@ -23,7 +23,6 @@ export const wikitext = {
23
23
  },
24
24
  trainingInformation: {
25
25
  dataType: 'text',
26
- preprocessingFunctions: [data.TextPreprocessing.Tokenize, data.TextPreprocessing.LeftPadding],
27
26
  scheme: 'federated',
28
27
  aggregationStrategy: 'mean',
29
28
  minNbOfParticipants: 2,
package/dist/index.d.ts CHANGED
@@ -7,14 +7,13 @@ export * as aggregator from './aggregator/index.js';
7
7
  export { WeightsContainer, aggregation } from './weights/index.js';
8
8
  export { Logger, ConsoleLogger } from './logging/index.js';
9
9
  export { Disco, RoundLogs, RoundStatus } from './training/index.js';
10
- export { Validator } from './validation/index.js';
10
+ export { Validator } from './validator.js';
11
11
  export { Model, BatchLogs, EpochLogs, ValidationMetrics } from './models/index.js';
12
12
  export * as models from './models/index.js';
13
13
  export * from './task/index.js';
14
14
  export * as defaultTasks from './default_tasks/index.js';
15
15
  export * as async_iterator from "./utils/async_iterator.js";
16
16
  export { EventEmitter } from "./utils/event_emitter.js";
17
- export { Dataset } from "./dataset/index.js";
18
- export * from "./dataset/types.js";
19
- export * from "./types.js";
20
- export * as processing from "./processing.js";
17
+ export * from "./dataset/index.js";
18
+ export * from "./types/index.js";
19
+ export * as processing from "./processing/index.js";
package/dist/index.js CHANGED
@@ -7,14 +7,13 @@ export * as aggregator from './aggregator/index.js';
7
7
  export { WeightsContainer, aggregation } from './weights/index.js';
8
8
  export { ConsoleLogger } from './logging/index.js';
9
9
  export { Disco } from './training/index.js';
10
- export { Validator } from './validation/index.js';
10
+ export { Validator } from './validator.js';
11
11
  export { Model, EpochLogs } from './models/index.js';
12
12
  export * as models from './models/index.js';
13
13
  export * from './task/index.js';
14
14
  export * as defaultTasks from './default_tasks/index.js';
15
15
  export * as async_iterator from "./utils/async_iterator.js";
16
16
  export { EventEmitter } from "./utils/event_emitter.js";
17
- export { Dataset } from "./dataset/index.js";
18
- export * from "./dataset/types.js"; // TODO merge with above
19
- export * from "./types.js";
20
- export * as processing from "./processing.js";
17
+ export * from "./dataset/index.js";
18
+ export * from "./types/index.js";
19
+ export * as processing from "./processing/index.js";
@@ -1,17 +1,20 @@
1
1
  /**
2
2
  * this code is taken from gpt-tfjs with modifications from @peacefulotter and @lukemovement
3
3
  **/
4
- import * as tf from '@tensorflow/tfjs';
5
- import { PreTrainedTokenizer } from '@xenova/transformers';
6
- import { WeightsContainer } from '../../index.js';
4
+ import * as tf from "@tensorflow/tfjs";
5
+ import type { Batched, Dataset, DataFormat } from "../../index.js";
6
+ import { WeightsContainer } from "../../index.js";
7
7
  import { BatchLogs, Model, EpochLogs } from "../index.js";
8
- import type { Prediction, Sample } from '../model.js';
9
- import { type GPTConfig } from './config.js';
8
+ import { type GPTConfig } from "./config.js";
10
9
  export type GPTSerialization = {
11
10
  weights: WeightsContainer;
12
11
  config?: GPTConfig;
13
12
  };
14
- export declare class GPT extends Model {
13
+ interface PredictConfig {
14
+ temperature: number;
15
+ doSample: boolean;
16
+ }
17
+ export declare class GPT extends Model<"text"> {
15
18
  #private;
16
19
  private readonly model;
17
20
  constructor(partialConfig?: GPTConfig, layersModel?: tf.LayersModel);
@@ -24,20 +27,14 @@ export declare class GPT extends Model {
24
27
  * @param epochs the number of passes of the training dataset
25
28
  * @param tracker
26
29
  */
27
- train(trainingData: tf.data.Dataset<{
28
- xs: tf.Tensor2D;
29
- ys: tf.Tensor3D;
30
- }>, validationData?: tf.data.Dataset<{
31
- xs: tf.Tensor2D;
32
- ys: tf.Tensor3D;
33
- }>): AsyncGenerator<BatchLogs, EpochLogs>;
34
- predict(input: Sample): Promise<Prediction>;
35
- generate(input: string, tokenizer: PreTrainedTokenizer, newTokens?: number): Promise<string>;
30
+ train(trainingDataset: Dataset<Batched<DataFormat.ModelEncoded["text"]>>, validationDataset?: Dataset<Batched<DataFormat.ModelEncoded["text"]>>): AsyncGenerator<BatchLogs, EpochLogs>;
31
+ predict(batch: Batched<DataFormat.ModelEncoded["text"][0]>, options?: Partial<PredictConfig>): Promise<Batched<DataFormat.ModelEncoded["text"][1]>>;
36
32
  get config(): Required<GPTConfig>;
37
33
  get weights(): WeightsContainer;
38
34
  set weights(ws: WeightsContainer);
39
- static deserialize(data: GPTSerialization): Model;
35
+ static deserialize(data: GPTSerialization): Model<"text">;
40
36
  serialize(): GPTSerialization;
41
37
  extract(): tf.LayersModel;
42
38
  [Symbol.dispose](): void;
43
39
  }
40
+ export {};