@epfml/discojs-node 2.1.1 → 2.1.2-p20240506085559.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/data/image_loader.d.ts +5 -0
- package/dist/data/image_loader.js +11 -0
- package/dist/data/index.d.ts +3 -0
- package/dist/data/index.js +3 -0
- package/dist/data/tabular_loader.d.ts +4 -0
- package/dist/data/tabular_loader.js +11 -0
- package/dist/data/text_loader.d.ts +4 -0
- package/dist/data/text_loader.js +14 -0
- package/dist/index.d.ts +2 -2
- package/dist/index.js +2 -6
- package/package.json +13 -16
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.d.ts +0 -15
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/default_tasks/cifar10.d.ts +0 -2
- package/dist/core/default_tasks/cifar10.js +0 -68
- package/dist/core/default_tasks/geotags.d.ts +0 -2
- package/dist/core/default_tasks/geotags.js +0 -69
- package/dist/core/default_tasks/index.d.ts +0 -6
- package/dist/core/default_tasks/index.js +0 -15
- package/dist/core/default_tasks/lus_covid.d.ts +0 -2
- package/dist/core/default_tasks/lus_covid.js +0 -96
- package/dist/core/default_tasks/mnist.d.ts +0 -2
- package/dist/core/default_tasks/mnist.js +0 -69
- package/dist/core/default_tasks/simple_face.d.ts +0 -2
- package/dist/core/default_tasks/simple_face.js +0 -53
- package/dist/core/default_tasks/titanic.d.ts +0 -2
- package/dist/core/default_tasks/titanic.js +0 -97
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -39
- package/dist/core/informant/graph_informant.d.ts +0 -10
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.d.ts +0 -18
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.d.ts +0 -12
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/privacy.d.ts +0 -11
- package/dist/core/privacy.js +0 -47
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.d.ts +0 -5
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.d.ts +0 -5
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/digest.d.ts +0 -5
- package/dist/core/task/digest.js +0 -18
- package/dist/core/task/display_information.d.ts +0 -15
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -6
- package/dist/core/task/index.js +0 -15
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.d.ts +0 -5
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.d.ts +0 -12
- package/dist/core/task/task.js +0 -35
- package/dist/core/task/task_handler.d.ts +0 -5
- package/dist/core/task/task_handler.js +0 -53
- package/dist/core/task/task_provider.d.ts +0 -6
- package/dist/core/task/task_provider.js +0 -13
- package/dist/core/task/training_information.d.ts +0 -28
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.d.ts +0 -25
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -7
- package/dist/core/weights/aggregation.js +0 -72
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/dataset/data_loader/image_loader.d.ts +0 -4
- package/dist/dataset/data_loader/image_loader.js +0 -21
- package/dist/dataset/data_loader/index.d.ts +0 -2
- package/dist/dataset/data_loader/index.js +0 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +0 -4
- package/dist/dataset/data_loader/tabular_loader.js +0 -20
- package/dist/imports.d.ts +0 -1
- package/dist/imports.js +0 -5
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import fs from 'node:fs/promises';
|
|
2
|
+
import { node as tfNode } from '@tensorflow/tfjs-node';
|
|
3
|
+
import { data } from '@epfml/discojs';
|
|
4
|
+
export class ImageLoader extends data.ImageLoader {
|
|
5
|
+
async readImageFrom(source, channels) {
|
|
6
|
+
// We allow specifying the number of channels because the default number of channels
|
|
7
|
+
// differs between web and node for the same image
|
|
8
|
+
// E.g. lus covid images have 1 channel with fs.readFile but 3 when loaded with discojs-web
|
|
9
|
+
return tfNode.decodeImage(await fs.readFile(source), channels);
|
|
10
|
+
}
|
|
11
|
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import { data as tfData } from '@tensorflow/tfjs-node';
|
|
2
|
+
import { data } from '@epfml/discojs';
|
|
3
|
+
export class TabularLoader extends data.TabularLoader {
|
|
4
|
+
loadDatasetFrom(source, csvConfig) {
|
|
5
|
+
const prefix = 'file://';
|
|
6
|
+
if (source.slice(0, 7) !== prefix) {
|
|
7
|
+
source = prefix + source;
|
|
8
|
+
}
|
|
9
|
+
return Promise.resolve(tfData.csv(source, csvConfig));
|
|
10
|
+
}
|
|
11
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import { data as tfData } from '@tensorflow/tfjs-node';
|
|
2
|
+
import fs from 'node:fs/promises';
|
|
3
|
+
import { data } from '@epfml/discojs';
|
|
4
|
+
export class TextLoader extends data.TextLoader {
|
|
5
|
+
async loadDatasetFrom(source) {
|
|
6
|
+
// TODO: reads all the file at once,
|
|
7
|
+
// inputting the file path to FileDataSource isn't supported anymore
|
|
8
|
+
const inputFile = await fs.readFile(source);
|
|
9
|
+
const file = new tfData.FileDataSource(inputFile, { chunkSize: 1024 });
|
|
10
|
+
// TODO: reading files line by line is an issue for LLM tokenization
|
|
11
|
+
const dataset = new tfData.TextLineDataset(file).filter(s => s != ' '); // newline creates empty strings
|
|
12
|
+
return Promise.resolve(dataset);
|
|
13
|
+
}
|
|
14
|
+
}
|
package/dist/index.d.ts
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
export * from './
|
|
2
|
-
export
|
|
1
|
+
export * from './data/index.js';
|
|
2
|
+
export { saveModelToDisk, loadModelFromDisk } from './models/model_loader.js';
|
package/dist/index.js
CHANGED
|
@@ -1,6 +1,2 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
exports.node = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
(0, tslib_1.__exportStar)(require("./core"), exports);
|
|
6
|
-
exports.node = (0, tslib_1.__importStar)(require("./imports"));
|
|
1
|
+
export * from './data/index.js';
|
|
2
|
+
export { saveModelToDisk, loadModelFromDisk } from './models/model_loader.js';
|
package/package.json
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@epfml/discojs-node",
|
|
3
|
-
"version": "2.1.
|
|
3
|
+
"version": "2.1.2-p20240506085559.0",
|
|
4
|
+
"type": "module",
|
|
4
5
|
"main": "dist/index.js",
|
|
5
6
|
"types": "dist/index.d.ts",
|
|
6
7
|
"scripts": {
|
|
7
|
-
"
|
|
8
|
-
"build
|
|
9
|
-
"
|
|
10
|
-
"
|
|
8
|
+
"watch": "nodemon --ext ts --ignore dist --watch ../discojs/dist --watch . --exec npm run",
|
|
9
|
+
"build": "tsc",
|
|
10
|
+
"lint": "npx eslint .",
|
|
11
|
+
"test": "mocha"
|
|
11
12
|
},
|
|
12
13
|
"repository": {
|
|
13
14
|
"type": "git",
|
|
@@ -18,17 +19,13 @@
|
|
|
18
19
|
},
|
|
19
20
|
"homepage": "https://github.com/epfml/disco#readme",
|
|
20
21
|
"dependencies": {
|
|
21
|
-
"
|
|
22
|
-
"immutable": "4",
|
|
23
|
-
"tslib": "2",
|
|
24
|
-
"@tensorflow/tfjs-node": "4",
|
|
25
|
-
"isomorphic-ws": "4",
|
|
26
|
-
"url": "0.11",
|
|
22
|
+
"@epfml/discojs": "*",
|
|
27
23
|
"@koush/wrtc": "0.5",
|
|
28
|
-
"
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
"
|
|
32
|
-
"
|
|
24
|
+
"@tensorflow/tfjs-node": "4"
|
|
25
|
+
},
|
|
26
|
+
"devDependencies": {
|
|
27
|
+
"@types/node": "20",
|
|
28
|
+
"nodemon": "3",
|
|
29
|
+
"ts-node": "10"
|
|
33
30
|
}
|
|
34
31
|
}
|
package/README.md
DELETED
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
# Disco.js Node Module
|
|
2
|
-
|
|
3
|
-
`discojs-node` contains the Node.js code of Disco.js, based off and extending `discojs-core`.
|
|
4
|
-
|
|
5
|
-
## Installation
|
|
6
|
-
|
|
7
|
-
The `discojs-node` project is available as the `@epfml/discojs-node` NPM package, which can be installed with
|
|
8
|
-
`npm i @epfml/discojs-node`.
|
|
9
|
-
|
|
10
|
-
### Development Environment
|
|
11
|
-
|
|
12
|
-
The dev tools run on Node.js and require `npm`, a package manager for the Node.js runtime environment.
|
|
13
|
-
We recommend using [nvm](https://github.com/nvm-sh/nvm) for installing both Node.js and NPM.
|
|
14
|
-
|
|
15
|
-
To install the project's dependencies, run:
|
|
16
|
-
|
|
17
|
-
```
|
|
18
|
-
cd ..
|
|
19
|
-
npm ci
|
|
20
|
-
```
|
|
21
|
-
|
|
22
|
-
Since the dependencies of `discojs-core`, `discojs-web` and `discojs-node` are the same, they are specified in a top-level `package.json` file, to ease installation and building.
|
|
23
|
-
|
|
24
|
-
> **⚠ WARNING: Apple Silicon.**
|
|
25
|
-
> `TensorFlow.js` version `3` do support M1 processors for macs. To do so, make sure you have an `arm` Node.js executable installed (not `x86_64`). It can be checked using:
|
|
26
|
-
|
|
27
|
-
```
|
|
28
|
-
node -p "process.arch"
|
|
29
|
-
```
|
|
30
|
-
|
|
31
|
-
which should return something similar to `arm64`.
|
|
32
|
-
|
|
33
|
-
## Build
|
|
34
|
-
|
|
35
|
-
The server and CLI modules, as well as all unit tests (except Cypress) use the `discojs-node` interface, i.e. they all run on Node.js. This Disco.js Node module is build on top of and extends `discojs-core`, whose code is [symlinked](https://en.wikipedia.org/wiki/Symbolic_link) into `discojs-node/src/core`. To build this project:
|
|
36
|
-
|
|
37
|
-
```
|
|
38
|
-
npm run build
|
|
39
|
-
```
|
|
40
|
-
|
|
41
|
-
This invokes the TypeScript compiler (`tsc`). It will output the compilation files of `discojs-node` in a `dist/` directory. To recompile from stratch, simply `rm -rf dist/` before running `npm run build` again.
|
|
42
|
-
|
|
43
|
-
## Development
|
|
44
|
-
|
|
45
|
-
### Contributing
|
|
46
|
-
|
|
47
|
-
Contributions to `discojs-node` must only include Node-specific code. Code common to both Node and the browser must be added to `discojs-core` instead.
|
|
48
|
-
|
|
49
|
-
As a rule of thumb, the `src/core/` directory must never be modified when modifying `discojs-node`, since it is [symlinked](https://en.wikipedia.org/wiki/Symbolic_link) to `discojs-core`.
|
|
50
|
-
|
|
51
|
-
If you wish to add a new file or submodule to the project, please do so in a similar way as `src/core/` is structured. That is, [adding a new task](../../docs/TASK.md) to `discojs-node` would mean adding a new file to `src/tasks/` and modifying `src/tasks/index.ts` (NOT `src/core/tasks/...`).
|
|
52
|
-
|
|
53
|
-
Note that, if you end up making calls to the Tensorflow.js API, you must import it from the root index. This is to ensure the Node version of TF.js is loaded, and only once.
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
import { Map } from 'immutable';
|
|
2
|
-
import { TaskID, AsyncInformant } from '.';
|
|
3
|
-
/**
|
|
4
|
-
* The AsyncWeightsBuffer class holds and manipulates information about the
|
|
5
|
-
* async weights buffer. It works as follows:
|
|
6
|
-
*
|
|
7
|
-
* Setup: Init round to zero and create empty buffer (a map from user id to weights)
|
|
8
|
-
*
|
|
9
|
-
* - When a user adds weights only do so when they are recent weights: i.e. this.round - round <= roundCutoff.
|
|
10
|
-
* - If a user already added weights, update them. (-> there can be at most one entry of weights per id in a buffer).
|
|
11
|
-
* - When the buffer is full, call aggregateAndStoreWeights with the weights in the buffer and then increment round by one and reset the buffer.
|
|
12
|
-
*
|
|
13
|
-
* @remarks
|
|
14
|
-
* taskID: corresponds to the task that weights correspond to.
|
|
15
|
-
* bufferCapacity: size of the buffer.
|
|
16
|
-
* buffer: holds a map of users to their added weights.
|
|
17
|
-
* round: the latest round of the weight buffer.
|
|
18
|
-
* roundCutoff: cutoff for accepted rounds.
|
|
19
|
-
*/
|
|
20
|
-
export declare class AsyncBuffer<T> {
|
|
21
|
-
readonly taskID: TaskID;
|
|
22
|
-
private readonly bufferCapacity;
|
|
23
|
-
private readonly aggregateAndStoreWeights;
|
|
24
|
-
private readonly roundCutoff;
|
|
25
|
-
buffer: Map<string, T>;
|
|
26
|
-
round: number;
|
|
27
|
-
private observer;
|
|
28
|
-
constructor(taskID: TaskID, bufferCapacity: number, aggregateAndStoreWeights: (weights: Iterable<T>) => Promise<void>, roundCutoff?: number);
|
|
29
|
-
registerObserver(observer: AsyncInformant<T>): void;
|
|
30
|
-
bufferIsFull(): boolean;
|
|
31
|
-
private updateWeightsIfBufferIsFull;
|
|
32
|
-
isNotWithinRoundCutoff(round: number): boolean;
|
|
33
|
-
/**
|
|
34
|
-
* Add weights originating from weights of a given round.
|
|
35
|
-
* Only add to buffer if the given round is not old.
|
|
36
|
-
* @param weights
|
|
37
|
-
* @param round
|
|
38
|
-
* @returns true if weights were added, and false otherwise
|
|
39
|
-
*/
|
|
40
|
-
add(id: string, weights: T, round: number): Promise<boolean>;
|
|
41
|
-
}
|
|
@@ -1,97 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.AsyncBuffer = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var immutable_1 = require("immutable");
|
|
6
|
-
/**
|
|
7
|
-
* The AsyncWeightsBuffer class holds and manipulates information about the
|
|
8
|
-
* async weights buffer. It works as follows:
|
|
9
|
-
*
|
|
10
|
-
* Setup: Init round to zero and create empty buffer (a map from user id to weights)
|
|
11
|
-
*
|
|
12
|
-
* - When a user adds weights only do so when they are recent weights: i.e. this.round - round <= roundCutoff.
|
|
13
|
-
* - If a user already added weights, update them. (-> there can be at most one entry of weights per id in a buffer).
|
|
14
|
-
* - When the buffer is full, call aggregateAndStoreWeights with the weights in the buffer and then increment round by one and reset the buffer.
|
|
15
|
-
*
|
|
16
|
-
* @remarks
|
|
17
|
-
* taskID: corresponds to the task that weights correspond to.
|
|
18
|
-
* bufferCapacity: size of the buffer.
|
|
19
|
-
* buffer: holds a map of users to their added weights.
|
|
20
|
-
* round: the latest round of the weight buffer.
|
|
21
|
-
* roundCutoff: cutoff for accepted rounds.
|
|
22
|
-
*/
|
|
23
|
-
var AsyncBuffer = /** @class */ (function () {
|
|
24
|
-
function AsyncBuffer(taskID, bufferCapacity, aggregateAndStoreWeights, roundCutoff) {
|
|
25
|
-
if (roundCutoff === void 0) { roundCutoff = 0; }
|
|
26
|
-
this.taskID = taskID;
|
|
27
|
-
this.bufferCapacity = bufferCapacity;
|
|
28
|
-
this.aggregateAndStoreWeights = aggregateAndStoreWeights;
|
|
29
|
-
this.roundCutoff = roundCutoff;
|
|
30
|
-
this.buffer = (0, immutable_1.Map)();
|
|
31
|
-
this.round = 0;
|
|
32
|
-
}
|
|
33
|
-
AsyncBuffer.prototype.registerObserver = function (observer) {
|
|
34
|
-
this.observer = observer;
|
|
35
|
-
};
|
|
36
|
-
// TODO do not test private
|
|
37
|
-
AsyncBuffer.prototype.bufferIsFull = function () {
|
|
38
|
-
return this.buffer.size >= this.bufferCapacity;
|
|
39
|
-
};
|
|
40
|
-
AsyncBuffer.prototype.updateWeightsIfBufferIsFull = function () {
|
|
41
|
-
var _a;
|
|
42
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
43
|
-
return (0, tslib_1.__generator)(this, function (_b) {
|
|
44
|
-
switch (_b.label) {
|
|
45
|
-
case 0:
|
|
46
|
-
if (!this.bufferIsFull()) return [3 /*break*/, 2];
|
|
47
|
-
return [4 /*yield*/, this.aggregateAndStoreWeights(this.buffer.values())];
|
|
48
|
-
case 1:
|
|
49
|
-
_b.sent();
|
|
50
|
-
this.round += 1;
|
|
51
|
-
(_a = this.observer) === null || _a === void 0 ? void 0 : _a.update();
|
|
52
|
-
this.buffer = (0, immutable_1.Map)();
|
|
53
|
-
console.log('\n************************************************************');
|
|
54
|
-
console.log("Buffer is full; Aggregating weights and starting round: " + this.round + "\n");
|
|
55
|
-
_b.label = 2;
|
|
56
|
-
case 2: return [2 /*return*/];
|
|
57
|
-
}
|
|
58
|
-
});
|
|
59
|
-
});
|
|
60
|
-
};
|
|
61
|
-
// TODO do not test private
|
|
62
|
-
AsyncBuffer.prototype.isNotWithinRoundCutoff = function (round) {
|
|
63
|
-
// Note that always this.round >= round
|
|
64
|
-
return this.round - round > this.roundCutoff;
|
|
65
|
-
};
|
|
66
|
-
/**
|
|
67
|
-
* Add weights originating from weights of a given round.
|
|
68
|
-
* Only add to buffer if the given round is not old.
|
|
69
|
-
* @param weights
|
|
70
|
-
* @param round
|
|
71
|
-
* @returns true if weights were added, and false otherwise
|
|
72
|
-
*/
|
|
73
|
-
AsyncBuffer.prototype.add = function (id, weights, round) {
|
|
74
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
75
|
-
var weightsUpdatedByUser, msg;
|
|
76
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
77
|
-
switch (_a.label) {
|
|
78
|
-
case 0:
|
|
79
|
-
if (this.isNotWithinRoundCutoff(round)) {
|
|
80
|
-
console.log("Did not add weights of " + id + " to buffer. Due to old round update: " + round + ", current round is " + this.round);
|
|
81
|
-
return [2 /*return*/, false];
|
|
82
|
-
}
|
|
83
|
-
weightsUpdatedByUser = this.buffer.has(id);
|
|
84
|
-
msg = weightsUpdatedByUser ? '\tUpdating' : '-> Adding new';
|
|
85
|
-
console.log(msg + " weights of " + id + " to buffer.");
|
|
86
|
-
this.buffer = this.buffer.set(id, weights);
|
|
87
|
-
return [4 /*yield*/, this.updateWeightsIfBufferIsFull()];
|
|
88
|
-
case 1:
|
|
89
|
-
_a.sent();
|
|
90
|
-
return [2 /*return*/, true];
|
|
91
|
-
}
|
|
92
|
-
});
|
|
93
|
-
});
|
|
94
|
-
};
|
|
95
|
-
return AsyncBuffer;
|
|
96
|
-
}());
|
|
97
|
-
exports.AsyncBuffer = AsyncBuffer;
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
import { AsyncBuffer } from './async_buffer';
|
|
2
|
-
export declare class AsyncInformant<T> {
|
|
3
|
-
private readonly asyncBuffer;
|
|
4
|
-
private round;
|
|
5
|
-
private currentNumberOfParticipants;
|
|
6
|
-
private totalNumberOfParticipants;
|
|
7
|
-
private averageNumberOfParticipants;
|
|
8
|
-
constructor(asyncBuffer: AsyncBuffer<T>);
|
|
9
|
-
update(): void;
|
|
10
|
-
private updateRound;
|
|
11
|
-
private updateNumberOfParticipants;
|
|
12
|
-
private updateAverageNumberOfParticipants;
|
|
13
|
-
private updateTotalNumberOfParticipants;
|
|
14
|
-
getCurrentRound(): number;
|
|
15
|
-
getNumberOfParticipants(): number;
|
|
16
|
-
getTotalNumberOfParticipants(): number;
|
|
17
|
-
getAverageNumberOfParticipants(): number;
|
|
18
|
-
getAllStatistics(): Record<'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number>;
|
|
19
|
-
printAllInfos(): void;
|
|
20
|
-
}
|
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.AsyncInformant = void 0;
|
|
4
|
-
var AsyncInformant = /** @class */ (function () {
|
|
5
|
-
function AsyncInformant(asyncBuffer) {
|
|
6
|
-
this.asyncBuffer = asyncBuffer;
|
|
7
|
-
this.round = 0;
|
|
8
|
-
this.currentNumberOfParticipants = 0;
|
|
9
|
-
this.totalNumberOfParticipants = 0;
|
|
10
|
-
this.averageNumberOfParticipants = 0;
|
|
11
|
-
this.asyncBuffer.registerObserver(this);
|
|
12
|
-
}
|
|
13
|
-
// Update functions
|
|
14
|
-
AsyncInformant.prototype.update = function () {
|
|
15
|
-
// DEBUG
|
|
16
|
-
console.log('Before update');
|
|
17
|
-
this.printAllInfos();
|
|
18
|
-
this.updateRound();
|
|
19
|
-
this.updateNumberOfParticipants();
|
|
20
|
-
// DEBUG
|
|
21
|
-
console.log('After update');
|
|
22
|
-
this.printAllInfos();
|
|
23
|
-
};
|
|
24
|
-
AsyncInformant.prototype.updateRound = function () {
|
|
25
|
-
this.round = this.asyncBuffer.round;
|
|
26
|
-
};
|
|
27
|
-
AsyncInformant.prototype.updateNumberOfParticipants = function () {
|
|
28
|
-
this.currentNumberOfParticipants = this.asyncBuffer.buffer.size;
|
|
29
|
-
this.updateTotalNumberOfParticipants(this.currentNumberOfParticipants);
|
|
30
|
-
this.updateAverageNumberOfParticipants();
|
|
31
|
-
};
|
|
32
|
-
AsyncInformant.prototype.updateAverageNumberOfParticipants = function () {
|
|
33
|
-
this.averageNumberOfParticipants = this.totalNumberOfParticipants / this.round;
|
|
34
|
-
};
|
|
35
|
-
AsyncInformant.prototype.updateTotalNumberOfParticipants = function (currentNumberOfParticipants) {
|
|
36
|
-
this.totalNumberOfParticipants += currentNumberOfParticipants;
|
|
37
|
-
};
|
|
38
|
-
// Getter functions
|
|
39
|
-
AsyncInformant.prototype.getCurrentRound = function () {
|
|
40
|
-
return this.round;
|
|
41
|
-
};
|
|
42
|
-
AsyncInformant.prototype.getNumberOfParticipants = function () {
|
|
43
|
-
return this.currentNumberOfParticipants;
|
|
44
|
-
};
|
|
45
|
-
AsyncInformant.prototype.getTotalNumberOfParticipants = function () {
|
|
46
|
-
return this.totalNumberOfParticipants;
|
|
47
|
-
};
|
|
48
|
-
AsyncInformant.prototype.getAverageNumberOfParticipants = function () {
|
|
49
|
-
return this.averageNumberOfParticipants;
|
|
50
|
-
};
|
|
51
|
-
AsyncInformant.prototype.getAllStatistics = function () {
|
|
52
|
-
return {
|
|
53
|
-
round: this.getCurrentRound(),
|
|
54
|
-
currentNumberOfParticipants: this.getNumberOfParticipants(),
|
|
55
|
-
totalNumberOfParticipants: this.getTotalNumberOfParticipants(),
|
|
56
|
-
averageNumberOfParticipants: this.getAverageNumberOfParticipants()
|
|
57
|
-
};
|
|
58
|
-
};
|
|
59
|
-
// Debug
|
|
60
|
-
AsyncInformant.prototype.printAllInfos = function () {
|
|
61
|
-
console.log('task : ', this.asyncBuffer.taskID);
|
|
62
|
-
console.log('round : ', this.getCurrentRound());
|
|
63
|
-
console.log('participants : ', this.getNumberOfParticipants());
|
|
64
|
-
console.log('total : ', this.getTotalNumberOfParticipants());
|
|
65
|
-
console.log('average : ', this.getAverageNumberOfParticipants());
|
|
66
|
-
};
|
|
67
|
-
return AsyncInformant;
|
|
68
|
-
}());
|
|
69
|
-
exports.AsyncInformant = AsyncInformant;
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
import { tf, WeightsContainer, Task, TrainingInformant } from '..';
|
|
2
|
-
export declare abstract class Base {
|
|
3
|
-
readonly url: URL;
|
|
4
|
-
readonly task: Task;
|
|
5
|
-
protected connected: boolean;
|
|
6
|
-
constructor(url: URL, task: Task);
|
|
7
|
-
/**
|
|
8
|
-
* Handles the connection process from the client to any sort of
|
|
9
|
-
* centralized server.
|
|
10
|
-
*/
|
|
11
|
-
abstract connect(): Promise<void>;
|
|
12
|
-
/**
|
|
13
|
-
* Handles the disconnection process of the client from any sort
|
|
14
|
-
* of centralized server.
|
|
15
|
-
*/
|
|
16
|
-
abstract disconnect(): Promise<void>;
|
|
17
|
-
getLatestModel(): Promise<tf.LayersModel>;
|
|
18
|
-
/**
|
|
19
|
-
* The training manager matches this function with the training loop's
|
|
20
|
-
* onTrainEnd callback when training a TFJS model object. See the
|
|
21
|
-
* training manager for more details.
|
|
22
|
-
*/
|
|
23
|
-
abstract onTrainEndCommunication(weights: WeightsContainer, trainingInformant: TrainingInformant): Promise<void>;
|
|
24
|
-
/**
|
|
25
|
-
* This function will be called whenever a local round has ended.
|
|
26
|
-
*
|
|
27
|
-
* @param updatedWeights
|
|
28
|
-
* @param staleWeights
|
|
29
|
-
* @param round
|
|
30
|
-
* @param trainingInformant
|
|
31
|
-
*/
|
|
32
|
-
abstract onRoundEndCommunication(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, round: number, trainingInformant: TrainingInformant): Promise<WeightsContainer>;
|
|
33
|
-
}
|
package/dist/core/client/base.js
DELETED
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
"use strict";
|
|
2
|
-
Object.defineProperty(exports, "__esModule", { value: true });
|
|
3
|
-
exports.Base = void 0;
|
|
4
|
-
var tslib_1 = require("tslib");
|
|
5
|
-
var axios_1 = (0, tslib_1.__importDefault)(require("axios"));
|
|
6
|
-
var __1 = require("..");
|
|
7
|
-
var Base = /** @class */ (function () {
|
|
8
|
-
function Base(url, task) {
|
|
9
|
-
this.url = url;
|
|
10
|
-
this.task = task;
|
|
11
|
-
this.connected = false;
|
|
12
|
-
}
|
|
13
|
-
Base.prototype.getLatestModel = function () {
|
|
14
|
-
return (0, tslib_1.__awaiter)(this, void 0, void 0, function () {
|
|
15
|
-
var url, response;
|
|
16
|
-
return (0, tslib_1.__generator)(this, function (_a) {
|
|
17
|
-
switch (_a.label) {
|
|
18
|
-
case 0:
|
|
19
|
-
url = new URL('', this.url.href);
|
|
20
|
-
if (!url.pathname.endsWith('/')) {
|
|
21
|
-
url.pathname += '/';
|
|
22
|
-
}
|
|
23
|
-
url.pathname += "tasks/" + this.task.taskID + "/model.json";
|
|
24
|
-
return [4 /*yield*/, axios_1.default.get(url.href)];
|
|
25
|
-
case 1:
|
|
26
|
-
response = _a.sent();
|
|
27
|
-
return [4 /*yield*/, __1.serialization.model.decode(response.data)];
|
|
28
|
-
case 2: return [2 /*return*/, _a.sent()];
|
|
29
|
-
}
|
|
30
|
-
});
|
|
31
|
-
});
|
|
32
|
-
};
|
|
33
|
-
return Base;
|
|
34
|
-
}());
|
|
35
|
-
exports.Base = Base;
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
import { List, Map } from 'immutable';
|
|
2
|
-
import { TrainingInformant, WeightsContainer, Task } from '../..';
|
|
3
|
-
import { Base as ClientBase } from '../base';
|
|
4
|
-
import { PeerID } from './types';
|
|
5
|
-
import * as messages from './messages';
|
|
6
|
-
import { PeerConnection } from '../event_connection';
|
|
7
|
-
/**
|
|
8
|
-
* Abstract class for decentralized clients, executes onRoundEndCommunication as well as connecting
|
|
9
|
-
* to the signaling server
|
|
10
|
-
*/
|
|
11
|
-
export declare abstract class Base extends ClientBase {
|
|
12
|
-
readonly url: URL;
|
|
13
|
-
readonly task: Task;
|
|
14
|
-
protected readonly minimumReadyPeers: number;
|
|
15
|
-
private server?;
|
|
16
|
-
private peers?;
|
|
17
|
-
private ID?;
|
|
18
|
-
private pool?;
|
|
19
|
-
constructor(url: URL, task: Task);
|
|
20
|
-
private waitForPeers;
|
|
21
|
-
protected sendMessagetoPeer(peer: PeerConnection, msg: messages.PeerMessage): void;
|
|
22
|
-
private connectServer;
|
|
23
|
-
/**
|
|
24
|
-
* Initialize the connection to the peers and to the other nodes.
|
|
25
|
-
*/
|
|
26
|
-
connect(): Promise<void>;
|
|
27
|
-
disconnect(): Promise<void>;
|
|
28
|
-
onTrainEndCommunication(_: WeightsContainer, trainingInformant: TrainingInformant): Promise<void>;
|
|
29
|
-
onRoundEndCommunication(updatedWeights: WeightsContainer, staleWeights: WeightsContainer, round: number, trainingInformant: TrainingInformant): Promise<WeightsContainer>;
|
|
30
|
-
abstract sendAndReceiveWeights(peers: Map<PeerID, PeerConnection>, noisyWeights: WeightsContainer, round: number, trainingInformant: TrainingInformant): Promise<List<WeightsContainer>>;
|
|
31
|
-
abstract clientHandle(peers: Map<PeerID, PeerConnection>): void;
|
|
32
|
-
}
|