cesnet-datazoo 0.0.16__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,11 +15,14 @@ from sklearn.preprocessing import LabelEncoder
15
15
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
16
16
  from typing_extensions import assert_never
17
17
 
18
- from cesnet_datazoo.config import DataLoaderOrder, DatasetConfig, Scaler, ValidationApproach
19
- from cesnet_datazoo.constants import DATASET_SIZES, INDICES_LABEL_POS, SERVICEMAP_FILE
20
- from cesnet_datazoo.datasets.loaders import create_df_from_dataloader
18
+ from cesnet_datazoo.config import AppSelection, DataLoaderOrder, DatasetConfig, ValidationApproach
19
+ from cesnet_datazoo.constants import (APP_COLUMN, CATEGORY_COLUMN, DATASET_SIZES, INDICES_LABEL_POS,
20
+ SERVICEMAP_FILE, UNKNOWN_STR_LABEL)
21
+ from cesnet_datazoo.datasets.loaders import collate_fn_simple, create_df_from_dataloader
21
22
  from cesnet_datazoo.datasets.metadata.dataset_metadata import DatasetMetadata, load_metadata
22
23
  from cesnet_datazoo.datasets.statistics import compute_dataset_statistics
24
+ from cesnet_datazoo.pytables_data.apps_split import is_background_app
25
+ from cesnet_datazoo.pytables_data.data_scalers import fit_scalers
23
26
  from cesnet_datazoo.pytables_data.indices_setup import (IndicesTuple, compute_known_app_counts,
24
27
  compute_unknown_app_counts,
25
28
  date_weight_sample_train_indices,
@@ -27,8 +30,7 @@ from cesnet_datazoo.pytables_data.indices_setup import (IndicesTuple, compute_kn
27
30
  init_or_load_train_indices,
28
31
  init_or_load_val_indices,
29
32
  subset_and_sort_indices)
30
- from cesnet_datazoo.pytables_data.pytables_dataset import (PyTablesDataset, fit_or_load_scalers,
31
- pytables_collate_fn, worker_init_fn)
33
+ from cesnet_datazoo.pytables_data.pytables_dataset import PyTablesDataset, worker_init_fn
32
34
  from cesnet_datazoo.utils.class_info import ClassInfo, create_class_info
33
35
  from cesnet_datazoo.utils.download import resumable_download, simple_download
34
36
  from cesnet_datazoo.utils.random import RandomizedSection, get_fresh_random_generator
@@ -38,8 +40,7 @@ DATAFRAME_SAMPLES_WARNING_THRESHOLD = 20_000_000
38
40
 
39
41
  class CesnetDataset():
40
42
  """
41
- The main class for accessing CESNET datasets. It handles downloading, data preprocessing,
42
- train/validation/test splitting, and class selection. Access to data is provided through:
43
+ The main class for accessing CESNET datasets. It handles downloading, train/validation/test splitting, and class selection. Access to data is provided through:
43
44
 
44
45
  - Iterable PyTorch DataLoader for batch processing. See [using dataloaders][using-dataloaders] for more details.
45
46
  - Pandas DataFrame for loading the entire train, validation, or test set at once.
@@ -53,7 +54,7 @@ class CesnetDataset():
53
54
 
54
55
  1. Create an instance of the [dataset class][dataset-classes] with the desired size and data root. This will download the dataset if it has not already been downloaded.
55
56
  2. Create an instance of [`DatasetConfig`][config.DatasetConfig] and set it with [`set_dataset_config_and_initialize`][datasets.cesnet_dataset.CesnetDataset.set_dataset_config_and_initialize].
56
- This will initialize the dataset — select classes, split data into train/validation/test sets, and fit data scalers. All is done according to the provided configuration and is cached for later use.
57
+ This will initialize the dataset — select classes, split data into train/validation/test sets, and fit data scalers if needed. All is done according to the provided configuration and is cached for later use.
57
58
  3. Use [`get_train_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_train_dataloader] or [`get_train_df`][datasets.cesnet_dataset.CesnetDataset.get_train_df] to get training data for a classification model.
58
59
  4. Validate the model and perform the hyperparameter optimalization on [`get_val_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_val_dataloader] or [`get_val_df`][datasets.cesnet_dataset.CesnetDataset.get_val_df].
59
60
  5. Evaluate the model on [`get_test_dataloader`][datasets.cesnet_dataset.CesnetDataset.get_test_dataloader] or [`get_test_df`][datasets.cesnet_dataset.CesnetDataset.get_test_df].
@@ -68,9 +69,10 @@ class CesnetDataset():
68
69
  database_filename: Name of the database file.
69
70
  database_path: Path to the database file.
70
71
  servicemap_path: Path to the servicemap file.
71
- statistics_path: Path to the dataset statistics.
72
+ statistics_path: Path to the dataset statistics folder.
72
73
  bucket_url: URL of the bucket where the database is stored.
73
74
  metadata: Additional [dataset metadata][metadata].
75
+ available_classes: List of all available classes in the dataset.
74
76
  available_dates: List of all available dates in the dataset.
75
77
  time_periods: Predefined time periods. Each time period is a list of dates.
76
78
  default_train_period_name: Default time period for training.
@@ -85,35 +87,30 @@ class CesnetDataset():
85
87
  train_dataset: Train set in the form of `PyTablesDataset` instance wrapping the PyTables database.
86
88
  val_dataset: Validation set in the form of `PyTablesDataset` instance wrapping the PyTables database.
87
89
  test_dataset: Test set in the form of `PyTablesDataset` instance wrapping the PyTables database.
88
- known_apps_database_enum: Dictionary that maps the database integer labels (different to those from `encoder`) of known applications to their names.
89
- unknown_apps_database_enum: Dictionary that maps the database integer labels (different to those from `encoder`) of unknown applications to their names.
90
90
  known_app_counts: Known application counts in the train, validation, and test sets.
91
91
  unknown_app_counts: Unknown application counts in the validation and test sets.
92
- collate_fn: Collate function used for creating batches in dataloaders.
93
- encoder: Scikit-learn [`LabelEncoder`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html) used to encode class names into integers. It is fitted during the initialization of the dataset.
94
- flowstats_scaler: Scaler for flow statistics. It is fitted during the initialization of the dataset.
95
- psizes_scaler: Scaler for packet sizes.
96
- ipt_scaler: Scaler for inter-packet times.
97
92
  train_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for training.
98
93
  train_dataloader_sampler: Sampler used for iterating the training dataloader. Either [`RandomSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.RandomSampler) or [`SequentialSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.SequentialSampler).
94
+ train_dataloader_drop_last: Whether to drop the last incomplete batch when iterating the training dataloader.
99
95
  val_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for validation.
100
96
  test_dataloader: Iterable PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) for testing.
101
97
  """
102
- name: str
103
- size: str
104
98
  data_root: str
99
+ size: str
100
+ silent: bool = False
101
+
102
+ name: str
105
103
  database_filename: str
106
104
  database_path: str
107
105
  servicemap_path: str
108
106
  statistics_path: str
109
107
  bucket_url: str
110
108
  metadata: DatasetMetadata
109
+ available_classes: list[str]
111
110
  available_dates: list[str]
112
111
  time_periods: dict[str, list[str]]
113
112
  default_train_period_name: str
114
113
  default_test_period_name: str
115
- time_periods_gen: bool = False
116
- silent: bool = False
117
114
 
118
115
  dataset_config: Optional[DatasetConfig] = None
119
116
  class_info: Optional[ClassInfo] = None
@@ -121,24 +118,19 @@ class CesnetDataset():
121
118
  train_dataset: Optional[PyTablesDataset] = None
122
119
  val_dataset: Optional[PyTablesDataset] = None
123
120
  test_dataset: Optional[PyTablesDataset] = None
124
- known_apps_database_enum: Optional[dict[int, str]] = None
125
- unknown_apps_database_enum: Optional[dict[int, str]] = None
126
121
  known_app_counts: Optional[pd.DataFrame] = None
127
122
  unknown_app_counts: Optional[pd.DataFrame] = None
128
-
129
- collate_fn: Optional[Callable] = None
130
- encoder: Optional[LabelEncoder] = None
131
- flowstats_scaler: Scaler = None
132
- psizes_scaler: Scaler = None
133
- ipt_scaler: Scaler = None
134
-
135
123
  train_dataloader: Optional[DataLoader] = None
136
124
  train_dataloader_sampler: Optional[Sampler] = None
137
125
  train_dataloader_drop_last: bool = True
138
126
  val_dataloader: Optional[DataLoader] = None
139
127
  test_dataloader: Optional[DataLoader] = None
140
128
 
141
- def __init__(self, data_root: str, size: str = "S", skip_dataset_read_at_init: bool = False, silent: bool = False) -> None:
129
+ _collate_fn: Optional[Callable] = None
130
+ _tables_app_enum: dict[int, str]
131
+ _tables_cat_enum: dict[int, str]
132
+
133
+ def __init__(self, data_root: str, size: str = "S", database_checks_at_init: bool = False, silent: bool = False) -> None:
142
134
  self.silent = silent
143
135
  self.metadata = load_metadata(self.name)
144
136
  self.size = size
@@ -156,24 +148,31 @@ class CesnetDataset():
156
148
  os.makedirs(self.data_root)
157
149
  if not self._is_downloaded():
158
150
  self._download()
159
- if not skip_dataset_read_at_init:
151
+ if database_checks_at_init:
160
152
  with tb.open_file(self.database_path, mode="r") as database:
161
153
  tables_paths = list(map(lambda x: x._v_pathname, iter(database.get_node(f"/flows"))))
162
154
  num_samples = 0
163
155
  for p in tables_paths:
164
- num_samples += len(database.get_node(p))
156
+ table = database.get_node(p)
157
+ assert isinstance(table, tb.Table)
158
+ if self._tables_app_enum != {v: k for k, v in dict(table.get_enum(APP_COLUMN)).items()}:
159
+ raise ValueError(f"Found mismatch between _tables_app_enum and the PyTables database enum in table {p}. Please report this issue.")
160
+ if self._tables_cat_enum != {v: k for k, v in dict(table.get_enum(CATEGORY_COLUMN)).items()}:
161
+ raise ValueError(f"Found mismatch between _tables_cat_enum and the PyTables database enum in table {p}. Please report this issue.")
162
+ num_samples += len(table)
165
163
  if self.size == "ORIG" and num_samples != self.metadata.available_samples:
166
164
  raise ValueError(f"Expected {self.metadata.available_samples} samples, but got {num_samples} in the database. Please delete the data root folder, update cesnet-datazoo, and redownload the dataset.")
167
165
  if self.size != "ORIG" and num_samples != DATASET_SIZES[self.size]:
168
166
  raise ValueError(f"Expected {DATASET_SIZES[self.size]} samples, but got {num_samples} in the database. Please delete the data root folder, update cesnet-datazoo, and redownload the dataset.")
169
- self.available_dates = list(map(lambda x: x.removeprefix("/flows/D"), tables_paths))
170
- else:
171
- self.available_dates = []
172
- if self.time_periods_gen:
173
- self._generate_time_periods()
167
+ if self.available_dates != list(map(lambda x: x.removeprefix("/flows/D"), tables_paths)):
168
+ raise ValueError(f"Found mismatch between available_dates and the dates available in the PyTables database. Please report this issue.")
174
169
  # Add all available dates as single date time periods
175
170
  for d in self.available_dates:
176
171
  self.time_periods[d] = [d]
172
+ available_applications = sorted([app for app in pd.read_csv(self.servicemap_path, index_col="Tag").index if not is_background_app(app)])
173
+ if len(available_applications) != self.metadata.application_count:
174
+ raise ValueError(f"Found {len(available_applications)} applications in the servicemap (omitting background traffic classes), but expected {self.metadata.application_count}. Please report this issue.")
175
+ self.available_classes = available_applications + self.metadata.background_traffic_classes
177
176
 
178
177
  def set_dataset_config_and_initialize(self, dataset_config: DatasetConfig, disable_indices_cache: bool = False) -> None:
179
178
  """
@@ -205,6 +204,8 @@ class CesnetDataset():
205
204
  """
206
205
  if self.dataset_config is None:
207
206
  raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting train dataloader")
207
+ if not self.dataset_config.need_train_set:
208
+ raise ValueError("Train dataloader is not available when need_train_set is false")
208
209
  assert self.train_dataset
209
210
  if self.train_dataloader:
210
211
  return self.train_dataloader
@@ -227,7 +228,7 @@ class CesnetDataset():
227
228
  self.train_dataset,
228
229
  num_workers=self.dataset_config.train_workers,
229
230
  worker_init_fn=worker_init_fn,
230
- collate_fn=self.collate_fn,
231
+ collate_fn=self._collate_fn,
231
232
  persistent_workers=self.dataset_config.train_workers > 0,
232
233
  batch_size=None,
233
234
  sampler=batch_sampler,)
@@ -252,8 +253,8 @@ class CesnetDataset():
252
253
  """
253
254
  if self.dataset_config is None:
254
255
  raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting validaion dataloader")
255
- if self.dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
256
- raise ValueError("Validation dataloader is not available when using no-validation")
256
+ if not self.dataset_config.need_val_set:
257
+ raise ValueError("Validation dataloader is not available when need_val_set is false")
257
258
  assert self.val_dataset is not None
258
259
  if self.val_dataloader:
259
260
  return self.val_dataloader
@@ -262,7 +263,7 @@ class CesnetDataset():
262
263
  self.val_dataset,
263
264
  num_workers=self.dataset_config.val_workers,
264
265
  worker_init_fn=worker_init_fn,
265
- collate_fn=self.collate_fn,
266
+ collate_fn=self._collate_fn,
266
267
  persistent_workers=self.dataset_config.val_workers > 0,
267
268
  batch_size=None,
268
269
  sampler=batch_sampler,)
@@ -291,8 +292,8 @@ class CesnetDataset():
291
292
  """
292
293
  if self.dataset_config is None:
293
294
  raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting test dataloader")
294
- if self.dataset_config.no_test_set:
295
- raise ValueError("Test dataloader is not available when no_test_set is true")
295
+ if not self.dataset_config.need_test_set:
296
+ raise ValueError("Test dataloader is not available when need_test_set is false")
296
297
  assert self.test_dataset is not None
297
298
  if self.test_dataloader:
298
299
  return self.test_dataloader
@@ -301,7 +302,7 @@ class CesnetDataset():
301
302
  self.test_dataset,
302
303
  num_workers=self.dataset_config.test_workers,
303
304
  worker_init_fn=worker_init_fn,
304
- collate_fn=self.collate_fn,
305
+ collate_fn=self._collate_fn,
305
306
  persistent_workers=False,
306
307
  batch_size=None,
307
308
  sampler=batch_sampler,)
@@ -333,7 +334,7 @@ class CesnetDataset():
333
334
  Returns:
334
335
  Train data as a dataframe.
335
336
  """
336
- self._check_before_dataframe()
337
+ self._check_before_dataframe(check_train=True)
337
338
  assert self.dataset_config is not None and self.train_dataset is not None
338
339
  if len(self.train_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
339
340
  warnings.warn(f"Train set has ({len(self.train_dataset)} samples), consider using get_train_dataloader() instead")
@@ -366,7 +367,7 @@ class CesnetDataset():
366
367
  Returns:
367
368
  Validation data as a dataframe.
368
369
  """
369
- self._check_before_dataframe(check_no_val=True)
370
+ self._check_before_dataframe(check_val=True)
370
371
  assert self.dataset_config is not None and self.val_dataset is not None
371
372
  if len(self.val_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
372
373
  warnings.warn(f"Validation set has ({len(self.val_dataset)} samples), consider using get_val_dataloader() instead")
@@ -395,7 +396,7 @@ class CesnetDataset():
395
396
  Returns:
396
397
  Test data as a dataframe.
397
398
  """
398
- self._check_before_dataframe(check_no_test=True)
399
+ self._check_before_dataframe(check_test=True)
399
400
  assert self.dataset_config is not None and self.test_dataset is not None
400
401
  if len(self.test_dataset) > DATAFRAME_SAMPLES_WARNING_THRESHOLD:
401
402
  warnings.warn(f"Test set has ({len(self.test_dataset)} samples), consider using get_test_dataloader() instead")
@@ -433,12 +434,18 @@ class CesnetDataset():
433
434
  batch_size: Number of samples per batch for loading data.
434
435
  disabled_apps: List of applications to exclude from the statistics.
435
436
  """
436
- flowstats_features = self.metadata.flowstats_features + self.metadata.packet_histogram_features + self.metadata.tcp_features
437
+ if disabled_apps:
438
+ bad_disabled_apps = [a for a in disabled_apps if a not in self.available_classes]
439
+ if len(bad_disabled_apps) > 0:
440
+ raise ValueError(f"Bad applications in disabled_apps {bad_disabled_apps}. Use applications available in dataset.available_classes")
437
441
  if not os.path.exists(self.statistics_path):
438
442
  os.mkdir(self.statistics_path)
439
443
  compute_dataset_statistics(database_path=self.database_path,
444
+ tables_app_enum=self._tables_app_enum,
445
+ tables_cat_enum=self._tables_cat_enum,
440
446
  output_dir=self.statistics_path,
441
- flowstats_features=flowstats_features,
447
+ packet_histograms=self.metadata.packet_histograms,
448
+ flowstats_features_boolean=self.metadata.flowstats_features_boolean,
442
449
  protocol=self.metadata.protocol,
443
450
  extra_fields=not self.name.startswith("CESNET-TLS22"),
444
451
  disabled_apps=disabled_apps if disabled_apps is not None else [],
@@ -486,172 +493,193 @@ class CesnetDataset():
486
493
  self.train_dataset = None
487
494
  self.val_dataset = None
488
495
  self.test_dataset = None
489
- self.known_apps_database_enum = None
490
- self.unknown_apps_database_enum = None
491
496
  self.known_app_counts = None
492
497
  self.unknown_app_counts = None
493
-
494
- self.collate_fn = None
495
- self.encoder = None
496
- self.flowstats_scaler = None
497
- self.psizes_scaler = None
498
- self.ipt_scaler = None
499
-
500
498
  self.train_dataloader = None
501
499
  self.train_dataloader_sampler = None
502
500
  self.train_dataloader_drop_last = True
503
501
  self.val_dataloader = None
504
502
  self.test_dataloader = None
503
+ self._collate_fn = None
505
504
 
506
- def _check_before_dataframe(self, check_no_val: bool = False, check_no_test: bool = False) -> None:
505
+ def _check_before_dataframe(self, check_train: bool = False, check_val: bool = False, check_test: bool = False) -> None:
507
506
  if self.dataset_config is None:
508
507
  raise ValueError("Dataset is not initialized, use set_dataset_config_and_initialize() before getting a dataframe")
509
- if self.dataset_config.return_torch:
510
- raise ValueError("Dataframes are not available when return_torch is set. Use a dataloader instead.")
511
- if check_no_val and self.dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
512
- raise ValueError("Validation dataframe is not available when using no-validation")
513
- if check_no_test and self.dataset_config.no_test_set:
514
- raise ValueError("Test dataframe is not available when no_test_set is true")
508
+ if self.dataset_config.return_tensors:
509
+ raise ValueError("Dataframes are not available when return_tensors is set. Use a dataloader instead.")
510
+ if check_train and not self.dataset_config.need_train_set:
511
+ raise ValueError("Train dataframe is not available when need_train_set is false")
512
+ if check_val and not self.dataset_config.need_val_set:
513
+ raise ValueError("Validation dataframe is not available when need_val_set is false")
514
+ if check_test and not self.dataset_config.need_test_set:
515
+ raise ValueError("Test dataframe is not available when need_test_set is false")
515
516
 
516
517
  def _initialize_train_val_test(self, disable_indices_cache: bool = False) -> None:
517
518
  assert self.dataset_config is not None
518
519
  dataset_config = self.dataset_config
519
520
  servicemap = pd.read_csv(dataset_config.servicemap_path, index_col="Tag")
520
- # Initialize train and test indices
521
- train_indices, train_unknown_indices, encoder, known_apps_database_enum, unknown_apps_database_enum = init_or_load_train_indices(dataset_config=dataset_config,
522
- servicemap=servicemap,
523
- disable_indices_cache=disable_indices_cache,)
524
- if self.dataset_config.no_test_set:
525
- test_known_indices = np.empty((0,3), dtype=np.int64)
526
- test_unknown_indices = np.empty((0,3), dtype=np.int64)
527
- test_data_path = None
521
+ # Initialize train set
522
+ if dataset_config.need_train_set:
523
+ train_indices, train_unknown_indices, known_apps, unknown_apps = init_or_load_train_indices(dataset_config=dataset_config,
524
+ tables_app_enum=self._tables_app_enum,
525
+ servicemap=servicemap,
526
+ disable_indices_cache=disable_indices_cache,)
527
+ # Date weight sampling of train indices
528
+ if dataset_config.train_dates_weigths is not None:
529
+ assert dataset_config.train_size != "all"
530
+ if dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
531
+ # requested number of samples is train_size + val_known_size when using the split-from-train validation approach
532
+ assert dataset_config.val_known_size != "all"
533
+ num_samples = dataset_config.train_size + dataset_config.val_known_size
534
+ else:
535
+ num_samples = dataset_config.train_size
536
+ if num_samples > len(train_indices):
537
+ raise ValueError(f"Requested number of samples for weight sampling ({num_samples}) is larger than the number of available train samples ({len(train_indices)})")
538
+ train_indices = date_weight_sample_train_indices(dataset_config=dataset_config, train_indices=train_indices, num_samples=num_samples)
539
+ elif dataset_config.apps_selection == AppSelection.FIXED:
540
+ known_apps = dataset_config.apps_selection_fixed_known
541
+ unknown_apps = dataset_config.apps_selection_fixed_unknown
542
+ train_indices = np.zeros((0,3), dtype=np.int64)
543
+ train_unknown_indices = np.zeros((0,3), dtype=np.int64)
528
544
  else:
529
- test_known_indices, test_unknown_indices, test_data_path = init_or_load_test_indices(dataset_config=dataset_config,
530
- known_apps_database_enum=known_apps_database_enum,
531
- unknown_apps_database_enum=unknown_apps_database_enum,
545
+ raise ValueError("Either need train set or the fixed application selection")
546
+ # Initialize validation set
547
+ if dataset_config.need_val_set:
548
+ if dataset_config.val_approach == ValidationApproach.VALIDATION_DATES:
549
+ val_known_indices, val_unknown_indices, val_data_path = init_or_load_val_indices(dataset_config=dataset_config,
550
+ known_apps=known_apps,
551
+ unknown_apps=unknown_apps,
552
+ tables_app_enum=self._tables_app_enum,
532
553
  disable_indices_cache=disable_indices_cache,)
533
- # Date weight sampling of train indices
534
- if dataset_config.train_dates_weigths is not None:
535
- assert dataset_config.train_size != "all"
536
- if dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
537
- # requested number of samples is train_size + val_known_size when using the split-from-train validation approach
538
- assert dataset_config.val_known_size != "all"
539
- num_samples = dataset_config.train_size + dataset_config.val_known_size
540
- else:
541
- num_samples = dataset_config.train_size
542
- if num_samples > len(train_indices):
543
- raise ValueError(f"Requested number of samples for weight sampling ({num_samples}) is larger than the number of available train samples ({len(train_indices)})")
544
- train_indices = date_weight_sample_train_indices(dataset_config=dataset_config, train_indices=train_indices, num_samples=num_samples)
545
- # Obtain validation indices based on the selected approach
546
- if dataset_config.val_approach == ValidationApproach.VALIDATION_DATES:
547
- val_known_indices, val_unknown_indices, val_data_path = init_or_load_val_indices(dataset_config=dataset_config,
548
- known_apps_database_enum=known_apps_database_enum,
549
- unknown_apps_database_enum=unknown_apps_database_enum,
550
- disable_indices_cache=disable_indices_cache,)
551
- elif dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
552
- train_val_rng = get_fresh_random_generator(dataset_config=dataset_config, section=RandomizedSection.TRAIN_VAL_SPLIT)
553
- val_data_path = dataset_config._get_train_data_path()
554
- val_unknown_indices = train_unknown_indices
555
- train_labels = train_indices[:, INDICES_LABEL_POS]
556
- if dataset_config.train_dates_weigths is not None:
557
- assert dataset_config.val_known_size != "all"
558
- # When weight sampling is used, val_known_size is kept but the resulting train size can be smaller due to no enough samples in some train dates
559
- if dataset_config.val_known_size > len(train_indices):
560
- raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples after weight sampling ({len(train_indices)})")
561
- train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.val_known_size, stratify=train_labels, shuffle=True, random_state=train_val_rng)
562
- dataset_config.train_size = len(train_indices)
563
- elif dataset_config.train_size == "all" and dataset_config.val_known_size == "all":
564
- train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.train_val_split_fraction, stratify=train_labels, shuffle=True, random_state=train_val_rng)
565
- else:
566
- if dataset_config.val_known_size != "all" and dataset_config.train_size != "all" and dataset_config.train_size + dataset_config.val_known_size > len(train_indices):
567
- raise ValueError(f"Requested train size + validation size ({dataset_config.train_size + dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
568
- if dataset_config.train_size != "all" and dataset_config.train_size > len(train_indices):
569
- raise ValueError(f"Requested train size ({dataset_config.train_size}) is larger than the number of available train samples ({len(train_indices)})")
570
- if dataset_config.val_known_size != "all" and dataset_config.val_known_size > len(train_indices):
571
- raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
572
- train_indices, val_known_indices = train_test_split(train_indices,
573
- train_size=dataset_config.train_size if dataset_config.train_size != "all" else None,
574
- test_size=dataset_config.val_known_size if dataset_config.val_known_size != "all" else None,
575
- stratify=train_labels, shuffle=True, random_state=train_val_rng)
576
- elif dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
577
- val_known_indices = np.empty((0,3), dtype=np.int64)
578
- val_unknown_indices = np.empty((0,3), dtype=np.int64)
554
+ elif dataset_config.val_approach == ValidationApproach.SPLIT_FROM_TRAIN:
555
+ train_val_rng = get_fresh_random_generator(dataset_config=dataset_config, section=RandomizedSection.TRAIN_VAL_SPLIT)
556
+ val_data_path = dataset_config._get_train_data_path()
557
+ val_unknown_indices = train_unknown_indices
558
+ train_labels = train_indices[:, INDICES_LABEL_POS]
559
+ if dataset_config.train_dates_weigths is not None:
560
+ assert dataset_config.val_known_size != "all"
561
+ # When weight sampling is used, val_known_size is kept but the resulting train size can be smaller due to no enough samples in some train dates
562
+ if dataset_config.val_known_size > len(train_indices):
563
+ raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples after weight sampling ({len(train_indices)})")
564
+ train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.val_known_size, stratify=train_labels, shuffle=True, random_state=train_val_rng)
565
+ dataset_config.train_size = len(train_indices)
566
+ elif dataset_config.train_size == "all" and dataset_config.val_known_size == "all":
567
+ train_indices, val_known_indices = train_test_split(train_indices, test_size=dataset_config.train_val_split_fraction, stratify=train_labels, shuffle=True, random_state=train_val_rng)
568
+ else:
569
+ if dataset_config.val_known_size != "all" and dataset_config.train_size != "all" and dataset_config.train_size + dataset_config.val_known_size > len(train_indices):
570
+ raise ValueError(f"Requested train size + validation size ({dataset_config.train_size + dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
571
+ if dataset_config.train_size != "all" and dataset_config.train_size > len(train_indices):
572
+ raise ValueError(f"Requested train size ({dataset_config.train_size}) is larger than the number of available train samples ({len(train_indices)})")
573
+ if dataset_config.val_known_size != "all" and dataset_config.val_known_size > len(train_indices):
574
+ raise ValueError(f"Requested validation size ({dataset_config.val_known_size}) is larger than the number of available train samples ({len(train_indices)})")
575
+ train_indices, val_known_indices = train_test_split(train_indices,
576
+ train_size=dataset_config.train_size if dataset_config.train_size != "all" else None,
577
+ test_size=dataset_config.val_known_size if dataset_config.val_known_size != "all" else None,
578
+ stratify=train_labels, shuffle=True, random_state=train_val_rng)
579
+ else:
580
+ val_known_indices = np.zeros((0,3), dtype=np.int64)
581
+ val_unknown_indices = np.zeros((0,3), dtype=np.int64)
579
582
  val_data_path = None
580
- else: assert_never(dataset_config.val_approach)
581
-
582
- # Create class info
583
- class_info = create_class_info(servicemap=servicemap, encoder=encoder, known_apps_database_enum=known_apps_database_enum, unknown_apps_database_enum=unknown_apps_database_enum)
584
- # Load or fit data scalers
585
- flowstats_scaler, flowstats_quantiles, ipt_scaler, psizes_scaler = fit_or_load_scalers(dataset_config=dataset_config, train_indices=train_indices)
583
+ # Initialize test set
584
+ if dataset_config.need_test_set:
585
+ test_known_indices, test_unknown_indices, test_data_path = init_or_load_test_indices(dataset_config=dataset_config,
586
+ known_apps=known_apps,
587
+ unknown_apps=unknown_apps,
588
+ tables_app_enum=self._tables_app_enum,
589
+ disable_indices_cache=disable_indices_cache,)
590
+ else:
591
+ test_known_indices = np.zeros((0,3), dtype=np.int64)
592
+ test_unknown_indices = np.zeros((0,3), dtype=np.int64)
593
+ test_data_path = None
594
+ # Fit scalers if needed
595
+ if (dataset_config.ppi_transform is not None and dataset_config.ppi_transform.needs_fitting or
596
+ dataset_config.flowstats_transform is not None and dataset_config.flowstats_transform.needs_fitting):
597
+ if not dataset_config.need_train_set:
598
+ raise ValueError("Train set is needed to fit the scalers. Provide pre-fitted scalers.")
599
+ fit_scalers(dataset_config=dataset_config, train_indices=train_indices)
586
600
  # Subset dataset indices based on the selected sizes and compute application counts
587
601
  dataset_indices = IndicesTuple(train_indices=train_indices, val_known_indices=val_known_indices, val_unknown_indices=val_unknown_indices, test_known_indices=test_known_indices, test_unknown_indices=test_unknown_indices)
588
602
  dataset_indices = subset_and_sort_indices(dataset_config=dataset_config, dataset_indices=dataset_indices)
589
- known_app_counts = compute_known_app_counts(dataset_indices=dataset_indices, database_enum=known_apps_database_enum)
590
- unknown_app_counts = compute_unknown_app_counts(dataset_indices=dataset_indices, database_enum=unknown_apps_database_enum)
603
+ known_app_counts = compute_known_app_counts(dataset_indices=dataset_indices, tables_app_enum=self._tables_app_enum)
604
+ unknown_app_counts = compute_unknown_app_counts(dataset_indices=dataset_indices, tables_app_enum=self._tables_app_enum)
591
605
  # Combine known and unknown test indicies to create a single dataloader
592
606
  assert isinstance(dataset_config.test_unknown_size, int)
593
- if dataset_config.test_unknown_size > 0 and len(unknown_apps_database_enum) > 0:
607
+ if dataset_config.test_unknown_size > 0 and len(unknown_apps) > 0:
594
608
  test_combined_indices = np.concatenate((dataset_indices.test_known_indices, dataset_indices.test_unknown_indices))
595
609
  else:
596
610
  test_combined_indices = dataset_indices.test_known_indices
597
-
611
+ # Create encoder the class info structure
612
+ encoder = LabelEncoder().fit(known_apps)
613
+ encoder.classes_ = np.append(encoder.classes_, UNKNOWN_STR_LABEL)
614
+ class_info = create_class_info(servicemap=servicemap, encoder=encoder, known_apps=known_apps, unknown_apps=unknown_apps)
615
+ encode_labels_with_unknown_fn = partial(_encode_labels_with_unknown, encoder=encoder, class_info=class_info)
598
616
  # Create train, validation, and test datasets
599
- train_dataset = PyTablesDataset(
600
- database_path=dataset_config.database_path,
601
- tables_paths=dataset_config._get_train_tables_paths(),
602
- indices=dataset_indices.train_indices,
603
- flowstats_features=dataset_config.flowstats_features,
604
- other_fields=self.dataset_config.other_fields,)
605
- if dataset_config.no_test_set:
606
- test_dataset = None
607
- else:
608
- assert test_data_path is not None
609
- test_dataset = PyTablesDataset(
617
+ train_dataset = val_dataset = test_dataset = None
618
+ if dataset_config.need_train_set:
619
+ train_dataset = PyTablesDataset(
610
620
  database_path=dataset_config.database_path,
611
- tables_paths=dataset_config._get_test_tables_paths(),
612
- indices=test_combined_indices,
621
+ tables_paths=dataset_config._get_train_tables_paths(),
622
+ indices=dataset_indices.train_indices,
623
+ tables_app_enum=self._tables_app_enum,
624
+ tables_cat_enum=self._tables_cat_enum,
613
625
  flowstats_features=dataset_config.flowstats_features,
626
+ flowstats_features_boolean=dataset_config.flowstats_features_boolean,
627
+ flowstats_features_phist=dataset_config.flowstats_features_phist,
614
628
  other_fields=self.dataset_config.other_fields,
615
- preload=dataset_config.preload_test,
616
- preload_blob=os.path.join(test_data_path, "preload", f"test_dataset-{dataset_config.test_known_size}-{dataset_config.test_unknown_size}.npz"),)
617
- if dataset_config.val_approach == ValidationApproach.NO_VALIDATION:
618
- val_dataset = None
619
- else:
629
+ ppi_channels=dataset_config.get_ppi_channels(),
630
+ ppi_transform=dataset_config.ppi_transform,
631
+ flowstats_transform=dataset_config.flowstats_transform,
632
+ flowstats_phist_transform=dataset_config.flowstats_phist_transform,
633
+ target_transform=encode_labels_with_unknown_fn,
634
+ return_tensors=dataset_config.return_tensors,)
635
+ if dataset_config.need_val_set:
620
636
  assert val_data_path is not None
621
637
  val_dataset = PyTablesDataset(
622
638
  database_path=dataset_config.database_path,
623
639
  tables_paths=dataset_config._get_train_tables_paths(),
624
640
  indices=dataset_indices.val_known_indices,
641
+ tables_app_enum=self._tables_app_enum,
642
+ tables_cat_enum=self._tables_cat_enum,
625
643
  flowstats_features=dataset_config.flowstats_features,
644
+ flowstats_features_boolean=dataset_config.flowstats_features_boolean,
645
+ flowstats_features_phist=dataset_config.flowstats_features_phist,
626
646
  other_fields=self.dataset_config.other_fields,
647
+ ppi_channels=dataset_config.get_ppi_channels(),
648
+ ppi_transform=dataset_config.ppi_transform,
649
+ flowstats_transform=dataset_config.flowstats_transform,
650
+ flowstats_phist_transform=dataset_config.flowstats_phist_transform,
651
+ target_transform=encode_labels_with_unknown_fn,
652
+ return_tensors=dataset_config.return_tensors,
627
653
  preload=dataset_config.preload_val,
628
654
  preload_blob=os.path.join(val_data_path, "preload", f"val_dataset-{dataset_config.val_known_size}.npz"),)
629
- collate_fn = partial(pytables_collate_fn,
630
- flowstats_scaler=flowstats_scaler,
631
- flowstats_quantiles=flowstats_quantiles,
632
- psizes_scaler=psizes_scaler,
633
- psizes_max=dataset_config.psizes_max,
634
- ipt_scaler=ipt_scaler,
635
- ipt_min=dataset_config.ipt_min,
636
- ipt_max=dataset_config.ipt_max,
637
- use_push_flags=dataset_config.use_push_flags,
638
- use_packet_histograms=dataset_config.use_packet_histograms,
639
- normalize_packet_histograms=dataset_config.normalize_packet_histograms,
640
- zero_ppi_start=dataset_config.zero_ppi_start,
641
- encoder=encoder,
642
- known_apps=class_info.known_apps,
643
- return_torch=dataset_config.return_torch,)
655
+ if dataset_config.need_test_set:
656
+ assert test_data_path is not None
657
+ test_dataset = PyTablesDataset(
658
+ database_path=dataset_config.database_path,
659
+ tables_paths=dataset_config._get_test_tables_paths(),
660
+ indices=test_combined_indices,
661
+ tables_app_enum=self._tables_app_enum,
662
+ tables_cat_enum=self._tables_cat_enum,
663
+ flowstats_features=dataset_config.flowstats_features,
664
+ flowstats_features_boolean=dataset_config.flowstats_features_boolean,
665
+ flowstats_features_phist=dataset_config.flowstats_features_phist,
666
+ other_fields=self.dataset_config.other_fields,
667
+ ppi_channels=dataset_config.get_ppi_channels(),
668
+ ppi_transform=dataset_config.ppi_transform,
669
+ flowstats_transform=dataset_config.flowstats_transform,
670
+ flowstats_phist_transform=dataset_config.flowstats_phist_transform,
671
+ target_transform=encode_labels_with_unknown_fn,
672
+ return_tensors=dataset_config.return_tensors,
673
+ preload=dataset_config.preload_test,
674
+ preload_blob=os.path.join(test_data_path, "preload", f"test_dataset-{dataset_config.test_known_size}-{dataset_config.test_unknown_size}.npz"),)
644
675
  self.class_info = class_info
645
676
  self.dataset_indices = dataset_indices
646
677
  self.train_dataset = train_dataset
647
678
  self.val_dataset = val_dataset
648
679
  self.test_dataset = test_dataset
649
- self.known_apps_database_enum = known_apps_database_enum
650
- self.unknown_apps_database_enum = unknown_apps_database_enum
651
680
  self.known_app_counts = known_app_counts
652
681
  self.unknown_app_counts = unknown_app_counts
653
- self.collate_fn = collate_fn
654
- self.encoder = encoder
655
- self.flowstats_scaler = flowstats_scaler
656
- self.psizes_scaler = psizes_scaler
657
- self.ipt_scaler = ipt_scaler
682
+ self._collate_fn = collate_fn_simple
683
+
684
+ def _encode_labels_with_unknown(labels, encoder: LabelEncoder, class_info: ClassInfo):
685
+ return encoder.transform(np.where(np.isin(labels, class_info.known_apps), labels, UNKNOWN_STR_LABEL))