cesnet-datazoo 0.0.17__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,11 +15,14 @@ from sklearn.preprocessing import LabelEncoder
15
15
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
16
16
  from typing_extensions import assert_never
17
17
 
18
- from cesnet_datazoo.config import DataLoaderOrder, DatasetConfig, Scaler, ValidationApproach
19
- from cesnet_datazoo.constants import DATASET_SIZES, INDICES_LABEL_POS, SERVICEMAP_FILE
20
- from cesnet_datazoo.datasets.loaders import create_df_from_dataloader
18
+ from cesnet_datazoo.config import AppSelection, DataLoaderOrder, DatasetConfig, ValidationApproach
19
+ from cesnet_datazoo.constants import (APP_COLUMN, CATEGORY_COLUMN, DATASET_SIZES, INDICES_LABEL_POS,
20
+ SERVICEMAP_FILE, UNKNOWN_STR_LABEL)
21
+ from cesnet_datazoo.datasets.loaders import collate_fn_simple, create_df_from_dataloader
21
22
  from cesnet_datazoo.datasets.metadata.dataset_metadata import DatasetMetadata, load_metadata
22
23
  from cesnet_datazoo.datasets.statistics import compute_dataset_statistics
24
+ from cesnet_datazoo.pytables_data.apps_split import is_background_app
25
+ from cesnet_datazoo.pytables_data.data_scalers import fit_scalers
23
26
  from cesnet_datazoo.pytables_data.indices_setup import (IndicesTuple, compute_known_app_counts,
24
27
  compute_unknown_app_counts,
25
28
  date_weight_sample_train_indices,
@@ -27,10 +30,8 @@ from cesnet_datazoo.pytables_data.indices_setup import (IndicesTuple, compute_kn
27
30
  init_or_load_train_indices,
28
31
  init_or_load_val_indices,
29
32
  subset_and_sort_indices)
30
- from cesnet_datazoo.pytables_data.pytables_dataset import (PyTablesDataset, pytables_collate_fn,
31
- worker_init_fn)
33
+ from cesnet_datazoo.pytables_data.pytables_dataset import PyTablesDataset, worker_init_fn
32
34
  from cesnet_datazoo.utils.class_info import ClassInfo, create_class_info
33
- from cesnet_datazoo.pytables_data.data_scalers import fit_or_load_scalers
34
35
  from cesnet_datazoo.utils.download import resumable_download, simple_download
35
36
  from cesnet_datazoo.utils.random import RandomizedSection, get_fresh_random_generator
36
37
 
@@ -39,8 +40,7 @@ DATAFRAME_SAMPLES_WARNING_THRESHOLD = 20_000_000
39
40
 
40
41
  class CesnetDataset():
41
42
  """
42
- The main class for accessing CESNET datasets. It handles downloading, data preprocessing,
43
- train/validation/test splitting, and class selection. Access to data is provided through:
43
+ The main class for accessing CESNET datasets. It handles downloading, train/validation/test splitting, and class selection. Access to data is provided through:
44
44
 
45
45
  - Iterable PyTorch DataLoader for batch processing. See [using dataloaders][using-dataloaders] for more details.
46
46
  - Pandas DataFrame for loading the entire train, validation, or test set at once.
@@ -54,7 +54,7 @@ class CesnetDataset():
54
54
 
55
55
  1. Create an instance of the [dataset class][dataset-classes] with the desired size and data root. This will download the dataset if it has not already been downloaded.
56
56
  2. Create an instance of [`DatasetConfig`][config.DatasetConfig] and set it with [`set_dataset_config_and_initialize`][datasets.cesnet_dataset.CesnetDataset.set_dataset_config_and_initialize].
57
- This will initialize the dataset — select classes, split data into train/validation/test sets, and fit data scalers. All is done according to the provided configuration and is cached for later use.
57
+ This will initialize the dataset — select classes, split data into train/validation/test sets, and fit data scalers if needed. All is done according to the provided configuration and is cached for later use.
58
58
  3. Use [`get_train_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_train_dataloader] or [`get_train_df`][datasets.cesnet_dataset.CesnetDataset.get_train_df] to get training data for a classification model.
59
59
  4. Validate the model and perform the hyperparameter optimalization on [`get_val_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_val_dataloader] or [`get_val_df`][datasets.cesnet_dataset.CesnetDataset.get_val_df].
60
60
  5. Evaluate the model on [`get_test_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_test_dataloader] or [`get_test_df`][datasets.cesnet_dataset.CesnetDataset.get_test_df].
@@ -69,9 +69,10 @@ class CesnetDataset():
69
69
  database_filename: Name of the database file.
70
70
  database_path: Path to the database file.
71
71
  servicemap_path: Path to the servicemap file.
72
- statistics_path: Path to the dataset statistics.
72
+ statistics_path: Path to the dataset statistics folder.
73
73
  bucket_url: URL of the bucket where the database is stored.
74
74
  metadata: Additional [dataset metadata][metadata].
75
+ available_classes: List of all available classes in the dataset.
75
76
  available_dates: List of all available dates in the dataset.
76
77
  time_periods: Predefined time periods. Each time period is a list of dates.
77
78
  default_train_period_name: Default time period for training.
@@ -86,36 +87,30 @@ class CesnetDataset():
86
87
  train_dataset: Train set in the form of `PyTablesDataset` instance wrapping the PyTables database.
87
88
  val_dataset: Validation set in the form of `PyTablesDataset` instance wrapping the PyTables database.
88
89
  test_dataset: Test set in the form of `PyTablesDataset` instance wrapping the PyTables database.
89
- known_apps_database_enum: Dictionary that maps the database integer labels (different to those from `encoder`) of known applications to their names.
90
- unknown_apps_database_enum: Dictionary that maps the database integer labels (different to those from `encoder`) of unknown applications to their names.
91
90
  known_app_counts: Known application counts in the train, validation, and test sets.
92
91
  unknown_app_counts: Unknown application counts in the validation and test sets.
93
- collate_fn: Collate function used for creating batches in dataloaders.
94
- encoder: Scikit-learn [`LabelEncoder`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html) used to encode class names into integers. It is fitted during the initialization of the dataset.
95
- flowstats_scaler: Scaler for flow statistics. It is fitted during the initialization of the dataset.
96
- psizes_scaler: Scaler for packet sizes.
97
- ipt_scaler: Scaler for inter-packet times.
98
- flowstats_quantiles: Quantiles of flow statistics used for clipping.
99
92
  train_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for training.
100
93
  train_dataloader_sampler: Sampler used for iterating the training dataloader. Either [`RandomSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.RandomSampler) or [`SequentialSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.SequentialSampler).
94
+ train_dataloader_drop_last: Whether to drop the last incomplete batch when iterating the training dataloader.
101
95
  val_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for validation.
102
96
  test_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for testing.
103
97
  """
104
- name: str
105
- size: str
106
98
  data_root: str
99
+ size: str
100
+ silent: bool = False
101
+
102
+ name: str
107
103
  database_filename: str
108
104
  database_path: str
109
105
  servicemap_path: str
110
106
  statistics_path: str
111
107
  bucket_url: str
112
108
  metadata: DatasetMetadata
109
+ available_classes: list[str]
113
110
  available_dates: list[str]
114
111
  time_periods: dict[str, list[str]]
115
112
  default_train_period_name: str
116
113
  default_test_period_name: str
117
- time_periods_gen: bool = False
118
- silent: bool = False
119
114
 
120
115
  dataset_config: Optional[DatasetConfig] = None
121
116
  class_info: Optional[ClassInfo] = None
@@ -123,25 +118,19 @@ class CesnetDataset():
123
118
  train_dataset: Optional[PyTablesDataset] = None
124
119
  val_dataset: Optional[PyTablesDataset] = None
125
120
  test_dataset: Optional[PyTablesDataset] = None
126
- known_apps_database_enum: Optional[dict[int, str]] = None
127
- unknown_apps_database_enum: Optional[dict[int, str]] = None
128
121
  known_app_counts: Optional[pd.DataFrame] = None
129
122
  unknown_app_counts: Optional[pd.DataFrame] = None
130
-
131
- collate_fn: Optional[Callable] = None
132
- encoder: Optional[LabelEncoder] = None
133
- flowstats_scaler: Scaler = None
134
- psizes_scaler: Scaler = None
135
- ipt_scaler: Scaler = None
136
- flowstats_quantiles: Optional[np.ndarray] = None
137
-
138
123
  train_dataloader: Optional[DataLoader] = None
139
124
  train_dataloader_sampler: Optional[Sampler] = None
140
125
  train_dataloader_drop_last: bool = True
141
126
  val_dataloader: Optional[DataLoader] = None
142
127
  test_dataloader: Optional[DataLoader] = None
143
128
 
144
- def __init__(self, data_root: str, size: str = "S", skip_dataset_read_at_init: bool = False, silent: bool = False) -> None:
129
+ _collate_fn: Optional[Callable] = None
130
+ _tables_app_enum: dict[int, str]
131
+ _tables_cat_enum: dict[int, str]
132
+
133
+ def __init__(self, data_root: str, size: str = "S", database_checks_at_init: bool = False, silent: bool = False) -> None:
145
134
  self.silent = silent
146
135
  self.metadata = load_metadata(self.name)
147
136
  self.size = size
@@ -159,24 +148,31 @@ class CesnetDataset():
159
148
  os.makedirs(self.data_root)
160
149
  if not self._is_downloaded():
161
150
  self._download()
162
- if not skip_dataset_read_at_init:
151
+ if database_checks_at_init:
163
152
  with tb.open_file(self.database_path, mode="r") as database:
164
153
  tables_paths = list(map(lambda x: x._v_pathname, iter(database.get_node(f"/flows"))))
165
154
  num_samples = 0
166
155
  for p in tables_paths:
167
- num_samples += len(database.get_node(p))
156
+ table = database.get_node(p)
157
+ assert isinstance(table, tb.Table)
158
+ if self._tables_app_enum != {v: k for k, v in dict(table.get_enum(APP_COLUMN)).items()}:
159
+ raise ValueError(f"Found mismatch between _tables_app_enum and the PyTables database enum in table {p}. Please report this issue.")
160
+ if self._tables_cat_enum != {v: k for k, v in dict(table.get_enum(CATEGORY_COLUMN)).items()}:
161
+ raise ValueError(f"Found mismatch between _tables_cat_enum and the PyTables database enum in table {p}. Please report this issue.")
162
+ num_samples += len(table)
168
163
  if self.size == "ORIG" and num_samples != self.metadata.available_samples:
169
164
  raise ValueError(f"Expected {self.metadata.available_samples} samples, but got {num_samples} in the database. Please delete the data root folder, update cesnet-datazoo, and redownload the dataset.")
170
165
  if self.size != "ORIG" and num_samples != DATASET_SIZES[self.size]:
171
166
  raise ValueError(f"Expected {DATASET_SIZES[self.size]} samples, but got {num_samples} in the database. Please delete the data root folder, update cesnet-datazoo, and redownload the dataset.")
172
- self.available_dates = list(map(lambda x: x.removeprefix("/flows/D"), tables_paths))
173
- else:
174
- self.available_dates = []
175
- if self.time_periods_gen:
176
- self._generate_time_periods()
167
+ if self.available_dates != list(map(lambda x: x.removeprefix("/flows/D"), tables_paths)):
168
+ raise ValueError(f"Found mismatch between available_dates and the dates available in the PyTables database. Please report this issue.")
177
169
  # Add all available dates as single date time periods
178
170
  for d in self.available_dates:
179
171
  self.time_periods[d] = [d]
172
+ available_applications = sorted([app for app in pd.read_csv(self.servicemap_path, index_col="Tag").index if not is_background_app(app)])
173
+ if len(available_applications) != self.metadata.application_count:
174
+ raise ValueError(f"Found {len(available_applications)} applications in the servicemap (omitting background traffic classes), but expected {self.metadata.application_count}. Please report this issue.")
175
+ self.available_classes = available_applications + self.metadata.background_traffic_classes
180
176
 
181
177
  def set_dataset_config_and_initialize(self, dataset_config: DatasetConfig, disable_indices_cache: bool = False) -> None:
182
178
  """
@@ -208,6 +204,8 @@ class CesnetDataset():
208
204
  """
209
205
  if self.dataset_config is None:
210
206
  raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting train dataloader")
207
+ if not self.dataset_config.need_train_set:
208
+ raise ValueError("Train dataloader is not available when need_train_set is false")
211
209
  assert self.train_dataset
212
210
  if self.train_dataloader:
213
211
  return self.train_dataloader
@@ -230,7 +228,7 @@ class CesnetDataset():
230
228
  self.train_dataset,
231
229
  num_workers=self.dataset_config.train_workers,
232
230
  worker_init_fn=worker_init_fn,
233
- collate_fn=self.collate_fn,
231
+ collate_fn=self._collate_fn,
234
232
  persistent_workers=self.dataset_config.train_workers > 0,
235
233
  batch_size=None,
236
234
  sampler=batch_sampler,)
@@ -255,8 +253,8 @@ class CesnetDataset():
255
253
  """
256
254
  if self.dataset_config is None:
257
255
  raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting validaion dataloader")
258
- if self.dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
259
- raise ValueError("Validation dataloader is not available when using no-validation")
256
+ if not self.dataset_config.need_val_set:
257
+ raise ValueError("Validation dataloader is not available when need_val_set is false")
260
258
  assert self.val_dataset is not None
261
259
  if self.val_dataloader:
262
260
  return self.val_dataloader
@@ -265,7 +263,7 @@ class CesnetDataset():
265
263
  self.val_dataset,
266
264
  num_workers=self.dataset_config.val_workers,
267
265
  worker_init_fn=worker_init_fn,
268
- collate_fn=self.collate_fn,
266
+ collate_fn=self._collate_fn,
269
267
  persistent_workers=self.dataset_config.val_workers > 0,
270
268
  batch_size=None,
271
269
  sampler=batch_sampler,)
@@ -294,8 +292,8 @@ class CesnetDataset():
294
292
  """
295
293
  if self.dataset_config is None:
296
294
  raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting test dataloader")
297
- if self.dataset_config.no_test_set:
298
- raise ValueError("Test dataloader is not available when no_test_set is true")
295
+ if not self.dataset_config.need_test_set:
296
+ raise ValueError("Test dataloader is not available when need_test_set is false")
299
297
  assert self.test_dataset is not None
300
298
  if self.test_dataloader:
301
299
  return self.test_dataloader
@@ -304,7 +302,7 @@ class CesnetDataset():
304
302
  self.test_dataset,
305
303
  num_workers=self.dataset_config.test_workers,
306
304
  worker_init_fn=worker_init_fn,
307
- collate_fn=self.collate_fn,
305
+ collate_fn=self._collate_fn,
308
306
  persistent_workers=False,
309
307
  batch_size=None,
310
308
  sampler=batch_sampler,)
@@ -336,7 +334,7 @@ class CesnetDataset():
336
334
  Returns:
337
335
  Train data as a dataframe.
338
336
  """
339
- self._check_before_dataframe()
337
+ self._check_before_dataframe(check_train=True)
340
338
  assert self.dataset_config is not None and self.train_dataset is not None
341
339
  if len(self.train_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
342
340
  warnings.warn(f"Train set has ({len(self.train_dataset)} samples), consider using get_train_dataloader() instead")
@@ -369,7 +367,7 @@ class CesnetDataset():
369
367
  Returns:
370
368
  Validation data as a dataframe.
371
369
  """
372
- self._check_before_dataframe(check_no_val=True)
370
+ self._check_before_dataframe(check_val=True)
373
371
  assert self.dataset_config is not None and self.val_dataset is not None
374
372
  if len(self.val_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
375
373
  warnings.warn(f"Validation set has ({len(self.val_dataset)} samples), consider using get_val_dataloader() instead")
@@ -398,7 +396,7 @@ class CesnetDataset():
398
396
  Returns:
399
397
  Test data as a dataframe.
400
398
  """
401
- self._check_before_dataframe(check_no_test=True)
399
+ self._check_before_dataframe(check_test=True)
402
400
  assert self.dataset_config is not None and self.test_dataset is not None
403
401
  if len(self.test_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
404
402
  warnings.warn(f"Test set has ({len(self.test_dataset)} samples), consider using get_test_dataloader() instead")
@@ -436,12 +434,18 @@ class CesnetDataset():
436
434
  batch_size: Number of samples per batch for loading data.
437
435
  disabled_apps: List of applications to exclude from the statistics.
438
436
  """
439
- flowstats_features = self.metadata.flowstats_features + self.metadata.packet_histogram_features + self.metadata.tcp_features
437
+ if disabled_apps:
438
+ bad_disabled_apps = [a for a in disabled_apps if a not in self.available_classes]
439
+ if len(bad_disabled_apps) > 0:
440
+ raise ValueError(f"Bad applications in disabled_apps {bad_disabled_apps}. Use applications available in dataset.available_classes")
440
441
  if not os.path.exists(self.statistics_path):
441
442
  os.mkdir(self.statistics_path)
442
443
  compute_dataset_statistics(database_path=self.database_path,
444
+ tables_app_enum=self._tables_app_enum,
445
+ tables_cat_enum=self._tables_cat_enum,
443
446
  output_dir=self.statistics_path,
444
- flowstats_features=flowstats_features,
447
+ packet_histograms=self.metadata.packet_histograms,
448
+ flowstats_features_boolean=self.metadata.flowstats_features_boolean,
445
449
  protocol=self.metadata.protocol,
446
450
  extra_fields=not self.name.startswith("CESNET-TLS22"),
447
451
  disabled_apps=disabled_apps if disabled_apps is not None else [],
@@ -489,174 +493,193 @@ class CesnetDataset():
489
493
  self.train_dataset = None
490
494
  self.val_dataset = None
491
495
  self.test_dataset = None
492
- self.known_apps_database_enum = None
493
- self.unknown_apps_database_enum = None
494
496
  self.known_app_counts = None
495
497
  self.unknown_app_counts = None
496
-
497
- self.collate_fn = None
498
- self.encoder = None
499
- self.flowstats_scaler = None
500
- self.psizes_scaler = None
501
- self.ipt_scaler = None
502
- self.flowstats_quantiles = None
503
-
504
498
  self.train_dataloader = None
505
499
  self.train_dataloader_sampler = None
506
500
  self.train_dataloader_drop_last = True
507
501
  self.val_dataloader = None
508
502
  self.test_dataloader = None
503
+ self._collate_fn = None
509
504
 
510
- def _check_before_dataframe(self, check_no_val: bool = False, check_no_test: bool = False) -> None:
505
+ def _check_before_dataframe(self, check_train: bool = False, check_val: bool = False, check_test: bool = False) -> None:
511
506
  if self.dataset_config is None:
512
507
  raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting a dataframe")
513
- if self.dataset_config.return_torch:
514
- raise ValueError("Dataframes are not available when return_torch is set. Use a dataloader instead.")
515
- if check_no_val and self.dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
516
- raise ValueError("Validation dataframe is not available when using no-validation")
517
- if check_no_test and self.dataset_config.no_test_set:
518
- raise ValueError("Test dataframe is not available when no_test_set is true")
508
+ if self.dataset_config.return_tensors:
509
+ raise ValueError("Dataframes are not available when return_tensors is set. Use a dataloader instead.")
510
+ if check_train and not self.dataset_config.need_train_set:
511
+ raise ValueError("Train dataframe is not available when need_train_set is false")
512
+ if check_val and not self.dataset_config.need_val_set:
513
+ raise ValueError("Validation dataframe is not available when need_val_set is false")
514
+ if check_test and not self.dataset_config.need_test_set:
515
+ raise ValueError("Test dataframe is not available when need_test_set is false")
519
516
 
520
517
  def _initialize_train_val_test(self, disable_indices_cache: bool = False) -> None:
521
518
  assert self.dataset_config is not None
522
519
  dataset_config = self.dataset_config
523
520
  servicemap = pd.read_csv(dataset_config.servicemap_path, index_col="Tag")
524
- # Initialize train and test indices
525
- train_indices, train_unknown_indices, encoder, known_apps_database_enum, unknown_apps_database_enum = init_or_load_train_indices(dataset_config=dataset_config,
526
- servicemap=servicemap,
527
- disable_indices_cache=disable_indices_cache,)
528
- if self.dataset_config.no_test_set:
529
- test_known_indices = np.empty((0,3), dtype=np.int64)
530
- test_unknown_indices = np.empty((0,3), dtype=np.int64)
531
- test_data_path = None
521
+ # Initialize train set
522
+ if dataset_config.need_train_set:
523
+ train_indices, train_unknown_indices, known_apps, unknown_apps = init_or_load_train_indices(dataset_config=dataset_config,
524
+ tables_app_enum=self._tables_app_enum,
525
+ servicemap=servicemap,
526
+ disable_indices_cache=disable_indices_cache,)
527
+ # Date weight sampling of train indices
528
+ if dataset_config.train_dates_weigths is not None:
529
+ assert dataset_config.train_size != "all"
530
+ if dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
531
+ # requested number of samples is train_size + val_known_size when using the split-from-train validation approach
532
+ assert dataset_config.val_known_size != "all"
533
+ num_samples = dataset_config.train_size + dataset_config.val_known_size
534
+ else:
535
+ num_samples = dataset_config.train_size
536
+ if num_samples > len(train_indices):
537
+ raise ValueError(f"Requested number of samples for weight sampling ({num_samples}) is larger than the number of available train samples ({len(train_indices)})")
538
+ train_indices = date_weight_sample_train_indices(dataset_config=dataset_config, train_indices=train_indices, num_samples=num_samples)
539
+ elif dataset_config.apps_selection == AppSelection.FIXED:
540
+ known_apps = dataset_config.apps_selection_fixed_known
541
+ unknown_apps = dataset_config.apps_selection_fixed_unknown
542
+ train_indices = np.zeros((0,3), dtype=np.int64)
543
+ train_unknown_indices = np.zeros((0,3), dtype=np.int64)
532
544
  else:
533
- test_known_indices, test_unknown_indices, test_data_path = init_or_load_test_indices(dataset_config=dataset_config,
534
- known_apps_database_enum=known_apps_database_enum,
535
- unknown_apps_database_enum=unknown_apps_database_enum,
545
+ raise ValueError("Either need train set or the fixed application selection")
546
+ # Initialize validation set
547
+ if dataset_config.need_val_set:
548
+ if dataset_config.val_approach == ValidationApproach.VALIDATION_DATES:
549
+ val_known_indices, val_unknown_indices, val_data_path = init_or_load_val_indices(dataset_config=dataset_config,
550
+ known_apps=known_apps,
551
+ unknown_apps=unknown_apps,
552
+ tables_app_enum=self._tables_app_enum,
536
553
  disable_indices_cache=disable_indices_cache,)
537
- # Date weight sampling of train indices
538
- if dataset_config.train_dates_weigths is not None:
539
- assert dataset_config.train_size != "all"
540
- if dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
541
- # requested number of samples is train_size + val_known_size when using the split-from-train validation approach
542
- assert dataset_config.val_known_size != "all"
543
- num_samples = dataset_config.train_size + dataset_config.val_known_size
544
- else:
545
- num_samples = dataset_config.train_size
546
- if num_samples > len(train_indices):
547
- raise ValueError(f"Requested number of samples for weight sampling ({num_samples}) is larger than the number of available train samples ({len(train_indices)})")
548
- train_indices = date_weight_sample_train_indices(dataset_config=dataset_config, train_indices=train_indices, num_samples=num_samples)
549
- # Obtain validation indices based on the selected approach
550
- if dataset_config.val_approach == ValidationApproach.VALIDATION_DATES:
551
- val_known_indices, val_unknown_indices, val_data_path = init_or_load_val_indices(dataset_config=dataset_config,
552
- known_apps_database_enum=known_apps_database_enum,
553
- unknown_apps_database_enum=unknown_apps_database_enum,
554
- disable_indices_cache=disable_indices_cache,)
555
- elif dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
556
- train_val_rng = get_fresh_random_generator(dataset_config=dataset_config, section=RandomizedSection.TRAIN_VAL_SPLIT)
557
- val_data_path = dataset_config._get_train_data_path()
558
- val_unknown_indices = train_unknown_indices
559
- train_labels = train_indices[:, INDICES_LABEL_POS]
560
- if dataset_config.train_dates_weigths is not None:
561
- assert dataset_config.val_known_size != "all"
562
- # When weight sampling is used, val_known_size is kept but the resulting train size can be smaller due to no enough samples in some train dates
563
- if dataset_config.val_known_size > len(train_indices):
564
- raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples after weight sampling ({len(train_indices)})")
565
- train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.val_known_size, stratify=train_labels, shuffle=True, random_state=train_val_rng)
566
- dataset_config.train_size = len(train_indices)
567
- elif dataset_config.train_size == "all" and dataset_config.val_known_size == "all":
568
- train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.train_val_split_fraction, stratify=train_labels, shuffle=True, random_state=train_val_rng)
569
- else:
570
- if dataset_config.val_known_size != "all" and dataset_config.train_size != "all" and dataset_config.train_size + dataset_config.val_known_size > len(train_indices):
571
- raise ValueError(f"Requested train size + validation size ({dataset_config.train_size + dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
572
- if dataset_config.train_size != "all" and dataset_config.train_size > len(train_indices):
573
- raise ValueError(f"Requested train size ({dataset_config.train_size}) is larger than the number of available train samples ({len(train_indices)})")
574
- if dataset_config.val_known_size != "all" and dataset_config.val_known_size > len(train_indices):
575
- raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
576
- train_indices, val_known_indices = train_test_split(train_indices,
577
- train_size=dataset_config.train_size if dataset_config.train_size != "all" else None,
578
- test_size=dataset_config.val_known_size if dataset_config.val_known_size != "all" else None,
579
- stratify=train_labels, shuffle=True, random_state=train_val_rng)
580
- elif dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
581
- val_known_indices = np.empty((0,3), dtype=np.int64)
582
- val_unknown_indices = np.empty((0,3), dtype=np.int64)
554
+ elif dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
555
+ train_val_rng = get_fresh_random_generator(dataset_config=dataset_config, section=RandomizedSection.TRAIN_VAL_SPLIT)
556
+ val_data_path = dataset_config._get_train_data_path()
557
+ val_unknown_indices = train_unknown_indices
558
+ train_labels = train_indices[:, INDICES_LABEL_POS]
559
+ if dataset_config.train_dates_weigths is not None:
560
+ assert dataset_config.val_known_size != "all"
561
+ # When weight sampling is used, val_known_size is kept but the resulting train size can be smaller due to no enough samples in some train dates
562
+ if dataset_config.val_known_size > len(train_indices):
563
+ raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples after weight sampling ({len(train_indices)})")
564
+ train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.val_known_size, stratify=train_labels, shuffle=True, random_state=train_val_rng)
565
+ dataset_config.train_size = len(train_indices)
566
+ elif dataset_config.train_size == "all" and dataset_config.val_known_size == "all":
567
+ train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.train_val_split_fraction, stratify=train_labels, shuffle=True, random_state=train_val_rng)
568
+ else:
569
+ if dataset_config.val_known_size != "all" and dataset_config.train_size != "all" and dataset_config.train_size + dataset_config.val_known_size > len(train_indices):
570
+ raise ValueError(f"Requested train size + validation size ({dataset_config.train_size + dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
571
+ if dataset_config.train_size != "all" and dataset_config.train_size > len(train_indices):
572
+ raise ValueError(f"Requested train size ({dataset_config.train_size}) is larger than the number of available train samples ({len(train_indices)})")
573
+ if dataset_config.val_known_size != "all" and dataset_config.val_known_size > len(train_indices):
574
+ raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
575
+ train_indices, val_known_indices = train_test_split(train_indices,
576
+ train_size=dataset_config.train_size if dataset_config.train_size != "all" else None,
577
+ test_size=dataset_config.val_known_size if dataset_config.val_known_size != "all" else None,
578
+ stratify=train_labels, shuffle=True, random_state=train_val_rng)
579
+ else:
580
+ val_known_indices = np.zeros((0,3), dtype=np.int64)
581
+ val_unknown_indices = np.zeros((0,3), dtype=np.int64)
583
582
  val_data_path = None
584
- else: assert_never(dataset_config.val_approach)
585
-
586
- # Create class info
587
- class_info = create_class_info(servicemap=servicemap, encoder=encoder, known_apps_database_enum=known_apps_database_enum, unknown_apps_database_enum=unknown_apps_database_enum)
588
- # Load or fit data scalers
589
- flowstats_scaler, psizes_scaler, ipt_scaler, flowstats_quantiles = fit_or_load_scalers(dataset_config=dataset_config, train_indices=train_indices)
583
+ # Initialize test set
584
+ if dataset_config.need_test_set:
585
+ test_known_indices, test_unknown_indices, test_data_path = init_or_load_test_indices(dataset_config=dataset_config,
586
+ known_apps=known_apps,
587
+ unknown_apps=unknown_apps,
588
+ tables_app_enum=self._tables_app_enum,
589
+ disable_indices_cache=disable_indices_cache,)
590
+ else:
591
+ test_known_indices = np.zeros((0,3), dtype=np.int64)
592
+ test_unknown_indices = np.zeros((0,3), dtype=np.int64)
593
+ test_data_path = None
594
+ # Fit scalers if needed
595
+ if (dataset_config.ppi_transform is not None and dataset_config.ppi_transform.needs_fitting or
596
+ dataset_config.flowstats_transform is not None and dataset_config.flowstats_transform.needs_fitting):
597
+ if not dataset_config.need_train_set:
598
+ raise ValueError("Train set is needed to fit the scalers. Provide pre-fitted scalers.")
599
+ fit_scalers(dataset_config=dataset_config, train_indices=train_indices)
590
600
  # Subset dataset indices based on the selected sizes and compute application counts
591
601
  dataset_indices = IndicesTuple(train_indices=train_indices, val_known_indices=val_known_indices, val_unknown_indices=val_unknown_indices, test_known_indices=test_known_indices, test_unknown_indices=test_unknown_indices)
592
602
  dataset_indices = subset_and_sort_indices(dataset_config=dataset_config, dataset_indices=dataset_indices)
593
- known_app_counts = compute_known_app_counts(dataset_indices=dataset_indices, database_enum=known_apps_database_enum)
594
- unknown_app_counts = compute_unknown_app_counts(dataset_indices=dataset_indices, database_enum=unknown_apps_database_enum)
603
+ known_app_counts = compute_known_app_counts(dataset_indices=dataset_indices, tables_app_enum=self._tables_app_enum)
604
+ unknown_app_counts = compute_unknown_app_counts(dataset_indices=dataset_indices, tables_app_enum=self._tables_app_enum)
595
605
  # Combine known and unknown test indicies to create a single dataloader
596
606
  assert isinstance(dataset_config.test_unknown_size, int)
597
- if dataset_config.test_unknown_size > 0 and len(unknown_apps_database_enum) > 0:
607
+ if dataset_config.test_unknown_size > 0 and len(unknown_apps) > 0:
598
608
  test_combined_indices = np.concatenate((dataset_indices.test_known_indices, dataset_indices.test_unknown_indices))
599
609
  else:
600
610
  test_combined_indices = dataset_indices.test_known_indices
601
-
611
+ # Create encoder the class info structure
612
+ encoder = LabelEncoder().fit(known_apps)
613
+ encoder.classes_ = np.append(encoder.classes_, UNKNOWN_STR_LABEL)
614
+ class_info = create_class_info(servicemap=servicemap, encoder=encoder, known_apps=known_apps, unknown_apps=unknown_apps)
615
+ encode_labels_with_unknown_fn = partial(_encode_labels_with_unknown, encoder=encoder, class_info=class_info)
602
616
  # Create train, validation, and test datasets
603
- train_dataset = PyTablesDataset(
604
- database_path=dataset_config.database_path,
605
- tables_paths=dataset_config._get_train_tables_paths(),
606
- indices=dataset_indices.train_indices,
607
- flowstats_features=dataset_config.flowstats_features,
608
- other_fields=self.dataset_config.other_fields,)
609
- if dataset_config.no_test_set:
610
- test_dataset = None
611
- else:
612
- assert test_data_path is not None
613
- test_dataset = PyTablesDataset(
617
+ train_dataset = val_dataset = test_dataset = None
618
+ if dataset_config.need_train_set:
619
+ train_dataset = PyTablesDataset(
614
620
  database_path=dataset_config.database_path,
615
- tables_paths=dataset_config._get_test_tables_paths(),
616
- indices=test_combined_indices,
621
+ tables_paths=dataset_config._get_train_tables_paths(),
622
+ indices=dataset_indices.train_indices,
623
+ tables_app_enum=self._tables_app_enum,
624
+ tables_cat_enum=self._tables_cat_enum,
617
625
  flowstats_features=dataset_config.flowstats_features,
626
+ flowstats_features_boolean=dataset_config.flowstats_features_boolean,
627
+ flowstats_features_phist=dataset_config.flowstats_features_phist,
618
628
  other_fields=self.dataset_config.other_fields,
619
- preload=dataset_config.preload_test,
620
- preload_blob=os.path.join(test_data_path, "preload", f"test_dataset-{dataset_config.test_known_size}-{dataset_config.test_unknown_size}.npz"),)
621
- if dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
622
- val_dataset = None
623
- else:
629
+ ppi_channels=dataset_config.get_ppi_channels(),
630
+ ppi_transform=dataset_config.ppi_transform,
631
+ flowstats_transform=dataset_config.flowstats_transform,
632
+ flowstats_phist_transform=dataset_config.flowstats_phist_transform,
633
+ target_transform=encode_labels_with_unknown_fn,
634
+ return_tensors=dataset_config.return_tensors,)
635
+ if dataset_config.need_val_set:
624
636
  assert val_data_path is not None
625
637
  val_dataset = PyTablesDataset(
626
638
  database_path=dataset_config.database_path,
627
639
  tables_paths=dataset_config._get_train_tables_paths(),
628
640
  indices=dataset_indices.val_known_indices,
641
+ tables_app_enum=self._tables_app_enum,
642
+ tables_cat_enum=self._tables_cat_enum,
629
643
  flowstats_features=dataset_config.flowstats_features,
644
+ flowstats_features_boolean=dataset_config.flowstats_features_boolean,
645
+ flowstats_features_phist=dataset_config.flowstats_features_phist,
630
646
  other_fields=self.dataset_config.other_fields,
647
+ ppi_channels=dataset_config.get_ppi_channels(),
648
+ ppi_transform=dataset_config.ppi_transform,
649
+ flowstats_transform=dataset_config.flowstats_transform,
650
+ flowstats_phist_transform=dataset_config.flowstats_phist_transform,
651
+ target_transform=encode_labels_with_unknown_fn,
652
+ return_tensors=dataset_config.return_tensors,
631
653
  preload=dataset_config.preload_val,
632
654
  preload_blob=os.path.join(val_data_path, "preload", f"val_dataset-{dataset_config.val_known_size}.npz"),)
633
- collate_fn = partial(pytables_collate_fn,
634
- flowstats_scaler=flowstats_scaler,
635
- flowstats_quantiles=flowstats_quantiles,
636
- psizes_scaler=psizes_scaler,
637
- psizes_max=dataset_config.psizes_max,
638
- ipt_scaler=ipt_scaler,
639
- ipt_min=dataset_config.ipt_min,
640
- ipt_max=dataset_config.ipt_max,
641
- use_push_flags=dataset_config.use_push_flags,
642
- use_packet_histograms=dataset_config.use_packet_histograms,
643
- normalize_packet_histograms=dataset_config.normalize_packet_histograms,
644
- zero_ppi_start=dataset_config.zero_ppi_start,
645
- encoder=encoder,
646
- known_apps=class_info.known_apps,
647
- return_torch=dataset_config.return_torch,)
655
+ if dataset_config.need_test_set:
656
+ assert test_data_path is not None
657
+ test_dataset = PyTablesDataset(
658
+ database_path=dataset_config.database_path,
659
+ tables_paths=dataset_config._get_test_tables_paths(),
660
+ indices=test_combined_indices,
661
+ tables_app_enum=self._tables_app_enum,
662
+ tables_cat_enum=self._tables_cat_enum,
663
+ flowstats_features=dataset_config.flowstats_features,
664
+ flowstats_features_boolean=dataset_config.flowstats_features_boolean,
665
+ flowstats_features_phist=dataset_config.flowstats_features_phist,
666
+ other_fields=self.dataset_config.other_fields,
667
+ ppi_channels=dataset_config.get_ppi_channels(),
668
+ ppi_transform=dataset_config.ppi_transform,
669
+ flowstats_transform=dataset_config.flowstats_transform,
670
+ flowstats_phist_transform=dataset_config.flowstats_phist_transform,
671
+ target_transform=encode_labels_with_unknown_fn,
672
+ return_tensors=dataset_config.return_tensors,
673
+ preload=dataset_config.preload_test,
674
+ preload_blob=os.path.join(test_data_path, "preload", f"test_dataset-{dataset_config.test_known_size}-{dataset_config.test_unknown_size}.npz"),)
648
675
  self.class_info = class_info
649
676
  self.dataset_indices = dataset_indices
650
677
  self.train_dataset = train_dataset
651
678
  self.val_dataset = val_dataset
652
679
  self.test_dataset = test_dataset
653
- self.known_apps_database_enum = known_apps_database_enum
654
- self.unknown_apps_database_enum = unknown_apps_database_enum
655
680
  self.known_app_counts = known_app_counts
656
681
  self.unknown_app_counts = unknown_app_counts
657
- self.collate_fn = collate_fn
658
- self.encoder = encoder
659
- self.flowstats_scaler = flowstats_scaler
660
- self.psizes_scaler = psizes_scaler
661
- self.ipt_scaler = ipt_scaler
662
- self.flowstats_quantiles = flowstats_quantiles
682
+ self._collate_fn = collate_fn_simple
683
+
684
+ def _encode_labels_with_unknown(labels, encoder: LabelEncoder, class_info: ClassInfo):
685
+ return encoder.transform(np.where(np.isin(labels, class_info.known_apps), labels, UNKNOWN_STR_LABEL))