@epfml/discojs 2.0.0 → 2.1.2-p20240506085037.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/aggregator/base.d.ts +180 -0
- package/dist/aggregator/base.js +236 -0
- package/dist/aggregator/get.d.ts +16 -0
- package/dist/aggregator/get.js +31 -0
- package/dist/aggregator/index.d.ts +7 -0
- package/dist/aggregator/index.js +4 -0
- package/dist/aggregator/mean.d.ts +23 -0
- package/dist/aggregator/mean.js +69 -0
- package/dist/aggregator/secure.d.ts +27 -0
- package/dist/aggregator/secure.js +91 -0
- package/dist/async_informant.d.ts +15 -0
- package/dist/async_informant.js +42 -0
- package/dist/client/base.d.ts +76 -0
- package/dist/client/base.js +88 -0
- package/dist/client/decentralized/base.d.ts +32 -0
- package/dist/client/decentralized/base.js +192 -0
- package/dist/client/decentralized/index.d.ts +2 -0
- package/dist/client/decentralized/index.js +2 -0
- package/dist/client/decentralized/messages.d.ts +28 -0
- package/dist/client/decentralized/messages.js +44 -0
- package/dist/client/decentralized/peer.d.ts +40 -0
- package/dist/client/decentralized/peer.js +189 -0
- package/dist/client/decentralized/peer_pool.d.ts +12 -0
- package/dist/client/decentralized/peer_pool.js +44 -0
- package/dist/client/event_connection.d.ts +34 -0
- package/dist/client/event_connection.js +105 -0
- package/dist/client/federated/base.d.ts +54 -0
- package/dist/client/federated/base.js +151 -0
- package/dist/client/federated/index.d.ts +2 -0
- package/dist/client/federated/index.js +2 -0
- package/dist/client/federated/messages.d.ts +30 -0
- package/dist/client/federated/messages.js +24 -0
- package/dist/client/index.d.ts +8 -0
- package/dist/client/index.js +8 -0
- package/dist/client/local.d.ts +3 -0
- package/dist/client/local.js +3 -0
- package/dist/client/messages.d.ts +30 -0
- package/dist/client/messages.js +26 -0
- package/dist/client/types.d.ts +2 -0
- package/dist/client/types.js +4 -0
- package/dist/client/utils.d.ts +2 -0
- package/dist/client/utils.js +7 -0
- package/dist/dataset/data/data.d.ts +48 -0
- package/dist/dataset/data/data.js +72 -0
- package/dist/dataset/data/data_split.d.ts +8 -0
- package/dist/dataset/data/data_split.js +1 -0
- package/dist/dataset/data/image_data.d.ts +11 -0
- package/dist/dataset/data/image_data.js +38 -0
- package/dist/dataset/data/index.d.ts +6 -0
- package/dist/dataset/data/index.js +5 -0
- package/dist/dataset/data/preprocessing/base.d.ts +16 -0
- package/dist/dataset/data/preprocessing/base.js +1 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/image_preprocessing.js +40 -0
- package/dist/dataset/data/preprocessing/index.d.ts +4 -0
- package/dist/dataset/data/preprocessing/index.js +3 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/tabular_preprocessing.js +45 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.d.ts +13 -0
- package/dist/dataset/data/preprocessing/text_preprocessing.js +85 -0
- package/dist/dataset/data/tabular_data.d.ts +11 -0
- package/dist/dataset/data/tabular_data.js +25 -0
- package/dist/dataset/data/text_data.d.ts +11 -0
- package/dist/dataset/data/text_data.js +14 -0
- package/dist/{core/dataset → dataset}/data_loader/data_loader.d.ts +3 -5
- package/dist/dataset/data_loader/data_loader.js +2 -0
- package/dist/dataset/data_loader/image_loader.d.ts +20 -3
- package/dist/dataset/data_loader/image_loader.js +98 -23
- package/dist/dataset/data_loader/index.d.ts +5 -2
- package/dist/dataset/data_loader/index.js +4 -7
- package/dist/dataset/data_loader/tabular_loader.d.ts +34 -3
- package/dist/dataset/data_loader/tabular_loader.js +75 -15
- package/dist/dataset/data_loader/text_loader.d.ts +14 -0
- package/dist/dataset/data_loader/text_loader.js +25 -0
- package/dist/dataset/dataset.d.ts +5 -0
- package/dist/dataset/dataset.js +1 -0
- package/dist/dataset/dataset_builder.d.ts +60 -0
- package/dist/dataset/dataset_builder.js +142 -0
- package/dist/dataset/index.d.ts +5 -0
- package/dist/dataset/index.js +3 -0
- package/dist/default_tasks/cifar10/index.d.ts +2 -0
- package/dist/default_tasks/cifar10/index.js +60 -0
- package/dist/default_tasks/cifar10/model.d.ts +434 -0
- package/dist/default_tasks/cifar10/model.js +2385 -0
- package/dist/default_tasks/geotags/index.d.ts +2 -0
- package/dist/default_tasks/geotags/index.js +65 -0
- package/dist/default_tasks/geotags/model.d.ts +593 -0
- package/dist/default_tasks/geotags/model.js +4715 -0
- package/dist/default_tasks/index.d.ts +8 -0
- package/dist/default_tasks/index.js +8 -0
- package/dist/default_tasks/lus_covid.d.ts +2 -0
- package/dist/default_tasks/lus_covid.js +89 -0
- package/dist/default_tasks/mnist.d.ts +2 -0
- package/dist/default_tasks/mnist.js +61 -0
- package/dist/default_tasks/simple_face/index.d.ts +2 -0
- package/dist/default_tasks/simple_face/index.js +48 -0
- package/dist/default_tasks/simple_face/model.d.ts +513 -0
- package/dist/default_tasks/simple_face/model.js +4301 -0
- package/dist/default_tasks/skin_mnist.d.ts +2 -0
- package/dist/default_tasks/skin_mnist.js +80 -0
- package/dist/default_tasks/titanic.d.ts +2 -0
- package/dist/default_tasks/titanic.js +88 -0
- package/dist/default_tasks/wikitext.d.ts +2 -0
- package/dist/default_tasks/wikitext.js +38 -0
- package/dist/index.d.ts +18 -2
- package/dist/index.js +18 -6
- package/dist/{core/informant → informant}/graph_informant.d.ts +1 -1
- package/dist/informant/graph_informant.js +20 -0
- package/dist/informant/index.d.ts +1 -0
- package/dist/informant/index.js +1 -0
- package/dist/{core/logging → logging}/console_logger.d.ts +2 -2
- package/dist/logging/console_logger.js +22 -0
- package/dist/logging/index.d.ts +2 -0
- package/dist/logging/index.js +1 -0
- package/dist/{core/logging → logging}/logger.d.ts +3 -3
- package/dist/logging/logger.js +1 -0
- package/dist/memory/base.d.ts +119 -0
- package/dist/memory/base.js +9 -0
- package/dist/memory/empty.d.ts +20 -0
- package/dist/memory/empty.js +43 -0
- package/dist/memory/index.d.ts +3 -1
- package/dist/memory/index.js +3 -5
- package/dist/memory/model_type.d.ts +9 -0
- package/dist/memory/model_type.js +10 -0
- package/dist/{core/privacy.d.ts → privacy.d.ts} +1 -1
- package/dist/{core/privacy.js → privacy.js} +11 -16
- package/dist/serialization/index.d.ts +2 -0
- package/dist/serialization/index.js +2 -0
- package/dist/serialization/model.d.ts +5 -0
- package/dist/serialization/model.js +67 -0
- package/dist/{core/serialization → serialization}/weights.d.ts +2 -2
- package/dist/serialization/weights.js +37 -0
- package/dist/task/data_example.js +14 -0
- package/dist/task/digest.d.ts +5 -0
- package/dist/task/digest.js +14 -0
- package/dist/{core/task → task}/display_information.d.ts +5 -3
- package/dist/task/display_information.js +46 -0
- package/dist/task/index.d.ts +7 -0
- package/dist/task/index.js +5 -0
- package/dist/task/label_type.d.ts +9 -0
- package/dist/task/label_type.js +28 -0
- package/dist/task/summary.js +13 -0
- package/dist/task/task.d.ts +12 -0
- package/dist/task/task.js +22 -0
- package/dist/task/task_handler.d.ts +5 -0
- package/dist/task/task_handler.js +20 -0
- package/dist/task/task_provider.d.ts +5 -0
- package/dist/task/task_provider.js +1 -0
- package/dist/{core/task → task}/training_information.d.ts +9 -10
- package/dist/task/training_information.js +88 -0
- package/dist/training/disco.d.ts +40 -0
- package/dist/training/disco.js +107 -0
- package/dist/training/index.d.ts +2 -0
- package/dist/training/index.js +1 -0
- package/dist/training/trainer/distributed_trainer.d.ts +20 -0
- package/dist/training/trainer/distributed_trainer.js +36 -0
- package/dist/training/trainer/local_trainer.d.ts +12 -0
- package/dist/training/trainer/local_trainer.js +19 -0
- package/dist/training/trainer/trainer.d.ts +33 -0
- package/dist/training/trainer/trainer.js +52 -0
- package/dist/{core/training → training}/trainer/trainer_builder.d.ts +5 -7
- package/dist/training/trainer/trainer_builder.js +43 -0
- package/dist/types.d.ts +8 -0
- package/dist/types.js +1 -0
- package/dist/utils/event_emitter.d.ts +40 -0
- package/dist/utils/event_emitter.js +57 -0
- package/dist/validation/index.d.ts +1 -0
- package/dist/validation/index.js +1 -0
- package/dist/validation/validator.d.ts +28 -0
- package/dist/validation/validator.js +132 -0
- package/dist/weights/aggregation.d.ts +21 -0
- package/dist/weights/aggregation.js +44 -0
- package/dist/weights/index.d.ts +2 -0
- package/dist/weights/index.js +2 -0
- package/dist/weights/weights_container.d.ts +68 -0
- package/dist/weights/weights_container.js +96 -0
- package/package.json +25 -16
- package/README.md +0 -53
- package/dist/core/async_buffer.d.ts +0 -41
- package/dist/core/async_buffer.js +0 -97
- package/dist/core/async_informant.d.ts +0 -20
- package/dist/core/async_informant.js +0 -69
- package/dist/core/client/base.d.ts +0 -33
- package/dist/core/client/base.js +0 -35
- package/dist/core/client/decentralized/base.d.ts +0 -32
- package/dist/core/client/decentralized/base.js +0 -212
- package/dist/core/client/decentralized/clear_text.d.ts +0 -14
- package/dist/core/client/decentralized/clear_text.js +0 -96
- package/dist/core/client/decentralized/index.d.ts +0 -4
- package/dist/core/client/decentralized/index.js +0 -9
- package/dist/core/client/decentralized/messages.d.ts +0 -41
- package/dist/core/client/decentralized/messages.js +0 -54
- package/dist/core/client/decentralized/peer.d.ts +0 -26
- package/dist/core/client/decentralized/peer.js +0 -210
- package/dist/core/client/decentralized/peer_pool.d.ts +0 -14
- package/dist/core/client/decentralized/peer_pool.js +0 -92
- package/dist/core/client/decentralized/sec_agg.d.ts +0 -22
- package/dist/core/client/decentralized/sec_agg.js +0 -190
- package/dist/core/client/decentralized/secret_shares.d.ts +0 -3
- package/dist/core/client/decentralized/secret_shares.js +0 -39
- package/dist/core/client/decentralized/types.d.ts +0 -2
- package/dist/core/client/decentralized/types.js +0 -7
- package/dist/core/client/event_connection.d.ts +0 -37
- package/dist/core/client/event_connection.js +0 -158
- package/dist/core/client/federated/client.d.ts +0 -37
- package/dist/core/client/federated/client.js +0 -273
- package/dist/core/client/federated/index.d.ts +0 -2
- package/dist/core/client/federated/index.js +0 -7
- package/dist/core/client/federated/messages.d.ts +0 -38
- package/dist/core/client/federated/messages.js +0 -25
- package/dist/core/client/index.d.ts +0 -5
- package/dist/core/client/index.js +0 -11
- package/dist/core/client/local.d.ts +0 -8
- package/dist/core/client/local.js +0 -36
- package/dist/core/client/messages.d.ts +0 -28
- package/dist/core/client/messages.js +0 -33
- package/dist/core/client/utils.d.ts +0 -2
- package/dist/core/client/utils.js +0 -19
- package/dist/core/dataset/data/data.d.ts +0 -11
- package/dist/core/dataset/data/data.js +0 -20
- package/dist/core/dataset/data/data_split.d.ts +0 -5
- package/dist/core/dataset/data/data_split.js +0 -2
- package/dist/core/dataset/data/image_data.d.ts +0 -8
- package/dist/core/dataset/data/image_data.js +0 -64
- package/dist/core/dataset/data/index.d.ts +0 -5
- package/dist/core/dataset/data/index.js +0 -11
- package/dist/core/dataset/data/preprocessing.d.ts +0 -13
- package/dist/core/dataset/data/preprocessing.js +0 -33
- package/dist/core/dataset/data/tabular_data.d.ts +0 -8
- package/dist/core/dataset/data/tabular_data.js +0 -40
- package/dist/core/dataset/data_loader/data_loader.js +0 -10
- package/dist/core/dataset/data_loader/image_loader.d.ts +0 -17
- package/dist/core/dataset/data_loader/image_loader.js +0 -141
- package/dist/core/dataset/data_loader/index.d.ts +0 -3
- package/dist/core/dataset/data_loader/index.js +0 -9
- package/dist/core/dataset/data_loader/tabular_loader.d.ts +0 -29
- package/dist/core/dataset/data_loader/tabular_loader.js +0 -101
- package/dist/core/dataset/dataset.d.ts +0 -2
- package/dist/core/dataset/dataset.js +0 -2
- package/dist/core/dataset/dataset_builder.d.ts +0 -18
- package/dist/core/dataset/dataset_builder.js +0 -96
- package/dist/core/dataset/index.d.ts +0 -4
- package/dist/core/dataset/index.js +0 -14
- package/dist/core/index.d.ts +0 -18
- package/dist/core/index.js +0 -41
- package/dist/core/informant/graph_informant.js +0 -23
- package/dist/core/informant/index.d.ts +0 -3
- package/dist/core/informant/index.js +0 -9
- package/dist/core/informant/training_informant/base.d.ts +0 -31
- package/dist/core/informant/training_informant/base.js +0 -83
- package/dist/core/informant/training_informant/decentralized.d.ts +0 -5
- package/dist/core/informant/training_informant/decentralized.js +0 -22
- package/dist/core/informant/training_informant/federated.d.ts +0 -14
- package/dist/core/informant/training_informant/federated.js +0 -32
- package/dist/core/informant/training_informant/index.d.ts +0 -4
- package/dist/core/informant/training_informant/index.js +0 -11
- package/dist/core/informant/training_informant/local.d.ts +0 -6
- package/dist/core/informant/training_informant/local.js +0 -20
- package/dist/core/logging/console_logger.js +0 -33
- package/dist/core/logging/index.d.ts +0 -3
- package/dist/core/logging/index.js +0 -9
- package/dist/core/logging/logger.js +0 -9
- package/dist/core/logging/trainer_logger.d.ts +0 -24
- package/dist/core/logging/trainer_logger.js +0 -59
- package/dist/core/memory/base.d.ts +0 -22
- package/dist/core/memory/base.js +0 -9
- package/dist/core/memory/empty.d.ts +0 -14
- package/dist/core/memory/empty.js +0 -75
- package/dist/core/memory/index.d.ts +0 -3
- package/dist/core/memory/index.js +0 -9
- package/dist/core/memory/model_type.d.ts +0 -4
- package/dist/core/memory/model_type.js +0 -9
- package/dist/core/serialization/index.d.ts +0 -2
- package/dist/core/serialization/index.js +0 -6
- package/dist/core/serialization/model.d.ts +0 -5
- package/dist/core/serialization/model.js +0 -55
- package/dist/core/serialization/weights.js +0 -64
- package/dist/core/task/data_example.js +0 -24
- package/dist/core/task/display_information.js +0 -49
- package/dist/core/task/index.d.ts +0 -3
- package/dist/core/task/index.js +0 -8
- package/dist/core/task/model_compile_data.d.ts +0 -6
- package/dist/core/task/model_compile_data.js +0 -22
- package/dist/core/task/summary.js +0 -19
- package/dist/core/task/task.d.ts +0 -10
- package/dist/core/task/task.js +0 -31
- package/dist/core/task/training_information.js +0 -66
- package/dist/core/tasks/cifar10.d.ts +0 -3
- package/dist/core/tasks/cifar10.js +0 -65
- package/dist/core/tasks/geotags.d.ts +0 -3
- package/dist/core/tasks/geotags.js +0 -67
- package/dist/core/tasks/index.d.ts +0 -6
- package/dist/core/tasks/index.js +0 -10
- package/dist/core/tasks/lus_covid.d.ts +0 -3
- package/dist/core/tasks/lus_covid.js +0 -87
- package/dist/core/tasks/mnist.d.ts +0 -3
- package/dist/core/tasks/mnist.js +0 -60
- package/dist/core/tasks/simple_face.d.ts +0 -2
- package/dist/core/tasks/simple_face.js +0 -41
- package/dist/core/tasks/titanic.d.ts +0 -3
- package/dist/core/tasks/titanic.js +0 -88
- package/dist/core/training/disco.d.ts +0 -23
- package/dist/core/training/disco.js +0 -130
- package/dist/core/training/index.d.ts +0 -2
- package/dist/core/training/index.js +0 -7
- package/dist/core/training/trainer/distributed_trainer.d.ts +0 -20
- package/dist/core/training/trainer/distributed_trainer.js +0 -65
- package/dist/core/training/trainer/local_trainer.d.ts +0 -11
- package/dist/core/training/trainer/local_trainer.js +0 -34
- package/dist/core/training/trainer/round_tracker.d.ts +0 -30
- package/dist/core/training/trainer/round_tracker.js +0 -47
- package/dist/core/training/trainer/trainer.d.ts +0 -65
- package/dist/core/training/trainer/trainer.js +0 -160
- package/dist/core/training/trainer/trainer_builder.js +0 -95
- package/dist/core/training/training_schemes.d.ts +0 -5
- package/dist/core/training/training_schemes.js +0 -10
- package/dist/core/types.d.ts +0 -4
- package/dist/core/types.js +0 -2
- package/dist/core/validation/index.d.ts +0 -1
- package/dist/core/validation/index.js +0 -5
- package/dist/core/validation/validator.d.ts +0 -17
- package/dist/core/validation/validator.js +0 -104
- package/dist/core/weights/aggregation.d.ts +0 -8
- package/dist/core/weights/aggregation.js +0 -96
- package/dist/core/weights/index.d.ts +0 -2
- package/dist/core/weights/index.js +0 -7
- package/dist/core/weights/weights_container.d.ts +0 -19
- package/dist/core/weights/weights_container.js +0 -64
- package/dist/imports.d.ts +0 -2
- package/dist/imports.js +0 -7
- package/dist/memory/memory.d.ts +0 -26
- package/dist/memory/memory.js +0 -160
- package/dist/{core/task → task}/data_example.d.ts +1 -1
- package/dist/{core/task → task}/summary.d.ts +1 -1
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import type * as decentralized from './decentralized/messages.js';
|
|
2
|
+
import type * as federated from './federated/messages.js';
|
|
3
|
+
import { type NodeID } from './types.js';
|
|
4
|
+
export declare enum type {
|
|
5
|
+
ClientConnected = 0,
|
|
6
|
+
AssignNodeID = 1,
|
|
7
|
+
SignalForPeer = 2,
|
|
8
|
+
PeerIsReady = 3,
|
|
9
|
+
PeersForRound = 4,
|
|
10
|
+
Payload = 5,
|
|
11
|
+
SendPayload = 6,
|
|
12
|
+
ReceiveServerMetadata = 7,
|
|
13
|
+
ReceiveServerPayload = 8,
|
|
14
|
+
RequestServerStatistics = 9,
|
|
15
|
+
ReceiveServerStatistics = 10
|
|
16
|
+
}
|
|
17
|
+
export interface ClientConnected {
|
|
18
|
+
type: type.ClientConnected;
|
|
19
|
+
}
|
|
20
|
+
export interface AssignNodeID {
|
|
21
|
+
type: type.AssignNodeID;
|
|
22
|
+
id: NodeID;
|
|
23
|
+
}
|
|
24
|
+
export type Message = decentralized.MessageFromServer | decentralized.MessageToServer | decentralized.PeerMessage | federated.MessageFederated;
|
|
25
|
+
export type NarrowMessage<D> = Extract<Message, {
|
|
26
|
+
type: D;
|
|
27
|
+
}>;
|
|
28
|
+
export declare function hasMessageType(raw: unknown): raw is {
|
|
29
|
+
type: type;
|
|
30
|
+
} & Record<string, unknown>;
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
export var type;
|
|
2
|
+
(function (type) {
|
|
3
|
+
type[type["ClientConnected"] = 0] = "ClientConnected";
|
|
4
|
+
type[type["AssignNodeID"] = 1] = "AssignNodeID";
|
|
5
|
+
// Decentralized
|
|
6
|
+
type[type["SignalForPeer"] = 2] = "SignalForPeer";
|
|
7
|
+
type[type["PeerIsReady"] = 3] = "PeerIsReady";
|
|
8
|
+
type[type["PeersForRound"] = 4] = "PeersForRound";
|
|
9
|
+
type[type["Payload"] = 5] = "Payload";
|
|
10
|
+
// Federated
|
|
11
|
+
type[type["SendPayload"] = 6] = "SendPayload";
|
|
12
|
+
type[type["ReceiveServerMetadata"] = 7] = "ReceiveServerMetadata";
|
|
13
|
+
type[type["ReceiveServerPayload"] = 8] = "ReceiveServerPayload";
|
|
14
|
+
type[type["RequestServerStatistics"] = 9] = "RequestServerStatistics";
|
|
15
|
+
type[type["ReceiveServerStatistics"] = 10] = "ReceiveServerStatistics";
|
|
16
|
+
})(type || (type = {}));
|
|
17
|
+
export function hasMessageType(raw) {
|
|
18
|
+
if (typeof raw !== 'object' || raw === null) {
|
|
19
|
+
return false;
|
|
20
|
+
}
|
|
21
|
+
const o = raw;
|
|
22
|
+
if (!('type' in o && typeof o.type === 'number' && o.type in type)) {
|
|
23
|
+
return false;
|
|
24
|
+
}
|
|
25
|
+
return true;
|
|
26
|
+
}
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import type tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { List } from 'immutable';
|
|
3
|
+
import type { Task } from '../../index.js';
|
|
4
|
+
import type { Dataset } from '../index.js';
|
|
5
|
+
import type { PreprocessingFunction } from './preprocessing/base.js';
|
|
6
|
+
/**
|
|
7
|
+
* Abstract class representing an immutable Disco dataset, including a TF.js dataset,
|
|
8
|
+
* Disco task and set of preprocessing functions.
|
|
9
|
+
*/
|
|
10
|
+
export declare abstract class Data {
|
|
11
|
+
readonly dataset: Dataset;
|
|
12
|
+
readonly task: Task;
|
|
13
|
+
readonly size?: number | undefined;
|
|
14
|
+
abstract readonly availablePreprocessing: List<PreprocessingFunction>;
|
|
15
|
+
protected constructor(dataset: Dataset, task: Task, size?: number | undefined);
|
|
16
|
+
static init(_dataset: Dataset, _task: Task, _size?: number): Promise<Data>;
|
|
17
|
+
/**
|
|
18
|
+
* Callable abstract method instead of constructor.
|
|
19
|
+
*/
|
|
20
|
+
protected abstract create(dataset: Dataset, task: Task, size?: number): Data;
|
|
21
|
+
/**
|
|
22
|
+
* Creates a new Disco data object containing the batched TF.js dataset, according to the
|
|
23
|
+
* task's parameters.
|
|
24
|
+
* @returns The batched Disco data
|
|
25
|
+
*/
|
|
26
|
+
batch(): Data;
|
|
27
|
+
/**
|
|
28
|
+
* The TF.js dataset batched according to the task's parameters.
|
|
29
|
+
*/
|
|
30
|
+
get batchedDataset(): Dataset;
|
|
31
|
+
/**
|
|
32
|
+
* Creates a new Disco data object containing the preprocessed TF.js dataset,
|
|
33
|
+
* according to the defined set of preprocessing functions and the task's parameters.
|
|
34
|
+
* @returns The preprocessed Disco data
|
|
35
|
+
*/
|
|
36
|
+
preprocess(): Data;
|
|
37
|
+
/**
|
|
38
|
+
* Creates a higher level preprocessing function applying the specified set of preprocessing
|
|
39
|
+
* functions in a series. The preprocessing functions are chained according to their defined
|
|
40
|
+
* priority.
|
|
41
|
+
*/
|
|
42
|
+
get preprocessing(): (entry: tf.TensorContainer) => Promise<tf.TensorContainer>;
|
|
43
|
+
/**
|
|
44
|
+
* The TF.js dataset preprocessing according to the set of preprocessing functions and the task's
|
|
45
|
+
* parameters.
|
|
46
|
+
*/
|
|
47
|
+
get preprocessedDataset(): Dataset;
|
|
48
|
+
}
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Abstract class representing an immutable Disco dataset, including a TF.js dataset,
|
|
3
|
+
* Disco task and set of preprocessing functions.
|
|
4
|
+
*/
|
|
5
|
+
export class Data {
|
|
6
|
+
dataset;
|
|
7
|
+
task;
|
|
8
|
+
size;
|
|
9
|
+
constructor(dataset, task, size) {
|
|
10
|
+
this.dataset = dataset;
|
|
11
|
+
this.task = task;
|
|
12
|
+
this.size = size;
|
|
13
|
+
}
|
|
14
|
+
static init(_dataset, _task, _size) {
|
|
15
|
+
return Promise.reject(new Error('abstract'));
|
|
16
|
+
}
|
|
17
|
+
/**
|
|
18
|
+
* Creates a new Disco data object containing the batched TF.js dataset, according to the
|
|
19
|
+
* task's parameters.
|
|
20
|
+
* @returns The batched Disco data
|
|
21
|
+
*/
|
|
22
|
+
batch() {
|
|
23
|
+
return this.create(this.batchedDataset, this.task, this.size);
|
|
24
|
+
}
|
|
25
|
+
/**
|
|
26
|
+
* The TF.js dataset batched according to the task's parameters.
|
|
27
|
+
*/
|
|
28
|
+
get batchedDataset() {
|
|
29
|
+
const batchSize = this.task.trainingInformation.batchSize;
|
|
30
|
+
return batchSize === undefined
|
|
31
|
+
? this.dataset
|
|
32
|
+
: this.dataset.batch(batchSize);
|
|
33
|
+
}
|
|
34
|
+
/**
|
|
35
|
+
* Creates a new Disco data object containing the preprocessed TF.js dataset,
|
|
36
|
+
* according to the defined set of preprocessing functions and the task's parameters.
|
|
37
|
+
* @returns The preprocessed Disco data
|
|
38
|
+
*/
|
|
39
|
+
preprocess() {
|
|
40
|
+
return this.create(this.preprocessedDataset, this.task, this.size);
|
|
41
|
+
}
|
|
42
|
+
/**
|
|
43
|
+
* Creates a higher level preprocessing function applying the specified set of preprocessing
|
|
44
|
+
* functions in a series. The preprocessing functions are chained according to their defined
|
|
45
|
+
* priority.
|
|
46
|
+
*/
|
|
47
|
+
get preprocessing() {
|
|
48
|
+
const params = this.task.trainingInformation;
|
|
49
|
+
const taskPreprocessing = params.preprocessingFunctions;
|
|
50
|
+
if (taskPreprocessing === undefined ||
|
|
51
|
+
taskPreprocessing.length === 0 ||
|
|
52
|
+
this.availablePreprocessing === undefined ||
|
|
53
|
+
this.availablePreprocessing.size === 0) {
|
|
54
|
+
return x => Promise.resolve(x);
|
|
55
|
+
}
|
|
56
|
+
const applyPreprocessing = this.availablePreprocessing
|
|
57
|
+
.filter((e) => e.type in taskPreprocessing)
|
|
58
|
+
.map((e) => e.apply);
|
|
59
|
+
if (applyPreprocessing.size === 0) {
|
|
60
|
+
return x => Promise.resolve(x);
|
|
61
|
+
}
|
|
62
|
+
const preprocessingChain = applyPreprocessing.reduce((acc, fn) => x => fn(acc(x), this.task), (x) => x);
|
|
63
|
+
return x => preprocessingChain(Promise.resolve(x));
|
|
64
|
+
}
|
|
65
|
+
/**
|
|
66
|
+
* The TF.js dataset preprocessing according to the set of preprocessing functions and the task's
|
|
67
|
+
* parameters.
|
|
68
|
+
*/
|
|
69
|
+
get preprocessedDataset() {
|
|
70
|
+
return this.dataset.mapAsync(this.preprocessing);
|
|
71
|
+
}
|
|
72
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import type { Task } from '../../index.js';
|
|
2
|
+
import type { Dataset } from '../dataset.js';
|
|
3
|
+
import { Data } from './data.js';
|
|
4
|
+
/**
|
|
5
|
+
* Disco data made of image samples (.jpg, .png, etc.).
|
|
6
|
+
*/
|
|
7
|
+
export declare class ImageData extends Data {
|
|
8
|
+
readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
|
|
9
|
+
static init(dataset: Dataset, task: Task, size?: number): Promise<Data>;
|
|
10
|
+
protected create(dataset: Dataset, task: Task, size: number): ImageData;
|
|
11
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import { Data } from './data.js';
|
|
2
|
+
import { ImagePreprocessing, IMAGE_PREPROCESSING } from './preprocessing/index.js';
|
|
3
|
+
/**
|
|
4
|
+
* Disco data made of image samples (.jpg, .png, etc.).
|
|
5
|
+
*/
|
|
6
|
+
export class ImageData extends Data {
|
|
7
|
+
availablePreprocessing = IMAGE_PREPROCESSING;
|
|
8
|
+
static async init(dataset, task, size) {
|
|
9
|
+
// Here we do our best to check data format before proceeding to training, for
|
|
10
|
+
// better error handling. An incorrectly formatted image in the dataset might still
|
|
11
|
+
// cause an error during training, because of the lazy aspect of the dataset; we only
|
|
12
|
+
// verify the first sample.
|
|
13
|
+
if (task.trainingInformation.preprocessingFunctions?.includes(ImagePreprocessing.Resize) !== true) {
|
|
14
|
+
const sample = (await dataset.take(1).toArray())[0];
|
|
15
|
+
// TODO: We suppose the presence of labels
|
|
16
|
+
// TODO: Typing (discojs-node/src/dataset/data_loader/image_loader.spec.ts)
|
|
17
|
+
if (typeof sample !== 'object' || sample === null || sample === undefined) {
|
|
18
|
+
throw new Error("Image is undefined or is not an object");
|
|
19
|
+
}
|
|
20
|
+
let shape;
|
|
21
|
+
if ('xs' in sample && 'ys' in sample) {
|
|
22
|
+
shape = sample.xs.shape;
|
|
23
|
+
}
|
|
24
|
+
else {
|
|
25
|
+
shape = sample.shape;
|
|
26
|
+
}
|
|
27
|
+
const { IMAGE_H, IMAGE_W } = task.trainingInformation;
|
|
28
|
+
if (IMAGE_W !== undefined && IMAGE_H !== undefined &&
|
|
29
|
+
(shape[0] !== IMAGE_W || shape[1] !== IMAGE_H)) {
|
|
30
|
+
throw new Error(`Image doesn't have the dimensions specified in the task's training information. Expected ${IMAGE_H}x${IMAGE_W} but got ${shape[0]}x${shape[1]}.`);
|
|
31
|
+
}
|
|
32
|
+
}
|
|
33
|
+
return new ImageData(dataset, task, size);
|
|
34
|
+
}
|
|
35
|
+
create(dataset, task, size) {
|
|
36
|
+
return new ImageData(dataset, task, size);
|
|
37
|
+
}
|
|
38
|
+
}
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
export type { DataSplit } from './data_split.js';
|
|
2
|
+
export { Data } from './data.js';
|
|
3
|
+
export { ImageData } from './image_data.js';
|
|
4
|
+
export { TabularData } from './tabular_data.js';
|
|
5
|
+
export { TextData } from './text_data.js';
|
|
6
|
+
export { ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './preprocessing/index.js';
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
export { Data } from './data.js';
|
|
2
|
+
export { ImageData } from './image_data.js';
|
|
3
|
+
export { TabularData } from './tabular_data.js';
|
|
4
|
+
export { TextData } from './text_data.js';
|
|
5
|
+
export { ImagePreprocessing, TabularPreprocessing, TextPreprocessing, IMAGE_PREPROCESSING, TABULAR_PREPROCESSING, TEXT_PREPROCESSING } from './preprocessing/index.js';
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import type tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { Task } from '../../../index.js';
|
|
3
|
+
import type { ImagePreprocessing } from './image_preprocessing.js';
|
|
4
|
+
import type { TabularPreprocessing } from './tabular_preprocessing.js';
|
|
5
|
+
import type { TextPreprocessing } from './text_preprocessing.js';
|
|
6
|
+
/**
|
|
7
|
+
* All available preprocessing type enums.
|
|
8
|
+
*/
|
|
9
|
+
export type Preprocessing = ImagePreprocessing | TextPreprocessing | TabularPreprocessing;
|
|
10
|
+
/**
|
|
11
|
+
* Preprocessing function associating a preprocessing type enum to a sample transformation.
|
|
12
|
+
*/
|
|
13
|
+
export interface PreprocessingFunction {
|
|
14
|
+
type: Preprocessing;
|
|
15
|
+
apply: (x: Promise<tf.TensorContainer>, task: Task) => Promise<tf.TensorContainer>;
|
|
16
|
+
}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
export {};
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import type { PreprocessingFunction } from './base.js';
|
|
3
|
+
/**
|
|
4
|
+
* Available image preprocessing types.
|
|
5
|
+
*/
|
|
6
|
+
export declare enum ImagePreprocessing {
|
|
7
|
+
Resize = 0,
|
|
8
|
+
Normalize = 1
|
|
9
|
+
}
|
|
10
|
+
/**
|
|
11
|
+
* Available image preprocessing functions.
|
|
12
|
+
*/
|
|
13
|
+
export declare const AVAILABLE_PREPROCESSING: List<PreprocessingFunction>;
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
/**
|
|
4
|
+
* Available image preprocessing types.
|
|
5
|
+
*/
|
|
6
|
+
export var ImagePreprocessing;
|
|
7
|
+
(function (ImagePreprocessing) {
|
|
8
|
+
ImagePreprocessing[ImagePreprocessing["Resize"] = 0] = "Resize";
|
|
9
|
+
ImagePreprocessing[ImagePreprocessing["Normalize"] = 1] = "Normalize";
|
|
10
|
+
})(ImagePreprocessing || (ImagePreprocessing = {}));
|
|
11
|
+
const resize = {
|
|
12
|
+
type: ImagePreprocessing.Resize,
|
|
13
|
+
apply: async (entry, task) => {
|
|
14
|
+
const { xs, ys } = await entry;
|
|
15
|
+
const params = task.trainingInformation;
|
|
16
|
+
return {
|
|
17
|
+
xs: params.IMAGE_W !== undefined && params.IMAGE_H !== undefined
|
|
18
|
+
? xs.resizeBilinear([params.IMAGE_H, params.IMAGE_W])
|
|
19
|
+
: xs,
|
|
20
|
+
ys
|
|
21
|
+
};
|
|
22
|
+
}
|
|
23
|
+
};
|
|
24
|
+
const normalize = {
|
|
25
|
+
type: ImagePreprocessing.Normalize,
|
|
26
|
+
apply: async (entry) => {
|
|
27
|
+
const { xs, ys } = await entry;
|
|
28
|
+
return {
|
|
29
|
+
xs: xs.div(tf.scalar(255)),
|
|
30
|
+
ys
|
|
31
|
+
};
|
|
32
|
+
}
|
|
33
|
+
};
|
|
34
|
+
/**
|
|
35
|
+
* Available image preprocessing functions.
|
|
36
|
+
*/
|
|
37
|
+
export const AVAILABLE_PREPROCESSING = List([
|
|
38
|
+
resize,
|
|
39
|
+
normalize
|
|
40
|
+
]).sortBy((e) => e.type);
|
|
@@ -0,0 +1,4 @@
|
|
|
1
|
+
export type { Preprocessing, PreprocessingFunction } from './base.js';
|
|
2
|
+
export { AVAILABLE_PREPROCESSING as IMAGE_PREPROCESSING, ImagePreprocessing } from './image_preprocessing.js';
|
|
3
|
+
export { AVAILABLE_PREPROCESSING as TABULAR_PREPROCESSING, TabularPreprocessing } from './tabular_preprocessing.js';
|
|
4
|
+
export { AVAILABLE_PREPROCESSING as TEXT_PREPROCESSING, TextPreprocessing } from './text_preprocessing.js';
|
|
@@ -0,0 +1,3 @@
|
|
|
1
|
+
export { AVAILABLE_PREPROCESSING as IMAGE_PREPROCESSING, ImagePreprocessing } from './image_preprocessing.js';
|
|
2
|
+
export { AVAILABLE_PREPROCESSING as TABULAR_PREPROCESSING, TabularPreprocessing } from './tabular_preprocessing.js';
|
|
3
|
+
export { AVAILABLE_PREPROCESSING as TEXT_PREPROCESSING, TextPreprocessing } from './text_preprocessing.js';
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import type { PreprocessingFunction } from './base.js';
|
|
3
|
+
/**
|
|
4
|
+
* Available tabular preprocessing types.
|
|
5
|
+
*/
|
|
6
|
+
export declare enum TabularPreprocessing {
|
|
7
|
+
Sanitize = 0,
|
|
8
|
+
Normalize = 1
|
|
9
|
+
}
|
|
10
|
+
/**
|
|
11
|
+
* Available tabular preprocessing functions.
|
|
12
|
+
*/
|
|
13
|
+
export declare const AVAILABLE_PREPROCESSING: List<PreprocessingFunction>;
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
/**
|
|
3
|
+
* Available tabular preprocessing types.
|
|
4
|
+
*/
|
|
5
|
+
export var TabularPreprocessing;
|
|
6
|
+
(function (TabularPreprocessing) {
|
|
7
|
+
TabularPreprocessing[TabularPreprocessing["Sanitize"] = 0] = "Sanitize";
|
|
8
|
+
TabularPreprocessing[TabularPreprocessing["Normalize"] = 1] = "Normalize";
|
|
9
|
+
})(TabularPreprocessing || (TabularPreprocessing = {}));
|
|
10
|
+
const sanitize = {
|
|
11
|
+
type: TabularPreprocessing.Sanitize,
|
|
12
|
+
apply: async (entry) => {
|
|
13
|
+
const entryContainer = await entry;
|
|
14
|
+
// if preprocessing a dataset without labels, then the entry is an array of numbers
|
|
15
|
+
if (Array.isArray(entryContainer)) {
|
|
16
|
+
const entry = entryContainer;
|
|
17
|
+
return entry.map((i) => i ?? 0);
|
|
18
|
+
// if it is an object
|
|
19
|
+
}
|
|
20
|
+
else if (typeof entryContainer === 'object' && entry !== null) {
|
|
21
|
+
// if the object is a tensor container with features xs and labels ys
|
|
22
|
+
if (Object.hasOwn(entryContainer, 'xs')) {
|
|
23
|
+
const { xs, ys } = entryContainer;
|
|
24
|
+
return {
|
|
25
|
+
xs: xs.map(i => i ?? 0),
|
|
26
|
+
ys
|
|
27
|
+
};
|
|
28
|
+
// if the object contains features as a dict of feature names-values
|
|
29
|
+
}
|
|
30
|
+
else {
|
|
31
|
+
const entry = Object.values(entryContainer);
|
|
32
|
+
return entry.map((i) => i ?? 0);
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
else {
|
|
36
|
+
throw new Error('Unrecognized format during tabular preprocessing');
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
};
|
|
40
|
+
/**
|
|
41
|
+
* Available tabular preprocessing functions.
|
|
42
|
+
*/
|
|
43
|
+
export const AVAILABLE_PREPROCESSING = List([
|
|
44
|
+
sanitize
|
|
45
|
+
]).sortBy((e) => e.type);
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import type { PreprocessingFunction } from './base.js';
|
|
3
|
+
/**
|
|
4
|
+
* Available text preprocessing types.
|
|
5
|
+
*/
|
|
6
|
+
export declare enum TextPreprocessing {
|
|
7
|
+
Tokenize = 0,
|
|
8
|
+
LeftPadding = 1
|
|
9
|
+
}
|
|
10
|
+
/**
|
|
11
|
+
* Available text preprocessing functions.
|
|
12
|
+
*/
|
|
13
|
+
export declare const AVAILABLE_PREPROCESSING: List<PreprocessingFunction>;
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import { List } from 'immutable';
|
|
2
|
+
import * as tf from '@tensorflow/tfjs';
|
|
3
|
+
import { models } from '../../../index.js';
|
|
4
|
+
/**
|
|
5
|
+
* Available text preprocessing types.
|
|
6
|
+
*/
|
|
7
|
+
export var TextPreprocessing;
|
|
8
|
+
(function (TextPreprocessing) {
|
|
9
|
+
TextPreprocessing[TextPreprocessing["Tokenize"] = 0] = "Tokenize";
|
|
10
|
+
TextPreprocessing[TextPreprocessing["LeftPadding"] = 1] = "LeftPadding";
|
|
11
|
+
})(TextPreprocessing || (TextPreprocessing = {}));
|
|
12
|
+
/**
|
|
13
|
+
* LeftPadding pads all incoming inputs to be a fixed length, which should be specified
|
|
14
|
+
* in `task.trainingInformation.maxSequenceLength`.
|
|
15
|
+
*
|
|
16
|
+
* We are currently only implementing left padding for text generation
|
|
17
|
+
* https://huggingface.co/docs/transformers/en/llm_tutorial#wrong-padding-side
|
|
18
|
+
* The function can easily be extended to support right padding if needed
|
|
19
|
+
*
|
|
20
|
+
* Once Transformers.js supports left padding, it will be possible to pad inputs
|
|
21
|
+
* directly when tokenizing
|
|
22
|
+
* https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
|
|
23
|
+
*/
|
|
24
|
+
const leftPadding = {
|
|
25
|
+
type: TextPreprocessing.LeftPadding,
|
|
26
|
+
apply: async (x, task) => {
|
|
27
|
+
if (x === undefined || !Array.isArray(x) || x.length == 0 || typeof (x[0] != 'number')) {
|
|
28
|
+
new Error("The leftPadding preprocessing expects a non empty 1D array of number");
|
|
29
|
+
}
|
|
30
|
+
const { tokens } = await x;
|
|
31
|
+
const tokenizer = await models.getTaskTokenizer(task);
|
|
32
|
+
return tf.tidy(() => {
|
|
33
|
+
// maxLength is the final length of xs
|
|
34
|
+
// Because ys the contains the tokens in xs shifted by one (to predict the next token), we need
|
|
35
|
+
// to include one more token than maxSequenceLength in order to have the next token's label of the maxSequenceLength'th token
|
|
36
|
+
const maxLength = task.trainingInformation.maxSequenceLength ?? tokenizer.model_max_length;
|
|
37
|
+
const maxLengthPlusLabel = maxLength + 1;
|
|
38
|
+
let fixedLengthTokens = tf.tensor(tokens, undefined, 'int32'); // cast tokens from float to int for gpt-tfjs
|
|
39
|
+
if (fixedLengthTokens.size > maxLengthPlusLabel) { // Should never happen because tokenization truncates inputs
|
|
40
|
+
throw Error("There are more tokens than expected after tokenization and truncation");
|
|
41
|
+
}
|
|
42
|
+
else if (fixedLengthTokens.size < maxLengthPlusLabel) { // Pad inputs to fixed length
|
|
43
|
+
const paddingToken = tokenizer.pad_token_id;
|
|
44
|
+
fixedLengthTokens = fixedLengthTokens.pad([[Math.max(0, maxLengthPlusLabel - fixedLengthTokens.size), 0]], paddingToken);
|
|
45
|
+
}
|
|
46
|
+
// if tokens.size == maxLengthPlusLabel we can leave it as it is
|
|
47
|
+
// ys is a one-hot encoding of the next token (i.e. xs shifted by one)
|
|
48
|
+
const ys = tf.oneHot(fixedLengthTokens.slice([1]), tokenizer.model.vocab.length + 1);
|
|
49
|
+
// remove the extra token now that ys is created
|
|
50
|
+
const xs = fixedLengthTokens.slice([0], maxLength);
|
|
51
|
+
return { xs, ys };
|
|
52
|
+
});
|
|
53
|
+
}
|
|
54
|
+
};
|
|
55
|
+
/**
|
|
56
|
+
* Tokenize and truncates input strings
|
|
57
|
+
*/
|
|
58
|
+
const tokenize = {
|
|
59
|
+
type: TextPreprocessing.Tokenize,
|
|
60
|
+
apply: async (x, task) => {
|
|
61
|
+
if (typeof x != 'string') {
|
|
62
|
+
new Error("The tokenize preprocessing expects a string as input");
|
|
63
|
+
}
|
|
64
|
+
const xs = await x; // tf.TextLineDataset yields strings
|
|
65
|
+
const tokenizer = await models.getTaskTokenizer(task);
|
|
66
|
+
// Add plus one to include the next token label of the last token in the input sequence
|
|
67
|
+
// The inputs are truncated down to exactly maxSequenceLength in leftPadding
|
|
68
|
+
const maxLength = task.trainingInformation.maxSequenceLength ?? tokenizer.model_max_length;
|
|
69
|
+
const maxLengthPlusLabel = maxLength + 1;
|
|
70
|
+
const { input_ids: tokens } = tokenizer(xs, {
|
|
71
|
+
// Transformers.js currently only supports right padding while we need left for text generation
|
|
72
|
+
// Right padding should be supported in the future, once it is, we can directly pad while tokenizing
|
|
73
|
+
// https://github.com/xenova/transformers.js/blob/8804c36591d11d8456788d1bb4b16489121b3be2/src/tokenizers.js#L2517
|
|
74
|
+
padding: false,
|
|
75
|
+
truncation: true,
|
|
76
|
+
return_tensor: false,
|
|
77
|
+
max_length: maxLengthPlusLabel,
|
|
78
|
+
});
|
|
79
|
+
return { tokens };
|
|
80
|
+
}
|
|
81
|
+
};
|
|
82
|
+
/**
|
|
83
|
+
* Available text preprocessing functions.
|
|
84
|
+
*/
|
|
85
|
+
export const AVAILABLE_PREPROCESSING = List.of(tokenize, leftPadding).sortBy((e) => e.type);
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import type { Task } from '../../index.js';
|
|
2
|
+
import type { Dataset } from '../dataset.js';
|
|
3
|
+
import { Data } from './data.js';
|
|
4
|
+
/**
|
|
5
|
+
* Disco data made of tabular (.csv, .tsv, etc.) files.
|
|
6
|
+
*/
|
|
7
|
+
export declare class TabularData extends Data {
|
|
8
|
+
readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
|
|
9
|
+
static init(dataset: Dataset, task: Task, size?: number): Promise<TabularData>;
|
|
10
|
+
protected create(dataset: Dataset, task: Task, size: number): TabularData;
|
|
11
|
+
}
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import { Data } from './data.js';
|
|
2
|
+
import { TABULAR_PREPROCESSING } from './preprocessing/index.js';
|
|
3
|
+
/**
|
|
4
|
+
* Disco data made of tabular (.csv, .tsv, etc.) files.
|
|
5
|
+
*/
|
|
6
|
+
export class TabularData extends Data {
|
|
7
|
+
availablePreprocessing = TABULAR_PREPROCESSING;
|
|
8
|
+
static async init(dataset, task, size) {
|
|
9
|
+
// Force the check of the data column format (among other things) before proceeding
|
|
10
|
+
// to training, for better error handling. An incorrectly formatted line might still
|
|
11
|
+
// cause an error during training, because of the lazy aspect of the dataset; we only
|
|
12
|
+
// load/read the tabular file's lines on training.
|
|
13
|
+
try {
|
|
14
|
+
await dataset.iterator();
|
|
15
|
+
}
|
|
16
|
+
catch (e) {
|
|
17
|
+
console.error('Data input format is not compatible with the chosen task.');
|
|
18
|
+
throw (e);
|
|
19
|
+
}
|
|
20
|
+
return new TabularData(dataset, task, size);
|
|
21
|
+
}
|
|
22
|
+
create(dataset, task, size) {
|
|
23
|
+
return new TabularData(dataset, task, size);
|
|
24
|
+
}
|
|
25
|
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import type { Task } from '../../index.js';
|
|
2
|
+
import type { Dataset } from '../dataset.js';
|
|
3
|
+
import { Data } from './data.js';
|
|
4
|
+
/**
|
|
5
|
+
* Disco data made of textual samples.
|
|
6
|
+
*/
|
|
7
|
+
export declare class TextData extends Data {
|
|
8
|
+
readonly availablePreprocessing: import("immutable").List<import("./preprocessing/base.js").PreprocessingFunction>;
|
|
9
|
+
static init(dataset: Dataset, task: Task, size?: number): Promise<TextData>;
|
|
10
|
+
protected create(dataset: Dataset, task: Task, size?: number): TextData;
|
|
11
|
+
}
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import { Data } from './data.js';
|
|
2
|
+
import { TEXT_PREPROCESSING } from './preprocessing/index.js';
|
|
3
|
+
/**
|
|
4
|
+
* Disco data made of textual samples.
|
|
5
|
+
*/
|
|
6
|
+
export class TextData extends Data {
|
|
7
|
+
availablePreprocessing = TEXT_PREPROCESSING;
|
|
8
|
+
static init(dataset, task, size) {
|
|
9
|
+
return Promise.resolve(new TextData(dataset, task, size));
|
|
10
|
+
}
|
|
11
|
+
create(dataset, task, size) {
|
|
12
|
+
return new TextData(dataset, task, size);
|
|
13
|
+
}
|
|
14
|
+
}
|
|
@@ -1,15 +1,13 @@
|
|
|
1
|
-
import {
|
|
2
|
-
import { Dataset } from '../dataset';
|
|
3
|
-
import { DataSplit } from '../data';
|
|
1
|
+
import type { DataSplit, Dataset } from '../index.js';
|
|
4
2
|
export interface DataConfig {
|
|
5
3
|
features?: string[];
|
|
6
4
|
labels?: string[];
|
|
7
5
|
shuffle?: boolean;
|
|
8
6
|
validationSplit?: number;
|
|
7
|
+
inference?: boolean;
|
|
8
|
+
channels?: number;
|
|
9
9
|
}
|
|
10
10
|
export declare abstract class DataLoader<Source> {
|
|
11
|
-
protected task: Task;
|
|
12
|
-
constructor(task: Task);
|
|
13
11
|
abstract load(source: Source, config: DataConfig): Promise<Dataset>;
|
|
14
12
|
abstract loadAll(sources: Source[], config: DataConfig): Promise<DataSplit>;
|
|
15
13
|
}
|
|
@@ -1,4 +1,21 @@
|
|
|
1
|
-
import
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
import * as tf from '@tensorflow/tfjs';
|
|
2
|
+
import type { Task } from '../../index.js';
|
|
3
|
+
import type { Dataset, DataSplit } from '../index.js';
|
|
4
|
+
import type { DataConfig } from '../data_loader/index.js';
|
|
5
|
+
import { DataLoader } from '../data_loader/index.js';
|
|
6
|
+
/**
|
|
7
|
+
* Image data loader whose instantiable implementation is delegated by the platform-dependent Disco subprojects, namely,
|
|
8
|
+
* @epfml/discojs-web and @epfml/discojs-node.
|
|
9
|
+
* Load labels and correctly match them with their respective images, with the following constraints:
|
|
10
|
+
* 1. Images are given as 1 image/1 file;
|
|
11
|
+
* 2. Labels are given as multiple labels/1 file, each label file can contain a different amount of labels.
|
|
12
|
+
*/
|
|
13
|
+
export declare abstract class ImageLoader<Source> extends DataLoader<Source> {
|
|
14
|
+
private readonly task;
|
|
15
|
+
abstract readImageFrom(source: Source, channels?: number): Promise<tf.Tensor3D>;
|
|
16
|
+
constructor(task: Task);
|
|
17
|
+
load(image: Source, config?: DataConfig): Promise<Dataset>;
|
|
18
|
+
private buildDataset;
|
|
19
|
+
loadAll(images: Source[], config?: DataConfig): Promise<DataSplit>;
|
|
20
|
+
shuffle(array: number[]): void;
|
|
4
21
|
}
|