@epfml/discojs 2.2.1 → 3.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. package/dist/aggregator/base.d.ts +9 -48
  2. package/dist/aggregator/base.js +8 -69
  3. package/dist/aggregator/get.d.ts +23 -11
  4. package/dist/aggregator/get.js +40 -23
  5. package/dist/aggregator/index.d.ts +1 -1
  6. package/dist/aggregator/index.js +1 -1
  7. package/dist/aggregator/mean.d.ts +25 -6
  8. package/dist/aggregator/mean.js +62 -17
  9. package/dist/aggregator/secure.d.ts +2 -2
  10. package/dist/aggregator/secure.js +4 -7
  11. package/dist/client/base.d.ts +3 -3
  12. package/dist/client/base.js +6 -8
  13. package/dist/client/decentralized/base.d.ts +27 -10
  14. package/dist/client/decentralized/base.js +123 -86
  15. package/dist/client/decentralized/peer.js +7 -12
  16. package/dist/client/decentralized/peer_pool.js +6 -2
  17. package/dist/client/event_connection.d.ts +1 -1
  18. package/dist/client/event_connection.js +3 -3
  19. package/dist/client/federated/base.d.ts +5 -21
  20. package/dist/client/federated/base.js +38 -61
  21. package/dist/client/federated/messages.d.ts +2 -10
  22. package/dist/client/federated/messages.js +0 -1
  23. package/dist/client/index.d.ts +1 -1
  24. package/dist/client/index.js +1 -1
  25. package/dist/client/local.d.ts +3 -1
  26. package/dist/client/local.js +4 -1
  27. package/dist/client/messages.d.ts +1 -2
  28. package/dist/client/messages.js +8 -3
  29. package/dist/client/utils.d.ts +4 -2
  30. package/dist/client/utils.js +18 -3
  31. package/dist/dataset/data/data.d.ts +1 -1
  32. package/dist/dataset/data/data.js +13 -2
  33. package/dist/dataset/data/preprocessing/image_preprocessing.js +6 -4
  34. package/dist/default_tasks/cifar10.js +1 -2
  35. package/dist/default_tasks/lus_covid.js +0 -5
  36. package/dist/default_tasks/mnist.js +15 -14
  37. package/dist/default_tasks/simple_face.js +0 -2
  38. package/dist/default_tasks/titanic.js +2 -4
  39. package/dist/default_tasks/wikitext.js +7 -1
  40. package/dist/index.d.ts +0 -1
  41. package/dist/index.js +0 -1
  42. package/dist/models/gpt/config.js +1 -1
  43. package/dist/privacy.d.ts +8 -10
  44. package/dist/privacy.js +25 -40
  45. package/dist/task/task_handler.js +10 -2
  46. package/dist/task/training_information.d.ts +7 -4
  47. package/dist/task/training_information.js +25 -6
  48. package/dist/training/disco.d.ts +30 -28
  49. package/dist/training/disco.js +75 -73
  50. package/dist/training/index.d.ts +1 -1
  51. package/dist/training/index.js +1 -0
  52. package/dist/training/trainer.d.ts +16 -0
  53. package/dist/training/trainer.js +72 -0
  54. package/dist/types.d.ts +0 -2
  55. package/dist/weights/weights_container.d.ts +0 -5
  56. package/dist/weights/weights_container.js +0 -7
  57. package/package.json +1 -1
  58. package/dist/async_informant.d.ts +0 -15
  59. package/dist/async_informant.js +0 -42
  60. package/dist/training/trainer/distributed_trainer.d.ts +0 -20
  61. package/dist/training/trainer/distributed_trainer.js +0 -41
  62. package/dist/training/trainer/local_trainer.d.ts +0 -12
  63. package/dist/training/trainer/local_trainer.js +0 -24
  64. package/dist/training/trainer/trainer.d.ts +0 -32
  65. package/dist/training/trainer/trainer.js +0 -61
  66. package/dist/training/trainer/trainer_builder.d.ts +0 -23
  67. package/dist/training/trainer/trainer_builder.js +0 -47
@@ -4,5 +4,5 @@ export * as aggregator from '../aggregator/index.js';
4
4
  export * as decentralized from './decentralized/index.js';
5
5
  export * as federated from './federated/index.js';
6
6
  export * as messages from './messages.js';
7
- export * as utils from './utils.js';
7
+ export { getClient, timeout } from './utils.js';
8
8
  export { Local } from './local.js';
@@ -1,3 +1,5 @@
1
- import { Base } from './base.js';
1
+ import { WeightsContainer } from "../weights/weights_container.js";
2
+ import { Base } from "./base.js";
2
3
  export declare class Local extends Base {
4
+ onRoundEndCommunication(weights: WeightsContainer): Promise<WeightsContainer>;
3
5
  }
@@ -1,3 +1,6 @@
1
- import { Base } from './base.js';
1
+ import { Base } from "./base.js";
2
2
  export class Local extends Base {
3
+ onRoundEndCommunication(weights) {
4
+ return Promise.resolve(weights);
5
+ }
3
6
  }
@@ -9,8 +9,7 @@ export declare enum type {
9
9
  PeersForRound = 4,
10
10
  Payload = 5,
11
11
  SendPayload = 6,
12
- ReceiveServerMetadata = 7,
13
- ReceiveServerPayload = 8
12
+ ReceiveServerPayload = 7
14
13
  }
15
14
  export interface ClientConnected {
16
15
  type: type.ClientConnected;
@@ -1,16 +1,21 @@
1
1
  export var type;
2
2
  (function (type) {
3
+ // Sent from client to server as first point of contact to join a task.
4
+ // The server answers with an node id in a AssignNodeID message
3
5
  type[type["ClientConnected"] = 0] = "ClientConnected";
6
+ // When a user joins a task with a ClientConnected message, the server
7
+ // answers with an AssignNodeID message with its peer id.
4
8
  type[type["AssignNodeID"] = 1] = "AssignNodeID";
5
- // Decentralized
9
+ /* Decentralized */
10
+ // Message forwarded by the server from a client to another client
11
+ // to establish a peer-to-peer (WebRTC) connection
6
12
  type[type["SignalForPeer"] = 2] = "SignalForPeer";
7
13
  type[type["PeerIsReady"] = 3] = "PeerIsReady";
8
14
  type[type["PeersForRound"] = 4] = "PeersForRound";
9
15
  type[type["Payload"] = 5] = "Payload";
10
16
  // Federated
11
17
  type[type["SendPayload"] = 6] = "SendPayload";
12
- type[type["ReceiveServerMetadata"] = 7] = "ReceiveServerMetadata";
13
- type[type["ReceiveServerPayload"] = 8] = "ReceiveServerPayload";
18
+ type[type["ReceiveServerPayload"] = 7] = "ReceiveServerPayload";
14
19
  })(type || (type = {}));
15
20
  export function hasMessageType(raw) {
16
21
  if (typeof raw !== 'object' || raw === null) {
@@ -1,2 +1,4 @@
1
- export declare const MAX_WAIT_PER_ROUND = 15000;
2
- export declare function timeout(ms?: number): Promise<never>;
1
+ import type { Task } from '../index.js';
2
+ import { client as clients, type aggregator } from '../index.js';
3
+ export declare function timeout(ms?: number, errorMsg?: string): Promise<never>;
4
+ export declare function getClient(trainingScheme: Required<Task['trainingInformation']['scheme']>, serverURL: URL, task: Task, aggregator: aggregator.Aggregator): clients.Client;
@@ -1,7 +1,22 @@
1
+ import { client as clients } from '../index.js';
1
2
  // Time to wait for the others in milliseconds.
2
- export const MAX_WAIT_PER_ROUND = 15_000;
3
- export async function timeout(ms = MAX_WAIT_PER_ROUND) {
3
+ const MAX_WAIT_PER_ROUND = 15_000;
4
+ export async function timeout(ms = MAX_WAIT_PER_ROUND, errorMsg = 'timeout') {
4
5
  return await new Promise((_, reject) => {
5
- setTimeout(() => { reject(new Error('timeout')); }, ms);
6
+ setTimeout(() => { reject(new Error(errorMsg)); }, ms);
6
7
  });
7
8
  }
9
+ export function getClient(trainingScheme, serverURL, task, aggregator) {
10
+ switch (trainingScheme) {
11
+ case 'decentralized':
12
+ return new clients.decentralized.DecentralizedClient(serverURL, task, aggregator);
13
+ case 'federated':
14
+ return new clients.federated.FederatedClient(serverURL, task, aggregator);
15
+ case 'local':
16
+ return new clients.Local(serverURL, task, aggregator);
17
+ default: {
18
+ const _ = trainingScheme;
19
+ throw new Error('should never happen');
20
+ }
21
+ }
22
+ }
@@ -1,4 +1,4 @@
1
- import type tf from '@tensorflow/tfjs';
1
+ import * as tf from '@tensorflow/tfjs';
2
2
  import type { List } from 'immutable';
3
3
  import type { Task } from '../../index.js';
4
4
  import type { Dataset } from '../index.js';
@@ -1,3 +1,4 @@
1
+ import * as tf from '@tensorflow/tfjs';
1
2
  /**
2
3
  * Abstract class representing an immutable Disco dataset, including a TF.js dataset,
3
4
  * Disco task and set of preprocessing functions.
@@ -59,8 +60,18 @@ export class Data {
59
60
  if (applyPreprocessing.size === 0) {
60
61
  return x => Promise.resolve(x);
61
62
  }
62
- const preprocessingChain = applyPreprocessing.reduce((acc, fn) => x => fn(acc(x), this.task), (x) => x);
63
- return x => preprocessingChain(Promise.resolve(x));
63
+ const preprocessingChain = async (input) => {
64
+ let currentContainer = await input; // Start with the initial tensor container
65
+ for (const fn of applyPreprocessing) {
66
+ const newContainer = await fn(Promise.resolve(currentContainer), this.task);
67
+ if (currentContainer !== newContainer) {
68
+ tf.dispose(currentContainer); // Dispose of the old container
69
+ }
70
+ currentContainer = newContainer;
71
+ }
72
+ return currentContainer; // Return the final tensor container
73
+ };
74
+ return async (entry) => await preprocessingChain(Promise.resolve(entry));
64
75
  }
65
76
  /**
66
77
  * The TF.js dataset preprocessing according to the set of preprocessing functions and the task's
@@ -25,10 +25,12 @@ const normalize = {
25
25
  type: ImagePreprocessing.Normalize,
26
26
  apply: async (entry) => {
27
27
  const { xs, ys } = await entry;
28
- return {
29
- xs: xs.div(tf.scalar(255)),
30
- ys
31
- };
28
+ return tf.tidy(() => {
29
+ return {
30
+ xs: xs.div(tf.scalar(255)),
31
+ ys
32
+ };
33
+ });
32
34
  }
33
35
  };
34
36
  /**
@@ -30,8 +30,7 @@ export const cifar10 = {
30
30
  IMAGE_W: 224,
31
31
  LABEL_LIST: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
32
32
  scheme: 'decentralized',
33
- noiseScale: undefined,
34
- clippingRadius: 20,
33
+ privacy: { clippingRadius: 20, noiseScale: 1 },
35
34
  decentralizedSecure: true,
36
35
  minimumReadyPeers: 3,
37
36
  maxShareValue: 100,
@@ -29,11 +29,6 @@ export const lusCovid = {
29
29
  LABEL_LIST: ['COVID-Positive', 'COVID-Negative'],
30
30
  dataType: 'image',
31
31
  scheme: 'federated',
32
- noiseScale: undefined,
33
- clippingRadius: 20,
34
- decentralizedSecure: true,
35
- minimumReadyPeers: 2,
36
- maxShareValue: 100,
37
32
  tensorBackend: 'tfjs'
38
33
  }
39
34
  };
@@ -20,17 +20,16 @@ export const mnist = {
20
20
  trainingInformation: {
21
21
  modelID: 'mnist-model',
22
22
  epochs: 20,
23
- roundDuration: 10,
23
+ roundDuration: 2,
24
24
  validationSplit: 0.2,
25
- batchSize: 30,
25
+ batchSize: 64,
26
26
  dataType: 'image',
27
27
  IMAGE_H: 28,
28
28
  IMAGE_W: 28,
29
- preprocessingFunctions: [data.ImagePreprocessing.Normalize],
29
+ // Images should already be at the right size but resizing just in case
30
+ preprocessingFunctions: [data.ImagePreprocessing.Resize, data.ImagePreprocessing.Normalize],
30
31
  LABEL_LIST: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'],
31
32
  scheme: 'decentralized',
32
- noiseScale: undefined,
33
- clippingRadius: 20,
34
33
  decentralizedSecure: true,
35
34
  minimumReadyPeers: 3,
36
35
  maxShareValue: 100,
@@ -39,22 +38,24 @@ export const mnist = {
39
38
  };
40
39
  },
41
40
  getModel() {
41
+ // Architecture from the PyTorch MNIST example (I made it slightly smaller, 650kB instead of 5MB)
42
+ // https://github.com/pytorch/examples/blob/main/mnist/main.py
42
43
  const model = tf.sequential();
43
44
  model.add(tf.layers.conv2d({
44
45
  inputShape: [28, 28, 3],
45
- kernelSize: 3,
46
- filters: 16,
47
- activation: 'relu'
46
+ kernelSize: 5,
47
+ filters: 8,
48
+ activation: 'relu',
48
49
  }));
50
+ model.add(tf.layers.conv2d({ kernelSize: 5, filters: 16, activation: 'relu' }));
49
51
  model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
50
- model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
51
- model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
52
- model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
53
- model.add(tf.layers.flatten({}));
54
- model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
52
+ model.add(tf.layers.dropout({ rate: 0.25 }));
53
+ model.add(tf.layers.flatten());
54
+ model.add(tf.layers.dense({ units: 32, activation: 'relu' }));
55
+ model.add(tf.layers.dropout({ rate: 0.25 }));
55
56
  model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
56
57
  model.compile({
57
- optimizer: 'rmsprop',
58
+ optimizer: 'adam',
58
59
  loss: 'categoricalCrossentropy',
59
60
  metrics: ['accuracy']
60
61
  });
@@ -29,8 +29,6 @@ export const simpleFace = {
29
29
  IMAGE_W: 200,
30
30
  LABEL_LIST: ['child', 'adult'],
31
31
  scheme: 'federated', // secure aggregation not yet implemented for federated
32
- noiseScale: undefined,
33
- clippingRadius: undefined,
34
32
  tensorBackend: 'tfjs'
35
33
  }
36
34
  };
@@ -46,8 +46,8 @@ export const titanic = {
46
46
  },
47
47
  trainingInformation: {
48
48
  modelID: 'titanic-model',
49
- epochs: 40,
50
- roundDuration: 10,
49
+ epochs: 10,
50
+ roundDuration: 2,
51
51
  validationSplit: 0.2,
52
52
  batchSize: 30,
53
53
  preprocessingFunctions: [data.TabularPreprocessing.Sanitize],
@@ -63,8 +63,6 @@ export const titanic = {
63
63
  'Survived'
64
64
  ],
65
65
  scheme: 'federated', // secure aggregation not yet implemented for FeAI
66
- noiseScale: undefined,
67
- clippingRadius: undefined,
68
66
  tensorBackend: 'tfjs'
69
67
  }
70
68
  };
@@ -9,7 +9,13 @@ export const wikitext = {
9
9
  preview: 'Train a language model (L)LM in your browser, collaboratively and from scratch.',
10
10
  overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling, which you can download <a class='underline text-blue-400' target='_blank' href='https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz'>here</a>. More information on how to connect the dataset at the next step."
11
11
  },
12
- model: 'The model follows the exact GPT-2 architecture and is implemented in TensorFlow.js. The tokenizer used for preprocessing is the GPT-2 Byte-Pair encoding tokenizer. The model is trained via an Adam optimizer with unit gradient clipping and softmax cross-entropy loss. To accommodate all devices, the context length is currently kept at 128 and the batch size at 1.',
12
+ model: [
13
+ "The model follows the exact GPT-2 architecture and is implemented in TensorFlow.js.",
14
+ "The tokenizer used for preprocessing is the GPT-2 Byte-Pair encoding tokenizer.",
15
+ "The model is trained via an Adam optimizer with unit gradient clipping and softmax cross-entropy loss.",
16
+ "It has around 5M parameters.",
17
+ "To accommodate all devices, the context length is currently kept at 128 and the batch size at 1.",
18
+ ].join(" "),
13
19
  dataFormatInformation: 'You can use any natural language (text) dataset you like. For example the Wikitext-103 dataset is organized as a large text file, with each line representing a segment of raw text from Wikipedia articles.',
14
20
  dataExampleText: 'An example excerpt from the dataset is: <i>"For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work ( except eight private performances for Ludwig II at Munich in 1884 and 1885 ) ."</i>',
15
21
  sampleDatasetLink: 'https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz',
package/dist/index.d.ts CHANGED
@@ -5,7 +5,6 @@ export * as privacy from './privacy.js';
5
5
  export * as client from './client/index.js';
6
6
  export * as aggregator from './aggregator/index.js';
7
7
  export { WeightsContainer, aggregation } from './weights/index.js';
8
- export { AsyncInformant } from './async_informant.js';
9
8
  export { Logger, ConsoleLogger } from './logging/index.js';
10
9
  export { Memory, type ModelInfo, type Path, type ModelSource, Empty as EmptyMemory } from './memory/index.js';
11
10
  export { Disco, RoundLogs } from './training/index.js';
package/dist/index.js CHANGED
@@ -5,7 +5,6 @@ export * as privacy from './privacy.js';
5
5
  export * as client from './client/index.js';
6
6
  export * as aggregator from './aggregator/index.js';
7
7
  export { WeightsContainer, aggregation } from './weights/index.js';
8
- export { AsyncInformant } from './async_informant.js';
9
8
  export { ConsoleLogger } from './logging/index.js';
10
9
  export { Memory, Empty as EmptyMemory } from './memory/index.js';
11
10
  export { Disco } from './training/index.js';
@@ -3,7 +3,7 @@ export const DEFAULT_CONFIG = {
3
3
  name: 'transformer',
4
4
  lr: 0.001,
5
5
  weightDecay: 0,
6
- maxIter: 5,
6
+ maxIter: 10,
7
7
  verbose: 0,
8
8
  modelType: 'gpt-nano',
9
9
  evaluate: true,
package/dist/privacy.d.ts CHANGED
@@ -1,11 +1,9 @@
1
- import type { Task, WeightsContainer } from './index.js';
1
+ import type { WeightsContainer } from "./index.js";
2
+ /** Scramble weights */
3
+ export declare function addNoise(weights: WeightsContainer, deviation: number): WeightsContainer;
2
4
  /**
3
- * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
4
- * The previous round's weights are the last weights pulled from server/peers.
5
- * The current round's weights are obtained after a single round of training, from the previous round's weights.
6
- * @param updatedWeights weights from the current round
7
- * @param staleWeights weights from the previous round
8
- * @param task the task
9
- * @returns the noised weights for the current round
10
- */
11
- export declare function addDifferentialPrivacy(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, task: Task): WeightsContainer;
5
+ * Keep weights' norm within radius
6
+ *
7
+ * @param radius maximum norm
8
+ **/
9
+ export declare function clipNorm(weights: WeightsContainer, radius: number): Promise<WeightsContainer>;
package/dist/privacy.js CHANGED
@@ -1,42 +1,27 @@
1
- import * as tf from '@tensorflow/tfjs';
1
+ import * as tf from "@tensorflow/tfjs";
2
+ async function frobeniusNorm(weights) {
3
+ const squared = await weights
4
+ .map((w) => w.square().sum())
5
+ .reduce((a, b) => a.add(b))
6
+ .data();
7
+ if (squared.length !== 1)
8
+ throw new Error("unexcepted weights shape");
9
+ return Math.sqrt(squared[0]);
10
+ }
11
+ /** Scramble weights */
12
+ export function addNoise(weights, deviation) {
13
+ const variance = Math.pow(deviation, 2);
14
+ return weights.map((w) => w.add(tf.randomNormal(w.shape, 0, variance)));
15
+ }
2
16
  /**
3
- * Add task-parametrized Gaussian noise to and clip the weights update between the previous and current rounds.
4
- * The previous round's weights are the last weights pulled from server/peers.
5
- * The current round's weights are obtained after a single round of training, from the previous round's weights.
6
- * @param updatedWeights weights from the current round
7
- * @param staleWeights weights from the previous round
8
- * @param task the task
9
- * @returns the noised weights for the current round
10
- */
11
- export function addDifferentialPrivacy(updatedWeights, staleWeights, task) {
12
- const noiseScale = task.trainingInformation?.noiseScale;
13
- const clippingRadius = task.trainingInformation?.clippingRadius;
14
- const weightsDiff = updatedWeights.sub(staleWeights);
15
- let newWeightsDiff;
16
- if (clippingRadius !== undefined) {
17
- // Frobenius norm
18
- const norm = weightsDiff.frobeniusNorm();
19
- newWeightsDiff = weightsDiff.map((w) => {
20
- const clipped = w.div(Math.max(1, norm / clippingRadius));
21
- if (noiseScale !== undefined) {
22
- // Add clipping and noise
23
- const noise = tf.randomNormal(w.shape, 0, (noiseScale * noiseScale) * (clippingRadius * clippingRadius));
24
- return clipped.add(noise);
25
- }
26
- else {
27
- // Add clipping without any noise
28
- return clipped;
29
- }
30
- });
31
- }
32
- else {
33
- if (noiseScale !== undefined) {
34
- // Add noise without any clipping
35
- newWeightsDiff = weightsDiff.map((w) => tf.randomNormal(w.shape, 0, (noiseScale * noiseScale)));
36
- }
37
- else {
38
- return updatedWeights;
39
- }
40
- }
41
- return staleWeights.add(newWeightsDiff);
17
+ * Keep weights' norm within radius
18
+ *
19
+ * @param radius maximum norm
20
+ **/
21
+ export async function clipNorm(weights, radius) {
22
+ if (radius <= 0)
23
+ throw new Error("invalid radius");
24
+ const norm = await frobeniusNorm(weights);
25
+ const scaling = Math.max(1, norm / radius);
26
+ return weights.map((w) => w.div(scaling));
42
27
  }
@@ -13,8 +13,16 @@ export async function pushTask(url, task, model) {
13
13
  export async function fetchTasks(url) {
14
14
  const response = await axios.get(new URL(TASK_ENDPOINT, url).href);
15
15
  const tasks = response.data;
16
- if (!(Array.isArray(tasks) && tasks.every(isTask))) {
17
- throw new Error('invalid tasks response');
16
+ if (!Array.isArray(tasks)) {
17
+ throw new Error('Expected to receive an array of Tasks when fetching tasks');
18
+ }
19
+ else if (!tasks.every(isTask)) {
20
+ for (const task of tasks) {
21
+ if (!isTask(task)) {
22
+ console.error("task has invalid format:", task);
23
+ }
24
+ }
25
+ throw new Error('invalid tasks response, the task object received is not well formatted');
18
26
  }
19
27
  return Map(tasks.map((t) => [t.id, t]));
20
28
  }
@@ -1,6 +1,9 @@
1
- import type { AggregatorChoice } from '../aggregator/get.js';
2
1
  import type { Preprocessing } from '../dataset/data/preprocessing/index.js';
3
2
  import { PreTrainedTokenizer } from '@xenova/transformers';
3
+ interface Privacy {
4
+ clippingRadius?: number;
5
+ noiseScale?: number;
6
+ }
4
7
  export interface TrainingInformation {
5
8
  modelID: string;
6
9
  epochs: number;
@@ -15,14 +18,14 @@ export interface TrainingInformation {
15
18
  IMAGE_W?: number;
16
19
  LABEL_LIST?: string[];
17
20
  scheme: 'decentralized' | 'federated' | 'local';
18
- noiseScale?: number;
19
- clippingRadius?: number;
21
+ privacy?: Privacy;
20
22
  decentralizedSecure?: boolean;
21
23
  maxShareValue?: number;
22
24
  minimumReadyPeers?: number;
23
- aggregator?: AggregatorChoice;
25
+ aggregator?: 'mean' | 'secure';
24
26
  tokenizer?: string | PreTrainedTokenizer;
25
27
  maxSequenceLength?: number;
26
28
  tensorBackend: 'tfjs' | 'gpt';
27
29
  }
28
30
  export declare function isTrainingInformation(raw: unknown): raw is TrainingInformation;
31
+ export {};
@@ -6,11 +6,25 @@ function isStringArray(raw) {
6
6
  const arr = raw; // isArray is unsafely guarding with any[]
7
7
  return arr.every((e) => typeof e === 'string');
8
8
  }
9
+ function isPrivacy(raw) {
10
+ if (typeof raw !== "object" || raw === null) {
11
+ return false;
12
+ }
13
+ const { clippingRadius, noiseScale, } = raw;
14
+ if ((clippingRadius !== undefined && typeof clippingRadius !== "number") ||
15
+ (noiseScale !== undefined && typeof noiseScale !== "number"))
16
+ return false;
17
+ const _ = {
18
+ clippingRadius,
19
+ noiseScale,
20
+ };
21
+ return true;
22
+ }
9
23
  export function isTrainingInformation(raw) {
10
24
  if (typeof raw !== 'object' || raw === null) {
11
25
  return false;
12
26
  }
13
- const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, clippingRadius, dataType, decentralizedSecure, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, noiseScale, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
27
+ const { IMAGE_H, IMAGE_W, LABEL_LIST, aggregator, batchSize, dataType, decentralizedSecure, privacy, epochs, inputColumns, maxShareValue, minimumReadyPeers, modelID, outputColumns, preprocessingFunctions, roundDuration, scheme, validationSplit, tokenizer, maxSequenceLength, tensorBackend } = raw;
14
28
  if (typeof dataType !== 'string' ||
15
29
  typeof modelID !== 'string' ||
16
30
  typeof epochs !== 'number' ||
@@ -19,12 +33,11 @@ export function isTrainingInformation(raw) {
19
33
  typeof validationSplit !== 'number' ||
20
34
  (tokenizer !== undefined && typeof tokenizer !== 'string' && !(tokenizer instanceof PreTrainedTokenizer)) ||
21
35
  (maxSequenceLength !== undefined && typeof maxSequenceLength !== 'number') ||
22
- (aggregator !== undefined && typeof aggregator !== 'number') ||
23
- (clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
36
+ (aggregator !== undefined && typeof aggregator !== 'string') ||
24
37
  (decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
38
+ (privacy !== undefined && !isPrivacy(privacy)) ||
25
39
  (maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
26
40
  (minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
27
- (noiseScale !== undefined && typeof noiseScale !== 'number') ||
28
41
  (IMAGE_H !== undefined && typeof IMAGE_H !== 'number') ||
29
42
  (IMAGE_W !== undefined && typeof IMAGE_W !== 'number') ||
30
43
  (LABEL_LIST !== undefined && !isStringArray(LABEL_LIST)) ||
@@ -33,6 +46,13 @@ export function isTrainingInformation(raw) {
33
46
  (preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions))) {
34
47
  return false;
35
48
  }
49
+ if (aggregator !== undefined) {
50
+ switch (aggregator) {
51
+ case 'mean': break;
52
+ case 'secure': break;
53
+ default: return false;
54
+ }
55
+ }
36
56
  switch (dataType) {
37
57
  case 'image': break;
38
58
  case 'tabular': break;
@@ -70,15 +90,14 @@ export function isTrainingInformation(raw) {
70
90
  LABEL_LIST,
71
91
  aggregator,
72
92
  batchSize,
73
- clippingRadius,
74
93
  dataType,
75
94
  decentralizedSecure,
95
+ privacy,
76
96
  epochs,
77
97
  inputColumns,
78
98
  maxShareValue,
79
99
  minimumReadyPeers,
80
100
  modelID,
81
- noiseScale,
82
101
  outputColumns,
83
102
  preprocessingFunctions,
84
103
  roundDuration,
@@ -1,14 +1,11 @@
1
- import { BatchLogs, data, EpochLogs, Logger, Memory, Task, TrainingInformation } from '../index.js';
2
- import { client as clients } from '../index.js';
3
- import type { Aggregator } from '../aggregator/index.js';
4
- import type { RoundLogs } from './trainer/trainer.js';
5
- export interface DiscoOptions {
6
- client?: clients.Client;
7
- aggregator?: Aggregator;
8
- url?: string | URL;
9
- scheme?: TrainingInformation['scheme'];
10
- logger?: Logger;
11
- memory?: Memory;
1
+ import { data, BatchLogs, EpochLogs, Logger, Memory, Task, TrainingInformation } from "../index.js";
2
+ import { client as clients } from "../index.js";
3
+ import type { Aggregator } from "../aggregator/index.js";
4
+ import { RoundLogs, Trainer } from "./trainer.js";
5
+ interface Config {
6
+ scheme: TrainingInformation["scheme"];
7
+ logger: Logger;
8
+ memory: Memory;
12
9
  }
13
10
  /**
14
11
  * Top-level class handling distributed training from a client's perspective. It is meant to be
@@ -16,16 +13,22 @@ export interface DiscoOptions {
16
13
  * communication with nodes, logs and model memory.
17
14
  */
18
15
  export declare class Disco {
19
- readonly task: Task;
20
- readonly logger: Logger;
21
- readonly memory: Memory;
22
- private readonly client;
23
- private readonly trainer;
24
- constructor(task: Task, options: DiscoOptions);
16
+ #private;
17
+ readonly trainer: Trainer;
18
+ private constructor();
19
+ /**
20
+ * Connect to the given task and get ready to train.
21
+ *
22
+ * Will load the model from memory if available or fetch it from the server.
23
+ *
24
+ * @param clientConfig client to connect with or parameters on how to create one.
25
+ **/
26
+ static fromTask(task: Task, clientConfig: clients.Client | URL | {
27
+ aggregator: Aggregator;
28
+ url: URL;
29
+ }, config: Partial<Config>): Promise<Disco>;
25
30
  /** Train on dataset, yielding logs of every round. */
26
- trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs & {
27
- participants: number;
28
- }>;
31
+ trainByRound(dataTuple: data.DataSplit): AsyncGenerator<RoundLogs>;
29
32
  /** Train on dataset, yielding logs of every epoch. */
30
33
  trainByEpoch(dataTuple: data.DataSplit): AsyncGenerator<EpochLogs>;
31
34
  /** Train on dataset, yielding logs of every batch. */
@@ -33,14 +36,12 @@ export declare class Disco {
33
36
  /** Run whole train on dataset. */
34
37
  trainFully(dataTuple: data.DataSplit): Promise<void>;
35
38
  /**
36
- * Train on dataset, yield the nested steps.
37
- *
38
- * Don't forget to await the yielded generator otherwise nothing will progress.
39
- * If you don't care about the whole process, use one of the other train methods.
40
- **/
41
- train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs & {
42
- participants: number;
43
- }>>;
39
+ * Train on dataset, yield the nested steps.
40
+ *
41
+ * Don't forget to await the yielded generator otherwise nothing will progress.
42
+ * If you don't care about the whole process, use one of the other train methods.
43
+ **/
44
+ train(dataTuple: data.DataSplit): AsyncGenerator<AsyncGenerator<AsyncGenerator<BatchLogs, EpochLogs>, RoundLogs>>;
44
45
  /**
45
46
  * Stops the ongoing training instance without disconnecting the client.
46
47
  */
@@ -50,3 +51,4 @@ export declare class Disco {
50
51
  */
51
52
  close(): Promise<void>;
52
53
  }
54
+ export {};