cesnet-datazoo 0.0.16__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cesnet_datazoo/config.py +174 -167
- cesnet_datazoo/constants.py +4 -6
- cesnet_datazoo/datasets/cesnet_dataset.py +200 -172
- cesnet_datazoo/datasets/datasets.py +22 -2
- cesnet_datazoo/datasets/datasets_constants.py +670 -0
- cesnet_datazoo/datasets/loaders.py +3 -0
- cesnet_datazoo/datasets/metadata/dataset_metadata.py +6 -5
- cesnet_datazoo/datasets/metadata/metadata.csv +4 -4
- cesnet_datazoo/datasets/statistics.py +36 -16
- cesnet_datazoo/pytables_data/data_scalers.py +110 -0
- cesnet_datazoo/pytables_data/indices_setup.py +29 -33
- cesnet_datazoo/pytables_data/pytables_dataset.py +103 -229
- cesnet_datazoo/utils/class_info.py +7 -5
- cesnet_datazoo/utils/download.py +6 -1
- {cesnet_datazoo-0.0.16.dist-info → cesnet_datazoo-0.1.0.dist-info}/METADATA +2 -1
- cesnet_datazoo-0.1.0.dist-info/RECORD +30 -0
- {cesnet_datazoo-0.0.16.dist-info → cesnet_datazoo-0.1.0.dist-info}/WHEEL +1 -1
- cesnet_datazoo-0.0.16.dist-info/RECORD +0 -28
- {cesnet_datazoo-0.0.16.dist-info → cesnet_datazoo-0.1.0.dist-info}/LICENCE +0 -0
- {cesnet_datazoo-0.0.16.dist-info → cesnet_datazoo-0.1.0.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,14 @@ from sklearn.preprocessing import LabelEncoder
|
|
15
15
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
|
16
16
|
from typing_extensions import assert_never
|
17
17
|
|
18
|
-
from cesnet_datazoo.config import DataLoaderOrder, DatasetConfig,
|
19
|
-
from cesnet_datazoo.constants import DATASET_SIZES, INDICES_LABEL_POS,
|
20
|
-
|
18
|
+
from cesnet_datazoo.config import AppSelection, DataLoaderOrder, DatasetConfig, ValidationApproach
|
19
|
+
from cesnet_datazoo.constants import (APP_COLUMN, CATEGORY_COLUMN, DATASET_SIZES, INDICES_LABEL_POS,
|
20
|
+
SERVICEMAP_FILE, UNKNOWN_STR_LABEL)
|
21
|
+
from cesnet_datazoo.datasets.loaders import collate_fn_simple, create_df_from_dataloader
|
21
22
|
from cesnet_datazoo.datasets.metadata.dataset_metadata import DatasetMetadata, load_metadata
|
22
23
|
from cesnet_datazoo.datasets.statistics import compute_dataset_statistics
|
24
|
+
from cesnet_datazoo.pytables_data.apps_split import is_background_app
|
25
|
+
from cesnet_datazoo.pytables_data.data_scalers import fit_scalers
|
23
26
|
from cesnet_datazoo.pytables_data.indices_setup import (IndicesTuple, compute_known_app_counts,
|
24
27
|
compute_unknown_app_counts,
|
25
28
|
date_weight_sample_train_indices,
|
@@ -27,8 +30,7 @@ from cesnet_datazoo.pytables_data.indices_setup import (IndicesTuple, compute_kn
|
|
27
30
|
init_or_load_train_indices,
|
28
31
|
init_or_load_val_indices,
|
29
32
|
subset_and_sort_indices)
|
30
|
-
from cesnet_datazoo.pytables_data.pytables_dataset import
|
31
|
-
pytables_collate_fn, worker_init_fn)
|
33
|
+
from cesnet_datazoo.pytables_data.pytables_dataset import PyTablesDataset, worker_init_fn
|
32
34
|
from cesnet_datazoo.utils.class_info import ClassInfo, create_class_info
|
33
35
|
from cesnet_datazoo.utils.download import resumable_download, simple_download
|
34
36
|
from cesnet_datazoo.utils.random import RandomizedSection, get_fresh_random_generator
|
@@ -38,8 +40,7 @@ DATAFRAME_SAMPLES_WARNING_THRESHOLD = 20_000_000
|
|
38
40
|
|
39
41
|
class CesnetDataset():
|
40
42
|
"""
|
41
|
-
The main class for accessing CESNET datasets. It handles downloading, data
|
42
|
-
train/validation/test splitting, and class selection. Access to data is provided through:
|
43
|
+
The main class for accessing CESNET datasets. It handles downloading, train/validation/test splitting, and class selection. Access to data is provided through:
|
43
44
|
|
44
45
|
- Iterable PyTorch DataLoader for batch processing. See [using dataloaders][using-dataloaders] for more details.
|
45
46
|
- Pandas DataFrame for loading the entire train, validation, or test set at once.
|
@@ -53,7 +54,7 @@ class CesnetDataset():
|
|
53
54
|
|
54
55
|
1. Create an instance of the [dataset class][dataset-classes] with the desired size and data root. This will download the dataset if it has not already been downloaded.
|
55
56
|
2. Create an instance of [`DatasetConfig`][config.DatasetConfig] and set it with [`set_dataset_config_and_initialize`][datasets.cesnet_dataset.CesnetDataset.set_dataset_config_and_initialize].
|
56
|
-
This will initialize the dataset — select classes, split data into train/validation/test sets, and fit data scalers. All is done according to the provided configuration and is cached for later use.
|
57
|
+
This will initialize the dataset — select classes, split data into train/validation/test sets, and fit data scalers if needed. All is done according to the provided configuration and is cached for later use.
|
57
58
|
3. Use [`get_train_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_train_dataloader] or [`get_train_df`][datasets.cesnet_dataset.CesnetDataset.get_train_df] to get training data for a classification model.
|
58
59
|
4. Validate the model and perform the hyperparameter optimalization on [`get_val_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_val_dataloader] or [`get_val_df`][datasets.cesnet_dataset.CesnetDataset.get_val_df].
|
59
60
|
5. Evaluate the model on [`get_test_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_test_dataloader] or [`get_test_df`][datasets.cesnet_dataset.CesnetDataset.get_test_df].
|
@@ -68,9 +69,10 @@ class CesnetDataset():
|
|
68
69
|
database_filename: Name of the database file.
|
69
70
|
database_path: Path to the database file.
|
70
71
|
servicemap_path: Path to the servicemap file.
|
71
|
-
statistics_path: Path to the dataset statistics.
|
72
|
+
statistics_path: Path to the dataset statistics folder.
|
72
73
|
bucket_url: URL of the bucket where the database is stored.
|
73
74
|
metadata: Additional [dataset metadata][metadata].
|
75
|
+
available_classes: List of all available classes in the dataset.
|
74
76
|
available_dates: List of all available dates in the dataset.
|
75
77
|
time_periods: Predefined time periods. Each time period is a list of dates.
|
76
78
|
default_train_period_name: Default time period for training.
|
@@ -85,35 +87,30 @@ class CesnetDataset():
|
|
85
87
|
train_dataset: Train set in the form of `PyTablesDataset` instance wrapping the PyTables database.
|
86
88
|
val_dataset: Validation set in the form of `PyTablesDataset` instance wrapping the PyTables database.
|
87
89
|
test_dataset: Test set in the form of `PyTablesDataset` instance wrapping the PyTables database.
|
88
|
-
known_apps_database_enum: Dictionary that maps the database integer labels (different to those from `encoder`) of known applications to their names.
|
89
|
-
unknown_apps_database_enum: Dictionary that maps the database integer labels (different to those from `encoder`) of unknown applications to their names.
|
90
90
|
known_app_counts: Known application counts in the train, validation, and test sets.
|
91
91
|
unknown_app_counts: Unknown application counts in the validation and test sets.
|
92
|
-
collate_fn: Collate function used for creating batches in dataloaders.
|
93
|
-
encoder: Scikit-learn [`LabelEncoder`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html) used to encode class names into integers. It is fitted during the initialization of the dataset.
|
94
|
-
flowstats_scaler: Scaler for flow statistics. It is fitted during the initialization of the dataset.
|
95
|
-
psizes_scaler: Scaler for packet sizes.
|
96
|
-
ipt_scaler: Scaler for inter-packet times.
|
97
92
|
train_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for training.
|
98
93
|
train_dataloader_sampler: Sampler used for iterating the training dataloader. Either [`RandomSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.RandomSampler) or [`SequentialSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.SequentialSampler).
|
94
|
+
train_dataloader_drop_last: Whether to drop the last incomplete batch when iterating the training dataloader.
|
99
95
|
val_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for validation.
|
100
96
|
test_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for testing.
|
101
97
|
"""
|
102
|
-
name: str
|
103
|
-
size: str
|
104
98
|
data_root: str
|
99
|
+
size: str
|
100
|
+
silent: bool = False
|
101
|
+
|
102
|
+
name: str
|
105
103
|
database_filename: str
|
106
104
|
database_path: str
|
107
105
|
servicemap_path: str
|
108
106
|
statistics_path: str
|
109
107
|
bucket_url: str
|
110
108
|
metadata: DatasetMetadata
|
109
|
+
available_classes: list[str]
|
111
110
|
available_dates: list[str]
|
112
111
|
time_periods: dict[str, list[str]]
|
113
112
|
default_train_period_name: str
|
114
113
|
default_test_period_name: str
|
115
|
-
time_periods_gen: bool = False
|
116
|
-
silent: bool = False
|
117
114
|
|
118
115
|
dataset_config: Optional[DatasetConfig] = None
|
119
116
|
class_info: Optional[ClassInfo] = None
|
@@ -121,24 +118,19 @@ class CesnetDataset():
|
|
121
118
|
train_dataset: Optional[PyTablesDataset] = None
|
122
119
|
val_dataset: Optional[PyTablesDataset] = None
|
123
120
|
test_dataset: Optional[PyTablesDataset] = None
|
124
|
-
known_apps_database_enum: Optional[dict[int, str]] = None
|
125
|
-
unknown_apps_database_enum: Optional[dict[int, str]] = None
|
126
121
|
known_app_counts: Optional[pd.DataFrame] = None
|
127
122
|
unknown_app_counts: Optional[pd.DataFrame] = None
|
128
|
-
|
129
|
-
collate_fn: Optional[Callable] = None
|
130
|
-
encoder: Optional[LabelEncoder] = None
|
131
|
-
flowstats_scaler: Scaler = None
|
132
|
-
psizes_scaler: Scaler = None
|
133
|
-
ipt_scaler: Scaler = None
|
134
|
-
|
135
123
|
train_dataloader: Optional[DataLoader] = None
|
136
124
|
train_dataloader_sampler: Optional[Sampler] = None
|
137
125
|
train_dataloader_drop_last: bool = True
|
138
126
|
val_dataloader: Optional[DataLoader] = None
|
139
127
|
test_dataloader: Optional[DataLoader] = None
|
140
128
|
|
141
|
-
|
129
|
+
_collate_fn: Optional[Callable] = None
|
130
|
+
_tables_app_enum: dict[int, str]
|
131
|
+
_tables_cat_enum: dict[int, str]
|
132
|
+
|
133
|
+
def __init__(self, data_root: str, size: str = "S", database_checks_at_init: bool = False, silent: bool = False) -> None:
|
142
134
|
self.silent = silent
|
143
135
|
self.metadata = load_metadata(self.name)
|
144
136
|
self.size = size
|
@@ -156,24 +148,31 @@ class CesnetDataset():
|
|
156
148
|
os.makedirs(self.data_root)
|
157
149
|
if not self._is_downloaded():
|
158
150
|
self._download()
|
159
|
-
if
|
151
|
+
if database_checks_at_init:
|
160
152
|
with tb.open_file(self.database_path, mode="r") as database:
|
161
153
|
tables_paths = list(map(lambda x: x._v_pathname, iter(database.get_node(f"/flows"))))
|
162
154
|
num_samples = 0
|
163
155
|
for p in tables_paths:
|
164
|
-
|
156
|
+
table = database.get_node(p)
|
157
|
+
assert isinstance(table, tb.Table)
|
158
|
+
if self._tables_app_enum != {v: k for k, v in dict(table.get_enum(APP_COLUMN)).items()}:
|
159
|
+
raise ValueError(f"Found mismatch between _tables_app_enum and the PyTables database enum in table {p}. Please report this issue.")
|
160
|
+
if self._tables_cat_enum != {v: k for k, v in dict(table.get_enum(CATEGORY_COLUMN)).items()}:
|
161
|
+
raise ValueError(f"Found mismatch between _tables_cat_enum and the PyTables database enum in table {p}. Please report this issue.")
|
162
|
+
num_samples += len(table)
|
165
163
|
if self.size == "ORIG" and num_samples != self.metadata.available_samples:
|
166
164
|
raise ValueError(f"Expected {self.metadata.available_samples} samples, but got {num_samples} in the database. Please delete the data root folder, update cesnet-datazoo, and redownload the dataset.")
|
167
165
|
if self.size != "ORIG" and num_samples != DATASET_SIZES[self.size]:
|
168
166
|
raise ValueError(f"Expected {DATASET_SIZES[self.size]} samples, but got {num_samples} in the database. Please delete the data root folder, update cesnet-datazoo, and redownload the dataset.")
|
169
|
-
self.available_dates
|
170
|
-
|
171
|
-
self.available_dates = []
|
172
|
-
if self.time_periods_gen:
|
173
|
-
self._generate_time_periods()
|
167
|
+
if self.available_dates != list(map(lambda x: x.removeprefix("/flows/D"), tables_paths)):
|
168
|
+
raise ValueError(f"Found mismatch between available_dates and the dates available in the PyTables database. Please report this issue.")
|
174
169
|
# Add all available dates as single date time periods
|
175
170
|
for d in self.available_dates:
|
176
171
|
self.time_periods[d] = [d]
|
172
|
+
available_applications = sorted([app for app in pd.read_csv(self.servicemap_path, index_col="Tag").index if not is_background_app(app)])
|
173
|
+
if len(available_applications) != self.metadata.application_count:
|
174
|
+
raise ValueError(f"Found {len(available_applications)} applications in the servicemap (omitting background traffic classes), but expected {self.metadata.application_count}. Please report this issue.")
|
175
|
+
self.available_classes = available_applications + self.metadata.background_traffic_classes
|
177
176
|
|
178
177
|
def set_dataset_config_and_initialize(self, dataset_config: DatasetConfig, disable_indices_cache: bool = False) -> None:
|
179
178
|
"""
|
@@ -205,6 +204,8 @@ class CesnetDataset():
|
|
205
204
|
"""
|
206
205
|
if self.dataset_config is None:
|
207
206
|
raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting train dataloader")
|
207
|
+
if not self.dataset_config.need_train_set:
|
208
|
+
raise ValueError("Train dataloader is not available when need_train_set is false")
|
208
209
|
assert self.train_dataset
|
209
210
|
if self.train_dataloader:
|
210
211
|
return self.train_dataloader
|
@@ -227,7 +228,7 @@ class CesnetDataset():
|
|
227
228
|
self.train_dataset,
|
228
229
|
num_workers=self.dataset_config.train_workers,
|
229
230
|
worker_init_fn=worker_init_fn,
|
230
|
-
collate_fn=self.
|
231
|
+
collate_fn=self._collate_fn,
|
231
232
|
persistent_workers=self.dataset_config.train_workers > 0,
|
232
233
|
batch_size=None,
|
233
234
|
sampler=batch_sampler,)
|
@@ -252,8 +253,8 @@ class CesnetDataset():
|
|
252
253
|
"""
|
253
254
|
if self.dataset_config is None:
|
254
255
|
raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting validaion dataloader")
|
255
|
-
if self.dataset_config.
|
256
|
-
raise ValueError("Validation dataloader is not available when
|
256
|
+
if not self.dataset_config.need_val_set:
|
257
|
+
raise ValueError("Validation dataloader is not available when need_val_set is false")
|
257
258
|
assert self.val_dataset is not None
|
258
259
|
if self.val_dataloader:
|
259
260
|
return self.val_dataloader
|
@@ -262,7 +263,7 @@ class CesnetDataset():
|
|
262
263
|
self.val_dataset,
|
263
264
|
num_workers=self.dataset_config.val_workers,
|
264
265
|
worker_init_fn=worker_init_fn,
|
265
|
-
collate_fn=self.
|
266
|
+
collate_fn=self._collate_fn,
|
266
267
|
persistent_workers=self.dataset_config.val_workers > 0,
|
267
268
|
batch_size=None,
|
268
269
|
sampler=batch_sampler,)
|
@@ -291,8 +292,8 @@ class CesnetDataset():
|
|
291
292
|
"""
|
292
293
|
if self.dataset_config is None:
|
293
294
|
raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting test dataloader")
|
294
|
-
if self.dataset_config.
|
295
|
-
raise ValueError("Test dataloader is not available when
|
295
|
+
if not self.dataset_config.need_test_set:
|
296
|
+
raise ValueError("Test dataloader is not available when need_test_set is false")
|
296
297
|
assert self.test_dataset is not None
|
297
298
|
if self.test_dataloader:
|
298
299
|
return self.test_dataloader
|
@@ -301,7 +302,7 @@ class CesnetDataset():
|
|
301
302
|
self.test_dataset,
|
302
303
|
num_workers=self.dataset_config.test_workers,
|
303
304
|
worker_init_fn=worker_init_fn,
|
304
|
-
collate_fn=self.
|
305
|
+
collate_fn=self._collate_fn,
|
305
306
|
persistent_workers=False,
|
306
307
|
batch_size=None,
|
307
308
|
sampler=batch_sampler,)
|
@@ -333,7 +334,7 @@ class CesnetDataset():
|
|
333
334
|
Returns:
|
334
335
|
Train data as a dataframe.
|
335
336
|
"""
|
336
|
-
self._check_before_dataframe()
|
337
|
+
self._check_before_dataframe(check_train=True)
|
337
338
|
assert self.dataset_config is not None and self.train_dataset is not None
|
338
339
|
if len(self.train_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
|
339
340
|
warnings.warn(f"Train set has ({len(self.train_dataset)} samples), consider using get_train_dataloader() instead")
|
@@ -366,7 +367,7 @@ class CesnetDataset():
|
|
366
367
|
Returns:
|
367
368
|
Validation data as a dataframe.
|
368
369
|
"""
|
369
|
-
self._check_before_dataframe(
|
370
|
+
self._check_before_dataframe(check_val=True)
|
370
371
|
assert self.dataset_config is not None and self.val_dataset is not None
|
371
372
|
if len(self.val_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
|
372
373
|
warnings.warn(f"Validation set has ({len(self.val_dataset)} samples), consider using get_val_dataloader() instead")
|
@@ -395,7 +396,7 @@ class CesnetDataset():
|
|
395
396
|
Returns:
|
396
397
|
Test data as a dataframe.
|
397
398
|
"""
|
398
|
-
self._check_before_dataframe(
|
399
|
+
self._check_before_dataframe(check_test=True)
|
399
400
|
assert self.dataset_config is not None and self.test_dataset is not None
|
400
401
|
if len(self.test_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
|
401
402
|
warnings.warn(f"Test set has ({len(self.test_dataset)} samples), consider using get_test_dataloader() instead")
|
@@ -433,12 +434,18 @@ class CesnetDataset():
|
|
433
434
|
batch_size: Number of samples per batch for loading data.
|
434
435
|
disabled_apps: List of applications to exclude from the statistics.
|
435
436
|
"""
|
436
|
-
|
437
|
+
if disabled_apps:
|
438
|
+
bad_disabled_apps = [a for a in disabled_apps if a not in self.available_classes]
|
439
|
+
if len(bad_disabled_apps) > 0:
|
440
|
+
raise ValueError(f"Bad applications in disabled_apps {bad_disabled_apps}. Use applications available in dataset.available_classes")
|
437
441
|
if not os.path.exists(self.statistics_path):
|
438
442
|
os.mkdir(self.statistics_path)
|
439
443
|
compute_dataset_statistics(database_path=self.database_path,
|
444
|
+
tables_app_enum=self._tables_app_enum,
|
445
|
+
tables_cat_enum=self._tables_cat_enum,
|
440
446
|
output_dir=self.statistics_path,
|
441
|
-
|
447
|
+
packet_histograms=self.metadata.packet_histograms,
|
448
|
+
flowstats_features_boolean=self.metadata.flowstats_features_boolean,
|
442
449
|
protocol=self.metadata.protocol,
|
443
450
|
extra_fields=not self.name.startswith("CESNET-TLS22"),
|
444
451
|
disabled_apps=disabled_apps if disabled_apps is not None else [],
|
@@ -486,172 +493,193 @@ class CesnetDataset():
|
|
486
493
|
self.train_dataset = None
|
487
494
|
self.val_dataset = None
|
488
495
|
self.test_dataset = None
|
489
|
-
self.known_apps_database_enum = None
|
490
|
-
self.unknown_apps_database_enum = None
|
491
496
|
self.known_app_counts = None
|
492
497
|
self.unknown_app_counts = None
|
493
|
-
|
494
|
-
self.collate_fn = None
|
495
|
-
self.encoder = None
|
496
|
-
self.flowstats_scaler = None
|
497
|
-
self.psizes_scaler = None
|
498
|
-
self.ipt_scaler = None
|
499
|
-
|
500
498
|
self.train_dataloader = None
|
501
499
|
self.train_dataloader_sampler = None
|
502
500
|
self.train_dataloader_drop_last = True
|
503
501
|
self.val_dataloader = None
|
504
502
|
self.test_dataloader = None
|
503
|
+
self._collate_fn = None
|
505
504
|
|
506
|
-
def _check_before_dataframe(self,
|
505
|
+
def _check_before_dataframe(self, check_train: bool = False, check_val: bool = False, check_test: bool = False) -> None:
|
507
506
|
if self.dataset_config is None:
|
508
507
|
raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting a dataframe")
|
509
|
-
if self.dataset_config.
|
510
|
-
raise ValueError("Dataframes are not available when
|
511
|
-
if
|
512
|
-
raise ValueError("
|
513
|
-
if
|
514
|
-
raise ValueError("
|
508
|
+
if self.dataset_config.return_tensors:
|
509
|
+
raise ValueError("Dataframes are not available when return_tensors is set. Use a dataloader instead.")
|
510
|
+
if check_train and not self.dataset_config.need_train_set:
|
511
|
+
raise ValueError("Train dataframe is not available when need_train_set is false")
|
512
|
+
if check_val and not self.dataset_config.need_val_set:
|
513
|
+
raise ValueError("Validation dataframe is not available when need_val_set is false")
|
514
|
+
if check_test and not self.dataset_config.need_test_set:
|
515
|
+
raise ValueError("Test dataframe is not available when need_test_set is false")
|
515
516
|
|
516
517
|
def _initialize_train_val_test(self, disable_indices_cache: bool = False) -> None:
|
517
518
|
assert self.dataset_config is not None
|
518
519
|
dataset_config = self.dataset_config
|
519
520
|
servicemap = pd.read_csv(dataset_config.servicemap_path, index_col="Tag")
|
520
|
-
# Initialize train
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
521
|
+
# Initialize train set
|
522
|
+
if dataset_config.need_train_set:
|
523
|
+
train_indices, train_unknown_indices, known_apps, unknown_apps = init_or_load_train_indices(dataset_config=dataset_config,
|
524
|
+
tables_app_enum=self._tables_app_enum,
|
525
|
+
servicemap=servicemap,
|
526
|
+
disable_indices_cache=disable_indices_cache,)
|
527
|
+
# Date weight sampling of train indices
|
528
|
+
if dataset_config.train_dates_weigths is not None:
|
529
|
+
assert dataset_config.train_size != "all"
|
530
|
+
if dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
|
531
|
+
# requested number of samples is train_size + val_known_size when using the split-from-train validation approach
|
532
|
+
assert dataset_config.val_known_size != "all"
|
533
|
+
num_samples = dataset_config.train_size + dataset_config.val_known_size
|
534
|
+
else:
|
535
|
+
num_samples = dataset_config.train_size
|
536
|
+
if num_samples > len(train_indices):
|
537
|
+
raise ValueError(f"Requested number of samples for weight sampling ({num_samples}) is larger than the number of available train samples ({len(train_indices)})")
|
538
|
+
train_indices = date_weight_sample_train_indices(dataset_config=dataset_config, train_indices=train_indices, num_samples=num_samples)
|
539
|
+
elif dataset_config.apps_selection == AppSelection.FIXED:
|
540
|
+
known_apps = dataset_config.apps_selection_fixed_known
|
541
|
+
unknown_apps = dataset_config.apps_selection_fixed_unknown
|
542
|
+
train_indices = np.zeros((0,3), dtype=np.int64)
|
543
|
+
train_unknown_indices = np.zeros((0,3), dtype=np.int64)
|
528
544
|
else:
|
529
|
-
|
530
|
-
|
531
|
-
|
545
|
+
raise ValueError("Either need train set or the fixed application selection")
|
546
|
+
# Initialize validation set
|
547
|
+
if dataset_config.need_val_set:
|
548
|
+
if dataset_config.val_approach == ValidationApproach.VALIDATION_DATES:
|
549
|
+
val_known_indices, val_unknown_indices, val_data_path = init_or_load_val_indices(dataset_config=dataset_config,
|
550
|
+
known_apps=known_apps,
|
551
|
+
unknown_apps=unknown_apps,
|
552
|
+
tables_app_enum=self._tables_app_enum,
|
532
553
|
disable_indices_cache=disable_indices_cache,)
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.val_known_size, stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
562
|
-
dataset_config.train_size = len(train_indices)
|
563
|
-
elif dataset_config.train_size == "all" and dataset_config.val_known_size == "all":
|
564
|
-
train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.train_val_split_fraction, stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
565
|
-
else:
|
566
|
-
if dataset_config.val_known_size != "all" and dataset_config.train_size != "all" and dataset_config.train_size + dataset_config.val_known_size > len(train_indices):
|
567
|
-
raise ValueError(f"Requested train size + validation size ({dataset_config.train_size + dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
|
568
|
-
if dataset_config.train_size != "all" and dataset_config.train_size > len(train_indices):
|
569
|
-
raise ValueError(f"Requested train size ({dataset_config.train_size}) is larger than the number of available train samples ({len(train_indices)})")
|
570
|
-
if dataset_config.val_known_size != "all" and dataset_config.val_known_size > len(train_indices):
|
571
|
-
raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
|
572
|
-
train_indices, val_known_indices = train_test_split(train_indices,
|
573
|
-
train_size=dataset_config.train_size if dataset_config.train_size != "all" else None,
|
574
|
-
test_size=dataset_config.val_known_size if dataset_config.val_known_size != "all" else None,
|
575
|
-
stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
576
|
-
elif dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
|
577
|
-
val_known_indices = np.empty((0,3), dtype=np.int64)
|
578
|
-
val_unknown_indices = np.empty((0,3), dtype=np.int64)
|
554
|
+
elif dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
|
555
|
+
train_val_rng = get_fresh_random_generator(dataset_config=dataset_config, section=RandomizedSection.TRAIN_VAL_SPLIT)
|
556
|
+
val_data_path = dataset_config._get_train_data_path()
|
557
|
+
val_unknown_indices = train_unknown_indices
|
558
|
+
train_labels = train_indices[:, INDICES_LABEL_POS]
|
559
|
+
if dataset_config.train_dates_weigths is not None:
|
560
|
+
assert dataset_config.val_known_size != "all"
|
561
|
+
# When weight sampling is used, val_known_size is kept but the resulting train size can be smaller due to no enough samples in some train dates
|
562
|
+
if dataset_config.val_known_size > len(train_indices):
|
563
|
+
raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples after weight sampling ({len(train_indices)})")
|
564
|
+
train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.val_known_size, stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
565
|
+
dataset_config.train_size = len(train_indices)
|
566
|
+
elif dataset_config.train_size == "all" and dataset_config.val_known_size == "all":
|
567
|
+
train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.train_val_split_fraction, stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
568
|
+
else:
|
569
|
+
if dataset_config.val_known_size != "all" and dataset_config.train_size != "all" and dataset_config.train_size + dataset_config.val_known_size > len(train_indices):
|
570
|
+
raise ValueError(f"Requested train size + validation size ({dataset_config.train_size + dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
|
571
|
+
if dataset_config.train_size != "all" and dataset_config.train_size > len(train_indices):
|
572
|
+
raise ValueError(f"Requested train size ({dataset_config.train_size}) is larger than the number of available train samples ({len(train_indices)})")
|
573
|
+
if dataset_config.val_known_size != "all" and dataset_config.val_known_size > len(train_indices):
|
574
|
+
raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
|
575
|
+
train_indices, val_known_indices = train_test_split(train_indices,
|
576
|
+
train_size=dataset_config.train_size if dataset_config.train_size != "all" else None,
|
577
|
+
test_size=dataset_config.val_known_size if dataset_config.val_known_size != "all" else None,
|
578
|
+
stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
579
|
+
else:
|
580
|
+
val_known_indices = np.zeros((0,3), dtype=np.int64)
|
581
|
+
val_unknown_indices = np.zeros((0,3), dtype=np.int64)
|
579
582
|
val_data_path = None
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
583
|
+
# Initialize test set
|
584
|
+
if dataset_config.need_test_set:
|
585
|
+
test_known_indices, test_unknown_indices, test_data_path = init_or_load_test_indices(dataset_config=dataset_config,
|
586
|
+
known_apps=known_apps,
|
587
|
+
unknown_apps=unknown_apps,
|
588
|
+
tables_app_enum=self._tables_app_enum,
|
589
|
+
disable_indices_cache=disable_indices_cache,)
|
590
|
+
else:
|
591
|
+
test_known_indices = np.zeros((0,3), dtype=np.int64)
|
592
|
+
test_unknown_indices = np.zeros((0,3), dtype=np.int64)
|
593
|
+
test_data_path = None
|
594
|
+
# Fit scalers if needed
|
595
|
+
if (dataset_config.ppi_transform is not None and dataset_config.ppi_transform.needs_fitting or
|
596
|
+
dataset_config.flowstats_transform is not None and dataset_config.flowstats_transform.needs_fitting):
|
597
|
+
if not dataset_config.need_train_set:
|
598
|
+
raise ValueError("Train set is needed to fit the scalers. Provide pre-fitted scalers.")
|
599
|
+
fit_scalers(dataset_config=dataset_config, train_indices=train_indices)
|
586
600
|
# Subset dataset indices based on the selected sizes and compute application counts
|
587
601
|
dataset_indices = IndicesTuple(train_indices=train_indices, val_known_indices=val_known_indices, val_unknown_indices=val_unknown_indices, test_known_indices=test_known_indices, test_unknown_indices=test_unknown_indices)
|
588
602
|
dataset_indices = subset_and_sort_indices(dataset_config=dataset_config, dataset_indices=dataset_indices)
|
589
|
-
known_app_counts = compute_known_app_counts(dataset_indices=dataset_indices,
|
590
|
-
unknown_app_counts = compute_unknown_app_counts(dataset_indices=dataset_indices,
|
603
|
+
known_app_counts = compute_known_app_counts(dataset_indices=dataset_indices, tables_app_enum=self._tables_app_enum)
|
604
|
+
unknown_app_counts = compute_unknown_app_counts(dataset_indices=dataset_indices, tables_app_enum=self._tables_app_enum)
|
591
605
|
# Combine known and unknown test indicies to create a single dataloader
|
592
606
|
assert isinstance(dataset_config.test_unknown_size, int)
|
593
|
-
if dataset_config.test_unknown_size > 0 and len(
|
607
|
+
if dataset_config.test_unknown_size > 0 and len(unknown_apps) > 0:
|
594
608
|
test_combined_indices = np.concatenate((dataset_indices.test_known_indices, dataset_indices.test_unknown_indices))
|
595
609
|
else:
|
596
610
|
test_combined_indices = dataset_indices.test_known_indices
|
597
|
-
|
611
|
+
# Create encoder the class info structure
|
612
|
+
encoder = LabelEncoder().fit(known_apps)
|
613
|
+
encoder.classes_ = np.append(encoder.classes_, UNKNOWN_STR_LABEL)
|
614
|
+
class_info = create_class_info(servicemap=servicemap, encoder=encoder, known_apps=known_apps, unknown_apps=unknown_apps)
|
615
|
+
encode_labels_with_unknown_fn = partial(_encode_labels_with_unknown, encoder=encoder, class_info=class_info)
|
598
616
|
# Create train, validation, and test datasets
|
599
|
-
train_dataset =
|
600
|
-
|
601
|
-
|
602
|
-
indices=dataset_indices.train_indices,
|
603
|
-
flowstats_features=dataset_config.flowstats_features,
|
604
|
-
other_fields=self.dataset_config.other_fields,)
|
605
|
-
if dataset_config.no_test_set:
|
606
|
-
test_dataset = None
|
607
|
-
else:
|
608
|
-
assert test_data_path is not None
|
609
|
-
test_dataset = PyTablesDataset(
|
617
|
+
train_dataset = val_dataset = test_dataset = None
|
618
|
+
if dataset_config.need_train_set:
|
619
|
+
train_dataset = PyTablesDataset(
|
610
620
|
database_path=dataset_config.database_path,
|
611
|
-
tables_paths=dataset_config.
|
612
|
-
indices=
|
621
|
+
tables_paths=dataset_config._get_train_tables_paths(),
|
622
|
+
indices=dataset_indices.train_indices,
|
623
|
+
tables_app_enum=self._tables_app_enum,
|
624
|
+
tables_cat_enum=self._tables_cat_enum,
|
613
625
|
flowstats_features=dataset_config.flowstats_features,
|
626
|
+
flowstats_features_boolean=dataset_config.flowstats_features_boolean,
|
627
|
+
flowstats_features_phist=dataset_config.flowstats_features_phist,
|
614
628
|
other_fields=self.dataset_config.other_fields,
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
629
|
+
ppi_channels=dataset_config.get_ppi_channels(),
|
630
|
+
ppi_transform=dataset_config.ppi_transform,
|
631
|
+
flowstats_transform=dataset_config.flowstats_transform,
|
632
|
+
flowstats_phist_transform=dataset_config.flowstats_phist_transform,
|
633
|
+
target_transform=encode_labels_with_unknown_fn,
|
634
|
+
return_tensors=dataset_config.return_tensors,)
|
635
|
+
if dataset_config.need_val_set:
|
620
636
|
assert val_data_path is not None
|
621
637
|
val_dataset = PyTablesDataset(
|
622
638
|
database_path=dataset_config.database_path,
|
623
639
|
tables_paths=dataset_config._get_train_tables_paths(),
|
624
640
|
indices=dataset_indices.val_known_indices,
|
641
|
+
tables_app_enum=self._tables_app_enum,
|
642
|
+
tables_cat_enum=self._tables_cat_enum,
|
625
643
|
flowstats_features=dataset_config.flowstats_features,
|
644
|
+
flowstats_features_boolean=dataset_config.flowstats_features_boolean,
|
645
|
+
flowstats_features_phist=dataset_config.flowstats_features_phist,
|
626
646
|
other_fields=self.dataset_config.other_fields,
|
647
|
+
ppi_channels=dataset_config.get_ppi_channels(),
|
648
|
+
ppi_transform=dataset_config.ppi_transform,
|
649
|
+
flowstats_transform=dataset_config.flowstats_transform,
|
650
|
+
flowstats_phist_transform=dataset_config.flowstats_phist_transform,
|
651
|
+
target_transform=encode_labels_with_unknown_fn,
|
652
|
+
return_tensors=dataset_config.return_tensors,
|
627
653
|
preload=dataset_config.preload_val,
|
628
654
|
preload_blob=os.path.join(val_data_path, "preload", f"val_dataset-{dataset_config.val_known_size}.npz"),)
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
655
|
+
if dataset_config.need_test_set:
|
656
|
+
assert test_data_path is not None
|
657
|
+
test_dataset = PyTablesDataset(
|
658
|
+
database_path=dataset_config.database_path,
|
659
|
+
tables_paths=dataset_config._get_test_tables_paths(),
|
660
|
+
indices=test_combined_indices,
|
661
|
+
tables_app_enum=self._tables_app_enum,
|
662
|
+
tables_cat_enum=self._tables_cat_enum,
|
663
|
+
flowstats_features=dataset_config.flowstats_features,
|
664
|
+
flowstats_features_boolean=dataset_config.flowstats_features_boolean,
|
665
|
+
flowstats_features_phist=dataset_config.flowstats_features_phist,
|
666
|
+
other_fields=self.dataset_config.other_fields,
|
667
|
+
ppi_channels=dataset_config.get_ppi_channels(),
|
668
|
+
ppi_transform=dataset_config.ppi_transform,
|
669
|
+
flowstats_transform=dataset_config.flowstats_transform,
|
670
|
+
flowstats_phist_transform=dataset_config.flowstats_phist_transform,
|
671
|
+
target_transform=encode_labels_with_unknown_fn,
|
672
|
+
return_tensors=dataset_config.return_tensors,
|
673
|
+
preload=dataset_config.preload_test,
|
674
|
+
preload_blob=os.path.join(test_data_path, "preload", f"test_dataset-{dataset_config.test_known_size}-{dataset_config.test_unknown_size}.npz"),)
|
644
675
|
self.class_info = class_info
|
645
676
|
self.dataset_indices = dataset_indices
|
646
677
|
self.train_dataset = train_dataset
|
647
678
|
self.val_dataset = val_dataset
|
648
679
|
self.test_dataset = test_dataset
|
649
|
-
self.known_apps_database_enum = known_apps_database_enum
|
650
|
-
self.unknown_apps_database_enum = unknown_apps_database_enum
|
651
680
|
self.known_app_counts = known_app_counts
|
652
681
|
self.unknown_app_counts = unknown_app_counts
|
653
|
-
self.
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
self.ipt_scaler = ipt_scaler
|
682
|
+
self._collate_fn = collate_fn_simple
|
683
|
+
|
684
|
+
def _encode_labels_with_unknown(labels, encoder: LabelEncoder, class_info: ClassInfo):
|
685
|
+
return encoder.transform(np.where(np.isin(labels, class_info.known_apps), labels, UNKNOWN_STR_LABEL))
|