cesnet-datazoo 0.0.17__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cesnet_datazoo/config.py +173 -168
- cesnet_datazoo/constants.py +4 -6
- cesnet_datazoo/datasets/cesnet_dataset.py +200 -177
- cesnet_datazoo/datasets/datasets.py +22 -2
- cesnet_datazoo/datasets/datasets_constants.py +670 -0
- cesnet_datazoo/datasets/loaders.py +3 -0
- cesnet_datazoo/datasets/metadata/dataset_metadata.py +6 -5
- cesnet_datazoo/datasets/metadata/metadata.csv +4 -4
- cesnet_datazoo/datasets/statistics.py +36 -16
- cesnet_datazoo/pytables_data/data_scalers.py +68 -154
- cesnet_datazoo/pytables_data/indices_setup.py +29 -33
- cesnet_datazoo/pytables_data/pytables_dataset.py +99 -122
- cesnet_datazoo/utils/class_info.py +7 -5
- {cesnet_datazoo-0.0.17.dist-info → cesnet_datazoo-0.1.0.dist-info}/METADATA +2 -1
- cesnet_datazoo-0.1.0.dist-info/RECORD +30 -0
- {cesnet_datazoo-0.0.17.dist-info → cesnet_datazoo-0.1.0.dist-info}/WHEEL +1 -1
- cesnet_datazoo-0.0.17.dist-info/RECORD +0 -29
- {cesnet_datazoo-0.0.17.dist-info → cesnet_datazoo-0.1.0.dist-info}/LICENCE +0 -0
- {cesnet_datazoo-0.0.17.dist-info → cesnet_datazoo-0.1.0.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,14 @@ from sklearn.preprocessing import LabelEncoder
|
|
15
15
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
|
16
16
|
from typing_extensions import assert_never
|
17
17
|
|
18
|
-
from cesnet_datazoo.config import DataLoaderOrder, DatasetConfig,
|
19
|
-
from cesnet_datazoo.constants import DATASET_SIZES, INDICES_LABEL_POS,
|
20
|
-
|
18
|
+
from cesnet_datazoo.config import AppSelection, DataLoaderOrder, DatasetConfig, ValidationApproach
|
19
|
+
from cesnet_datazoo.constants import (APP_COLUMN, CATEGORY_COLUMN, DATASET_SIZES, INDICES_LABEL_POS,
|
20
|
+
SERVICEMAP_FILE, UNKNOWN_STR_LABEL)
|
21
|
+
from cesnet_datazoo.datasets.loaders import collate_fn_simple, create_df_from_dataloader
|
21
22
|
from cesnet_datazoo.datasets.metadata.dataset_metadata import DatasetMetadata, load_metadata
|
22
23
|
from cesnet_datazoo.datasets.statistics import compute_dataset_statistics
|
24
|
+
from cesnet_datazoo.pytables_data.apps_split import is_background_app
|
25
|
+
from cesnet_datazoo.pytables_data.data_scalers import fit_scalers
|
23
26
|
from cesnet_datazoo.pytables_data.indices_setup import (IndicesTuple, compute_known_app_counts,
|
24
27
|
compute_unknown_app_counts,
|
25
28
|
date_weight_sample_train_indices,
|
@@ -27,10 +30,8 @@ from cesnet_datazoo.pytables_data.indices_setup import (IndicesTuple, compute_kn
|
|
27
30
|
init_or_load_train_indices,
|
28
31
|
init_or_load_val_indices,
|
29
32
|
subset_and_sort_indices)
|
30
|
-
from cesnet_datazoo.pytables_data.pytables_dataset import
|
31
|
-
worker_init_fn)
|
33
|
+
from cesnet_datazoo.pytables_data.pytables_dataset import PyTablesDataset, worker_init_fn
|
32
34
|
from cesnet_datazoo.utils.class_info import ClassInfo, create_class_info
|
33
|
-
from cesnet_datazoo.pytables_data.data_scalers import fit_or_load_scalers
|
34
35
|
from cesnet_datazoo.utils.download import resumable_download, simple_download
|
35
36
|
from cesnet_datazoo.utils.random import RandomizedSection, get_fresh_random_generator
|
36
37
|
|
@@ -39,8 +40,7 @@ DATAFRAME_SAMPLES_WARNING_THRESHOLD = 20_000_000
|
|
39
40
|
|
40
41
|
class CesnetDataset():
|
41
42
|
"""
|
42
|
-
The main class for accessing CESNET datasets. It handles downloading, data
|
43
|
-
train/validation/test splitting, and class selection. Access to data is provided through:
|
43
|
+
The main class for accessing CESNET datasets. It handles downloading, train/validation/test splitting, and class selection. Access to data is provided through:
|
44
44
|
|
45
45
|
- Iterable PyTorch DataLoader for batch processing. See [using dataloaders][using-dataloaders] for more details.
|
46
46
|
- Pandas DataFrame for loading the entire train, validation, or test set at once.
|
@@ -54,7 +54,7 @@ class CesnetDataset():
|
|
54
54
|
|
55
55
|
1. Create an instance of the [dataset class][dataset-classes] with the desired size and data root. This will download the dataset if it has not already been downloaded.
|
56
56
|
2. Create an instance of [`DatasetConfig`][config.DatasetConfig] and set it with [`set_dataset_config_and_initialize`][datasets.cesnet_dataset.CesnetDataset.set_dataset_config_and_initialize].
|
57
|
-
This will initialize the dataset — select classes, split data into train/validation/test sets, and fit data scalers. All is done according to the provided configuration and is cached for later use.
|
57
|
+
This will initialize the dataset — select classes, split data into train/validation/test sets, and fit data scalers if needed. All is done according to the provided configuration and is cached for later use.
|
58
58
|
3. Use [`get_train_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_train_dataloader] or [`get_train_df`][datasets.cesnet_dataset.CesnetDataset.get_train_df] to get training data for a classification model.
|
59
59
|
4. Validate the model and perform the hyperparameter optimalization on [`get_val_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_val_dataloader] or [`get_val_df`][datasets.cesnet_dataset.CesnetDataset.get_val_df].
|
60
60
|
5. Evaluate the model on [`get_test_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_test_dataloader] or [`get_test_df`][datasets.cesnet_dataset.CesnetDataset.get_test_df].
|
@@ -69,9 +69,10 @@ class CesnetDataset():
|
|
69
69
|
database_filename: Name of the database file.
|
70
70
|
database_path: Path to the database file.
|
71
71
|
servicemap_path: Path to the servicemap file.
|
72
|
-
statistics_path: Path to the dataset statistics.
|
72
|
+
statistics_path: Path to the dataset statistics folder.
|
73
73
|
bucket_url: URL of the bucket where the database is stored.
|
74
74
|
metadata: Additional [dataset metadata][metadata].
|
75
|
+
available_classes: List of all available classes in the dataset.
|
75
76
|
available_dates: List of all available dates in the dataset.
|
76
77
|
time_periods: Predefined time periods. Each time period is a list of dates.
|
77
78
|
default_train_period_name: Default time period for training.
|
@@ -86,36 +87,30 @@ class CesnetDataset():
|
|
86
87
|
train_dataset: Train set in the form of `PyTablesDataset` instance wrapping the PyTables database.
|
87
88
|
val_dataset: Validation set in the form of `PyTablesDataset` instance wrapping the PyTables database.
|
88
89
|
test_dataset: Test set in the form of `PyTablesDataset` instance wrapping the PyTables database.
|
89
|
-
known_apps_database_enum: Dictionary that maps the database integer labels (different to those from `encoder`) of known applications to their names.
|
90
|
-
unknown_apps_database_enum: Dictionary that maps the database integer labels (different to those from `encoder`) of unknown applications to their names.
|
91
90
|
known_app_counts: Known application counts in the train, validation, and test sets.
|
92
91
|
unknown_app_counts: Unknown application counts in the validation and test sets.
|
93
|
-
collate_fn: Collate function used for creating batches in dataloaders.
|
94
|
-
encoder: Scikit-learn [`LabelEncoder`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html) used to encode class names into integers. It is fitted during the initialization of the dataset.
|
95
|
-
flowstats_scaler: Scaler for flow statistics. It is fitted during the initialization of the dataset.
|
96
|
-
psizes_scaler: Scaler for packet sizes.
|
97
|
-
ipt_scaler: Scaler for inter-packet times.
|
98
|
-
flowstats_quantiles: Quantiles of flow statistics used for clipping.
|
99
92
|
train_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for training.
|
100
93
|
train_dataloader_sampler: Sampler used for iterating the training dataloader. Either [`RandomSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.RandomSampler) or [`SequentialSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.SequentialSampler).
|
94
|
+
train_dataloader_drop_last: Whether to drop the last incomplete batch when iterating the training dataloader.
|
101
95
|
val_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for validation.
|
102
96
|
test_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for testing.
|
103
97
|
"""
|
104
|
-
name: str
|
105
|
-
size: str
|
106
98
|
data_root: str
|
99
|
+
size: str
|
100
|
+
silent: bool = False
|
101
|
+
|
102
|
+
name: str
|
107
103
|
database_filename: str
|
108
104
|
database_path: str
|
109
105
|
servicemap_path: str
|
110
106
|
statistics_path: str
|
111
107
|
bucket_url: str
|
112
108
|
metadata: DatasetMetadata
|
109
|
+
available_classes: list[str]
|
113
110
|
available_dates: list[str]
|
114
111
|
time_periods: dict[str, list[str]]
|
115
112
|
default_train_period_name: str
|
116
113
|
default_test_period_name: str
|
117
|
-
time_periods_gen: bool = False
|
118
|
-
silent: bool = False
|
119
114
|
|
120
115
|
dataset_config: Optional[DatasetConfig] = None
|
121
116
|
class_info: Optional[ClassInfo] = None
|
@@ -123,25 +118,19 @@ class CesnetDataset():
|
|
123
118
|
train_dataset: Optional[PyTablesDataset] = None
|
124
119
|
val_dataset: Optional[PyTablesDataset] = None
|
125
120
|
test_dataset: Optional[PyTablesDataset] = None
|
126
|
-
known_apps_database_enum: Optional[dict[int, str]] = None
|
127
|
-
unknown_apps_database_enum: Optional[dict[int, str]] = None
|
128
121
|
known_app_counts: Optional[pd.DataFrame] = None
|
129
122
|
unknown_app_counts: Optional[pd.DataFrame] = None
|
130
|
-
|
131
|
-
collate_fn: Optional[Callable] = None
|
132
|
-
encoder: Optional[LabelEncoder] = None
|
133
|
-
flowstats_scaler: Scaler = None
|
134
|
-
psizes_scaler: Scaler = None
|
135
|
-
ipt_scaler: Scaler = None
|
136
|
-
flowstats_quantiles: Optional[np.ndarray] = None
|
137
|
-
|
138
123
|
train_dataloader: Optional[DataLoader] = None
|
139
124
|
train_dataloader_sampler: Optional[Sampler] = None
|
140
125
|
train_dataloader_drop_last: bool = True
|
141
126
|
val_dataloader: Optional[DataLoader] = None
|
142
127
|
test_dataloader: Optional[DataLoader] = None
|
143
128
|
|
144
|
-
|
129
|
+
_collate_fn: Optional[Callable] = None
|
130
|
+
_tables_app_enum: dict[int, str]
|
131
|
+
_tables_cat_enum: dict[int, str]
|
132
|
+
|
133
|
+
def __init__(self, data_root: str, size: str = "S", database_checks_at_init: bool = False, silent: bool = False) -> None:
|
145
134
|
self.silent = silent
|
146
135
|
self.metadata = load_metadata(self.name)
|
147
136
|
self.size = size
|
@@ -159,24 +148,31 @@ class CesnetDataset():
|
|
159
148
|
os.makedirs(self.data_root)
|
160
149
|
if not self._is_downloaded():
|
161
150
|
self._download()
|
162
|
-
if
|
151
|
+
if database_checks_at_init:
|
163
152
|
with tb.open_file(self.database_path, mode="r") as database:
|
164
153
|
tables_paths = list(map(lambda x: x._v_pathname, iter(database.get_node(f"/flows"))))
|
165
154
|
num_samples = 0
|
166
155
|
for p in tables_paths:
|
167
|
-
|
156
|
+
table = database.get_node(p)
|
157
|
+
assert isinstance(table, tb.Table)
|
158
|
+
if self._tables_app_enum != {v: k for k, v in dict(table.get_enum(APP_COLUMN)).items()}:
|
159
|
+
raise ValueError(f"Found mismatch between _tables_app_enum and the PyTables database enum in table {p}. Please report this issue.")
|
160
|
+
if self._tables_cat_enum != {v: k for k, v in dict(table.get_enum(CATEGORY_COLUMN)).items()}:
|
161
|
+
raise ValueError(f"Found mismatch between _tables_cat_enum and the PyTables database enum in table {p}. Please report this issue.")
|
162
|
+
num_samples += len(table)
|
168
163
|
if self.size == "ORIG" and num_samples != self.metadata.available_samples:
|
169
164
|
raise ValueError(f"Expected {self.metadata.available_samples} samples, but got {num_samples} in the database. Please delete the data root folder, update cesnet-datazoo, and redownload the dataset.")
|
170
165
|
if self.size != "ORIG" and num_samples != DATASET_SIZES[self.size]:
|
171
166
|
raise ValueError(f"Expected {DATASET_SIZES[self.size]} samples, but got {num_samples} in the database. Please delete the data root folder, update cesnet-datazoo, and redownload the dataset.")
|
172
|
-
self.available_dates
|
173
|
-
|
174
|
-
self.available_dates = []
|
175
|
-
if self.time_periods_gen:
|
176
|
-
self._generate_time_periods()
|
167
|
+
if self.available_dates != list(map(lambda x: x.removeprefix("/flows/D"), tables_paths)):
|
168
|
+
raise ValueError(f"Found mismatch between available_dates and the dates available in the PyTables database. Please report this issue.")
|
177
169
|
# Add all available dates as single date time periods
|
178
170
|
for d in self.available_dates:
|
179
171
|
self.time_periods[d] = [d]
|
172
|
+
available_applications = sorted([app for app in pd.read_csv(self.servicemap_path, index_col="Tag").index if not is_background_app(app)])
|
173
|
+
if len(available_applications) != self.metadata.application_count:
|
174
|
+
raise ValueError(f"Found {len(available_applications)} applications in the servicemap (omitting background traffic classes), but expected {self.metadata.application_count}. Please report this issue.")
|
175
|
+
self.available_classes = available_applications + self.metadata.background_traffic_classes
|
180
176
|
|
181
177
|
def set_dataset_config_and_initialize(self, dataset_config: DatasetConfig, disable_indices_cache: bool = False) -> None:
|
182
178
|
"""
|
@@ -208,6 +204,8 @@ class CesnetDataset():
|
|
208
204
|
"""
|
209
205
|
if self.dataset_config is None:
|
210
206
|
raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting train dataloader")
|
207
|
+
if not self.dataset_config.need_train_set:
|
208
|
+
raise ValueError("Train dataloader is not available when need_train_set is false")
|
211
209
|
assert self.train_dataset
|
212
210
|
if self.train_dataloader:
|
213
211
|
return self.train_dataloader
|
@@ -230,7 +228,7 @@ class CesnetDataset():
|
|
230
228
|
self.train_dataset,
|
231
229
|
num_workers=self.dataset_config.train_workers,
|
232
230
|
worker_init_fn=worker_init_fn,
|
233
|
-
collate_fn=self.
|
231
|
+
collate_fn=self._collate_fn,
|
234
232
|
persistent_workers=self.dataset_config.train_workers > 0,
|
235
233
|
batch_size=None,
|
236
234
|
sampler=batch_sampler,)
|
@@ -255,8 +253,8 @@ class CesnetDataset():
|
|
255
253
|
"""
|
256
254
|
if self.dataset_config is None:
|
257
255
|
raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting validaion dataloader")
|
258
|
-
if self.dataset_config.
|
259
|
-
raise ValueError("Validation dataloader is not available when
|
256
|
+
if not self.dataset_config.need_val_set:
|
257
|
+
raise ValueError("Validation dataloader is not available when need_val_set is false")
|
260
258
|
assert self.val_dataset is not None
|
261
259
|
if self.val_dataloader:
|
262
260
|
return self.val_dataloader
|
@@ -265,7 +263,7 @@ class CesnetDataset():
|
|
265
263
|
self.val_dataset,
|
266
264
|
num_workers=self.dataset_config.val_workers,
|
267
265
|
worker_init_fn=worker_init_fn,
|
268
|
-
collate_fn=self.
|
266
|
+
collate_fn=self._collate_fn,
|
269
267
|
persistent_workers=self.dataset_config.val_workers > 0,
|
270
268
|
batch_size=None,
|
271
269
|
sampler=batch_sampler,)
|
@@ -294,8 +292,8 @@ class CesnetDataset():
|
|
294
292
|
"""
|
295
293
|
if self.dataset_config is None:
|
296
294
|
raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting test dataloader")
|
297
|
-
if self.dataset_config.
|
298
|
-
raise ValueError("Test dataloader is not available when
|
295
|
+
if not self.dataset_config.need_test_set:
|
296
|
+
raise ValueError("Test dataloader is not available when need_test_set is false")
|
299
297
|
assert self.test_dataset is not None
|
300
298
|
if self.test_dataloader:
|
301
299
|
return self.test_dataloader
|
@@ -304,7 +302,7 @@ class CesnetDataset():
|
|
304
302
|
self.test_dataset,
|
305
303
|
num_workers=self.dataset_config.test_workers,
|
306
304
|
worker_init_fn=worker_init_fn,
|
307
|
-
collate_fn=self.
|
305
|
+
collate_fn=self._collate_fn,
|
308
306
|
persistent_workers=False,
|
309
307
|
batch_size=None,
|
310
308
|
sampler=batch_sampler,)
|
@@ -336,7 +334,7 @@ class CesnetDataset():
|
|
336
334
|
Returns:
|
337
335
|
Train data as a dataframe.
|
338
336
|
"""
|
339
|
-
self._check_before_dataframe()
|
337
|
+
self._check_before_dataframe(check_train=True)
|
340
338
|
assert self.dataset_config is not None and self.train_dataset is not None
|
341
339
|
if len(self.train_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
|
342
340
|
warnings.warn(f"Train set has ({len(self.train_dataset)} samples), consider using get_train_dataloader() instead")
|
@@ -369,7 +367,7 @@ class CesnetDataset():
|
|
369
367
|
Returns:
|
370
368
|
Validation data as a dataframe.
|
371
369
|
"""
|
372
|
-
self._check_before_dataframe(
|
370
|
+
self._check_before_dataframe(check_val=True)
|
373
371
|
assert self.dataset_config is not None and self.val_dataset is not None
|
374
372
|
if len(self.val_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
|
375
373
|
warnings.warn(f"Validation set has ({len(self.val_dataset)} samples), consider using get_val_dataloader() instead")
|
@@ -398,7 +396,7 @@ class CesnetDataset():
|
|
398
396
|
Returns:
|
399
397
|
Test data as a dataframe.
|
400
398
|
"""
|
401
|
-
self._check_before_dataframe(
|
399
|
+
self._check_before_dataframe(check_test=True)
|
402
400
|
assert self.dataset_config is not None and self.test_dataset is not None
|
403
401
|
if len(self.test_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
|
404
402
|
warnings.warn(f"Test set has ({len(self.test_dataset)} samples), consider using get_test_dataloader() instead")
|
@@ -436,12 +434,18 @@ class CesnetDataset():
|
|
436
434
|
batch_size: Number of samples per batch for loading data.
|
437
435
|
disabled_apps: List of applications to exclude from the statistics.
|
438
436
|
"""
|
439
|
-
|
437
|
+
if disabled_apps:
|
438
|
+
bad_disabled_apps = [a for a in disabled_apps if a not in self.available_classes]
|
439
|
+
if len(bad_disabled_apps) > 0:
|
440
|
+
raise ValueError(f"Bad applications in disabled_apps {bad_disabled_apps}. Use applications available in dataset.available_classes")
|
440
441
|
if not os.path.exists(self.statistics_path):
|
441
442
|
os.mkdir(self.statistics_path)
|
442
443
|
compute_dataset_statistics(database_path=self.database_path,
|
444
|
+
tables_app_enum=self._tables_app_enum,
|
445
|
+
tables_cat_enum=self._tables_cat_enum,
|
443
446
|
output_dir=self.statistics_path,
|
444
|
-
|
447
|
+
packet_histograms=self.metadata.packet_histograms,
|
448
|
+
flowstats_features_boolean=self.metadata.flowstats_features_boolean,
|
445
449
|
protocol=self.metadata.protocol,
|
446
450
|
extra_fields=not self.name.startswith("CESNET-TLS22"),
|
447
451
|
disabled_apps=disabled_apps if disabled_apps is not None else [],
|
@@ -489,174 +493,193 @@ class CesnetDataset():
|
|
489
493
|
self.train_dataset = None
|
490
494
|
self.val_dataset = None
|
491
495
|
self.test_dataset = None
|
492
|
-
self.known_apps_database_enum = None
|
493
|
-
self.unknown_apps_database_enum = None
|
494
496
|
self.known_app_counts = None
|
495
497
|
self.unknown_app_counts = None
|
496
|
-
|
497
|
-
self.collate_fn = None
|
498
|
-
self.encoder = None
|
499
|
-
self.flowstats_scaler = None
|
500
|
-
self.psizes_scaler = None
|
501
|
-
self.ipt_scaler = None
|
502
|
-
self.flowstats_quantiles = None
|
503
|
-
|
504
498
|
self.train_dataloader = None
|
505
499
|
self.train_dataloader_sampler = None
|
506
500
|
self.train_dataloader_drop_last = True
|
507
501
|
self.val_dataloader = None
|
508
502
|
self.test_dataloader = None
|
503
|
+
self._collate_fn = None
|
509
504
|
|
510
|
-
def _check_before_dataframe(self,
|
505
|
+
def _check_before_dataframe(self, check_train: bool = False, check_val: bool = False, check_test: bool = False) -> None:
|
511
506
|
if self.dataset_config is None:
|
512
507
|
raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting a dataframe")
|
513
|
-
if self.dataset_config.
|
514
|
-
raise ValueError("Dataframes are not available when
|
515
|
-
if
|
516
|
-
raise ValueError("
|
517
|
-
if
|
518
|
-
raise ValueError("
|
508
|
+
if self.dataset_config.return_tensors:
|
509
|
+
raise ValueError("Dataframes are not available when return_tensors is set. Use a dataloader instead.")
|
510
|
+
if check_train and not self.dataset_config.need_train_set:
|
511
|
+
raise ValueError("Train dataframe is not available when need_train_set is false")
|
512
|
+
if check_val and not self.dataset_config.need_val_set:
|
513
|
+
raise ValueError("Validation dataframe is not available when need_val_set is false")
|
514
|
+
if check_test and not self.dataset_config.need_test_set:
|
515
|
+
raise ValueError("Test dataframe is not available when need_test_set is false")
|
519
516
|
|
520
517
|
def _initialize_train_val_test(self, disable_indices_cache: bool = False) -> None:
|
521
518
|
assert self.dataset_config is not None
|
522
519
|
dataset_config = self.dataset_config
|
523
520
|
servicemap = pd.read_csv(dataset_config.servicemap_path, index_col="Tag")
|
524
|
-
# Initialize train
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
521
|
+
# Initialize train set
|
522
|
+
if dataset_config.need_train_set:
|
523
|
+
train_indices, train_unknown_indices, known_apps, unknown_apps = init_or_load_train_indices(dataset_config=dataset_config,
|
524
|
+
tables_app_enum=self._tables_app_enum,
|
525
|
+
servicemap=servicemap,
|
526
|
+
disable_indices_cache=disable_indices_cache,)
|
527
|
+
# Date weight sampling of train indices
|
528
|
+
if dataset_config.train_dates_weigths is not None:
|
529
|
+
assert dataset_config.train_size != "all"
|
530
|
+
if dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
|
531
|
+
# requested number of samples is train_size + val_known_size when using the split-from-train validation approach
|
532
|
+
assert dataset_config.val_known_size != "all"
|
533
|
+
num_samples = dataset_config.train_size + dataset_config.val_known_size
|
534
|
+
else:
|
535
|
+
num_samples = dataset_config.train_size
|
536
|
+
if num_samples > len(train_indices):
|
537
|
+
raise ValueError(f"Requested number of samples for weight sampling ({num_samples}) is larger than the number of available train samples ({len(train_indices)})")
|
538
|
+
train_indices = date_weight_sample_train_indices(dataset_config=dataset_config, train_indices=train_indices, num_samples=num_samples)
|
539
|
+
elif dataset_config.apps_selection == AppSelection.FIXED:
|
540
|
+
known_apps = dataset_config.apps_selection_fixed_known
|
541
|
+
unknown_apps = dataset_config.apps_selection_fixed_unknown
|
542
|
+
train_indices = np.zeros((0,3), dtype=np.int64)
|
543
|
+
train_unknown_indices = np.zeros((0,3), dtype=np.int64)
|
532
544
|
else:
|
533
|
-
|
534
|
-
|
535
|
-
|
545
|
+
raise ValueError("Either need train set or the fixed application selection")
|
546
|
+
# Initialize validation set
|
547
|
+
if dataset_config.need_val_set:
|
548
|
+
if dataset_config.val_approach == ValidationApproach.VALIDATION_DATES:
|
549
|
+
val_known_indices, val_unknown_indices, val_data_path = init_or_load_val_indices(dataset_config=dataset_config,
|
550
|
+
known_apps=known_apps,
|
551
|
+
unknown_apps=unknown_apps,
|
552
|
+
tables_app_enum=self._tables_app_enum,
|
536
553
|
disable_indices_cache=disable_indices_cache,)
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.val_known_size, stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
566
|
-
dataset_config.train_size = len(train_indices)
|
567
|
-
elif dataset_config.train_size == "all" and dataset_config.val_known_size == "all":
|
568
|
-
train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.train_val_split_fraction, stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
569
|
-
else:
|
570
|
-
if dataset_config.val_known_size != "all" and dataset_config.train_size != "all" and dataset_config.train_size + dataset_config.val_known_size > len(train_indices):
|
571
|
-
raise ValueError(f"Requested train size + validation size ({dataset_config.train_size + dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
|
572
|
-
if dataset_config.train_size != "all" and dataset_config.train_size > len(train_indices):
|
573
|
-
raise ValueError(f"Requested train size ({dataset_config.train_size}) is larger than the number of available train samples ({len(train_indices)})")
|
574
|
-
if dataset_config.val_known_size != "all" and dataset_config.val_known_size > len(train_indices):
|
575
|
-
raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
|
576
|
-
train_indices, val_known_indices = train_test_split(train_indices,
|
577
|
-
train_size=dataset_config.train_size if dataset_config.train_size != "all" else None,
|
578
|
-
test_size=dataset_config.val_known_size if dataset_config.val_known_size != "all" else None,
|
579
|
-
stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
580
|
-
elif dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
|
581
|
-
val_known_indices = np.empty((0,3), dtype=np.int64)
|
582
|
-
val_unknown_indices = np.empty((0,3), dtype=np.int64)
|
554
|
+
elif dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
|
555
|
+
train_val_rng = get_fresh_random_generator(dataset_config=dataset_config, section=RandomizedSection.TRAIN_VAL_SPLIT)
|
556
|
+
val_data_path = dataset_config._get_train_data_path()
|
557
|
+
val_unknown_indices = train_unknown_indices
|
558
|
+
train_labels = train_indices[:, INDICES_LABEL_POS]
|
559
|
+
if dataset_config.train_dates_weigths is not None:
|
560
|
+
assert dataset_config.val_known_size != "all"
|
561
|
+
# When weight sampling is used, val_known_size is kept but the resulting train size can be smaller due to no enough samples in some train dates
|
562
|
+
if dataset_config.val_known_size > len(train_indices):
|
563
|
+
raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples after weight sampling ({len(train_indices)})")
|
564
|
+
train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.val_known_size, stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
565
|
+
dataset_config.train_size = len(train_indices)
|
566
|
+
elif dataset_config.train_size == "all" and dataset_config.val_known_size == "all":
|
567
|
+
train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.train_val_split_fraction, stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
568
|
+
else:
|
569
|
+
if dataset_config.val_known_size != "all" and dataset_config.train_size != "all" and dataset_config.train_size + dataset_config.val_known_size > len(train_indices):
|
570
|
+
raise ValueError(f"Requested train size + validation size ({dataset_config.train_size + dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
|
571
|
+
if dataset_config.train_size != "all" and dataset_config.train_size > len(train_indices):
|
572
|
+
raise ValueError(f"Requested train size ({dataset_config.train_size}) is larger than the number of available train samples ({len(train_indices)})")
|
573
|
+
if dataset_config.val_known_size != "all" and dataset_config.val_known_size > len(train_indices):
|
574
|
+
raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
|
575
|
+
train_indices, val_known_indices = train_test_split(train_indices,
|
576
|
+
train_size=dataset_config.train_size if dataset_config.train_size != "all" else None,
|
577
|
+
test_size=dataset_config.val_known_size if dataset_config.val_known_size != "all" else None,
|
578
|
+
stratify=train_labels, shuffle=True, random_state=train_val_rng)
|
579
|
+
else:
|
580
|
+
val_known_indices = np.zeros((0,3), dtype=np.int64)
|
581
|
+
val_unknown_indices = np.zeros((0,3), dtype=np.int64)
|
583
582
|
val_data_path = None
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
583
|
+
# Initialize test set
|
584
|
+
if dataset_config.need_test_set:
|
585
|
+
test_known_indices, test_unknown_indices, test_data_path = init_or_load_test_indices(dataset_config=dataset_config,
|
586
|
+
known_apps=known_apps,
|
587
|
+
unknown_apps=unknown_apps,
|
588
|
+
tables_app_enum=self._tables_app_enum,
|
589
|
+
disable_indices_cache=disable_indices_cache,)
|
590
|
+
else:
|
591
|
+
test_known_indices = np.zeros((0,3), dtype=np.int64)
|
592
|
+
test_unknown_indices = np.zeros((0,3), dtype=np.int64)
|
593
|
+
test_data_path = None
|
594
|
+
# Fit scalers if needed
|
595
|
+
if (dataset_config.ppi_transform is not None and dataset_config.ppi_transform.needs_fitting or
|
596
|
+
dataset_config.flowstats_transform is not None and dataset_config.flowstats_transform.needs_fitting):
|
597
|
+
if not dataset_config.need_train_set:
|
598
|
+
raise ValueError("Train set is needed to fit the scalers. Provide pre-fitted scalers.")
|
599
|
+
fit_scalers(dataset_config=dataset_config, train_indices=train_indices)
|
590
600
|
# Subset dataset indices based on the selected sizes and compute application counts
|
591
601
|
dataset_indices = IndicesTuple(train_indices=train_indices, val_known_indices=val_known_indices, val_unknown_indices=val_unknown_indices, test_known_indices=test_known_indices, test_unknown_indices=test_unknown_indices)
|
592
602
|
dataset_indices = subset_and_sort_indices(dataset_config=dataset_config, dataset_indices=dataset_indices)
|
593
|
-
known_app_counts = compute_known_app_counts(dataset_indices=dataset_indices,
|
594
|
-
unknown_app_counts = compute_unknown_app_counts(dataset_indices=dataset_indices,
|
603
|
+
known_app_counts = compute_known_app_counts(dataset_indices=dataset_indices, tables_app_enum=self._tables_app_enum)
|
604
|
+
unknown_app_counts = compute_unknown_app_counts(dataset_indices=dataset_indices, tables_app_enum=self._tables_app_enum)
|
595
605
|
# Combine known and unknown test indicies to create a single dataloader
|
596
606
|
assert isinstance(dataset_config.test_unknown_size, int)
|
597
|
-
if dataset_config.test_unknown_size > 0 and len(
|
607
|
+
if dataset_config.test_unknown_size > 0 and len(unknown_apps) > 0:
|
598
608
|
test_combined_indices = np.concatenate((dataset_indices.test_known_indices, dataset_indices.test_unknown_indices))
|
599
609
|
else:
|
600
610
|
test_combined_indices = dataset_indices.test_known_indices
|
601
|
-
|
611
|
+
# Create encoder the class info structure
|
612
|
+
encoder = LabelEncoder().fit(known_apps)
|
613
|
+
encoder.classes_ = np.append(encoder.classes_, UNKNOWN_STR_LABEL)
|
614
|
+
class_info = create_class_info(servicemap=servicemap, encoder=encoder, known_apps=known_apps, unknown_apps=unknown_apps)
|
615
|
+
encode_labels_with_unknown_fn = partial(_encode_labels_with_unknown, encoder=encoder, class_info=class_info)
|
602
616
|
# Create train, validation, and test datasets
|
603
|
-
train_dataset =
|
604
|
-
|
605
|
-
|
606
|
-
indices=dataset_indices.train_indices,
|
607
|
-
flowstats_features=dataset_config.flowstats_features,
|
608
|
-
other_fields=self.dataset_config.other_fields,)
|
609
|
-
if dataset_config.no_test_set:
|
610
|
-
test_dataset = None
|
611
|
-
else:
|
612
|
-
assert test_data_path is not None
|
613
|
-
test_dataset = PyTablesDataset(
|
617
|
+
train_dataset = val_dataset = test_dataset = None
|
618
|
+
if dataset_config.need_train_set:
|
619
|
+
train_dataset = PyTablesDataset(
|
614
620
|
database_path=dataset_config.database_path,
|
615
|
-
tables_paths=dataset_config.
|
616
|
-
indices=
|
621
|
+
tables_paths=dataset_config._get_train_tables_paths(),
|
622
|
+
indices=dataset_indices.train_indices,
|
623
|
+
tables_app_enum=self._tables_app_enum,
|
624
|
+
tables_cat_enum=self._tables_cat_enum,
|
617
625
|
flowstats_features=dataset_config.flowstats_features,
|
626
|
+
flowstats_features_boolean=dataset_config.flowstats_features_boolean,
|
627
|
+
flowstats_features_phist=dataset_config.flowstats_features_phist,
|
618
628
|
other_fields=self.dataset_config.other_fields,
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
629
|
+
ppi_channels=dataset_config.get_ppi_channels(),
|
630
|
+
ppi_transform=dataset_config.ppi_transform,
|
631
|
+
flowstats_transform=dataset_config.flowstats_transform,
|
632
|
+
flowstats_phist_transform=dataset_config.flowstats_phist_transform,
|
633
|
+
target_transform=encode_labels_with_unknown_fn,
|
634
|
+
return_tensors=dataset_config.return_tensors,)
|
635
|
+
if dataset_config.need_val_set:
|
624
636
|
assert val_data_path is not None
|
625
637
|
val_dataset = PyTablesDataset(
|
626
638
|
database_path=dataset_config.database_path,
|
627
639
|
tables_paths=dataset_config._get_train_tables_paths(),
|
628
640
|
indices=dataset_indices.val_known_indices,
|
641
|
+
tables_app_enum=self._tables_app_enum,
|
642
|
+
tables_cat_enum=self._tables_cat_enum,
|
629
643
|
flowstats_features=dataset_config.flowstats_features,
|
644
|
+
flowstats_features_boolean=dataset_config.flowstats_features_boolean,
|
645
|
+
flowstats_features_phist=dataset_config.flowstats_features_phist,
|
630
646
|
other_fields=self.dataset_config.other_fields,
|
647
|
+
ppi_channels=dataset_config.get_ppi_channels(),
|
648
|
+
ppi_transform=dataset_config.ppi_transform,
|
649
|
+
flowstats_transform=dataset_config.flowstats_transform,
|
650
|
+
flowstats_phist_transform=dataset_config.flowstats_phist_transform,
|
651
|
+
target_transform=encode_labels_with_unknown_fn,
|
652
|
+
return_tensors=dataset_config.return_tensors,
|
631
653
|
preload=dataset_config.preload_val,
|
632
654
|
preload_blob=os.path.join(val_data_path, "preload", f"val_dataset-{dataset_config.val_known_size}.npz"),)
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
655
|
+
if dataset_config.need_test_set:
|
656
|
+
assert test_data_path is not None
|
657
|
+
test_dataset = PyTablesDataset(
|
658
|
+
database_path=dataset_config.database_path,
|
659
|
+
tables_paths=dataset_config._get_test_tables_paths(),
|
660
|
+
indices=test_combined_indices,
|
661
|
+
tables_app_enum=self._tables_app_enum,
|
662
|
+
tables_cat_enum=self._tables_cat_enum,
|
663
|
+
flowstats_features=dataset_config.flowstats_features,
|
664
|
+
flowstats_features_boolean=dataset_config.flowstats_features_boolean,
|
665
|
+
flowstats_features_phist=dataset_config.flowstats_features_phist,
|
666
|
+
other_fields=self.dataset_config.other_fields,
|
667
|
+
ppi_channels=dataset_config.get_ppi_channels(),
|
668
|
+
ppi_transform=dataset_config.ppi_transform,
|
669
|
+
flowstats_transform=dataset_config.flowstats_transform,
|
670
|
+
flowstats_phist_transform=dataset_config.flowstats_phist_transform,
|
671
|
+
target_transform=encode_labels_with_unknown_fn,
|
672
|
+
return_tensors=dataset_config.return_tensors,
|
673
|
+
preload=dataset_config.preload_test,
|
674
|
+
preload_blob=os.path.join(test_data_path, "preload", f"test_dataset-{dataset_config.test_known_size}-{dataset_config.test_unknown_size}.npz"),)
|
648
675
|
self.class_info = class_info
|
649
676
|
self.dataset_indices = dataset_indices
|
650
677
|
self.train_dataset = train_dataset
|
651
678
|
self.val_dataset = val_dataset
|
652
679
|
self.test_dataset = test_dataset
|
653
|
-
self.known_apps_database_enum = known_apps_database_enum
|
654
|
-
self.unknown_apps_database_enum = unknown_apps_database_enum
|
655
680
|
self.known_app_counts = known_app_counts
|
656
681
|
self.unknown_app_counts = unknown_app_counts
|
657
|
-
self.
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
self.ipt_scaler = ipt_scaler
|
662
|
-
self.flowstats_quantiles = flowstats_quantiles
|
682
|
+
self._collate_fn = collate_fn_simple
|
683
|
+
|
684
|
+
def _encode_labels_with_unknown(labels, encoder: LabelEncoder, class_info: ClassInfo):
|
685
|
+
return encoder.transform(np.where(np.isin(labels, class_info.known_apps), labels, UNKNOWN_STR_LABEL))
|