replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
@@ -1,4 +1,7 @@
1
+ import json
1
2
  import pickle
3
+ import warnings
4
+ from pathlib import Path
2
5
  from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
3
6
 
4
7
  import numpy as np
@@ -6,14 +9,15 @@ import polars as pl
6
9
  from pandas import DataFrame as PandasDataFrame
7
10
  from polars import DataFrame as PolarsDataFrame
8
11
 
9
- from replay.data import Dataset, FeatureSchema, FeatureSource
12
+ from replay.data import Dataset, FeatureHint, FeatureSchema, FeatureSource, FeatureType
10
13
  from replay.data.dataset_utils import DatasetLabelEncoder
11
- from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
12
- from .sequential_dataset import PandasSequentialDataset, SequentialDataset, PolarsSequentialDataset
13
- from .utils import ensure_pandas, groupby_sequences
14
- from replay.preprocessing import LabelEncoder
14
+ from replay.preprocessing import LabelEncoder, LabelEncodingRule
15
15
  from replay.preprocessing.label_encoder import HandleUnknownStrategies
16
+ from replay.utils.model_handler import deprecation_warning
16
17
 
18
+ from .schema import TensorFeatureInfo, TensorFeatureSource, TensorSchema
19
+ from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
20
+ from .utils import ensure_pandas, groupby_sequences
17
21
 
18
22
  SequenceDataFrameLike = Union[PandasDataFrame, PolarsDataFrame]
19
23
 
@@ -33,7 +37,7 @@ class SequenceTokenizer:
33
37
  """
34
38
  :param tensor_schema: tensor schema of tensor features
35
39
  :param handle_unknown_rule: handle unknown labels rule for LabelEncoder,
36
- values are in ('error', 'use_default_value').
40
+ values are in ('error', 'use_default_value', 'drop').
37
41
  Default: `error`
38
42
  :param default_value: Default value that will fill the unknown labels after transform.
39
43
  When the parameter handle_unknown is set to ``use_default_value``,
@@ -60,6 +64,7 @@ class SequenceTokenizer:
60
64
  :returns: fitted SequenceTokenizer
61
65
  """
62
66
  self._check_if_tensor_schema_matches_data(dataset, self._tensor_schema)
67
+ self._assign_tensor_features_cardinality(dataset)
63
68
  self._encoder.fit(dataset)
64
69
  return self
65
70
 
@@ -84,7 +89,6 @@ class SequenceTokenizer:
84
89
  :param dataset: input dataset to transform
85
90
  :returns: SequentialDataset
86
91
  """
87
- # pylint: disable=protected-access
88
92
  return self.fit(dataset)._transform_unchecked(dataset)
89
93
 
90
94
  @property
@@ -161,10 +165,7 @@ class SequenceTokenizer:
161
165
 
162
166
  assert self._tensor_schema.item_id_feature_name
163
167
 
164
- if is_polars:
165
- dataset_type = PolarsSequentialDataset
166
- else:
167
- dataset_type = PandasSequentialDataset
168
+ dataset_type = PolarsSequentialDataset if is_polars else PandasSequentialDataset
168
169
 
169
170
  return dataset_type(
170
171
  tensor_schema=schema,
@@ -191,7 +192,7 @@ class SequenceTokenizer:
191
192
  return (
192
193
  grouped_interactions.sort(dataset.feature_schema.query_id_column),
193
194
  dataset.query_features,
194
- dataset.item_features
195
+ dataset.item_features,
195
196
  )
196
197
 
197
198
  # We sort by QUERY_ID to make sure order is deterministic
@@ -211,7 +212,6 @@ class SequenceTokenizer:
211
212
 
212
213
  return grouped_interactions_pd, query_features_pd, item_features_pd
213
214
 
214
- # pylint: disable=too-many-arguments
215
215
  def _make_sequence_features(
216
216
  self,
217
217
  schema: TensorSchema,
@@ -298,24 +298,27 @@ class SequenceTokenizer:
298
298
  for tensor_feature in tensor_schema.all_features:
299
299
  feature_sources = tensor_feature.feature_sources
300
300
  if not feature_sources:
301
- raise ValueError("All tensor features must have sources defined")
301
+ msg = "All tensor features must have sources defined"
302
+ raise ValueError(msg)
302
303
 
303
304
  source_tables: List[FeatureSource] = [s.source for s in feature_sources]
304
305
 
305
306
  unexpected_tables = list(filter(lambda x: not isinstance(x, FeatureSource), source_tables))
306
307
  if len(unexpected_tables) > 0:
307
- raise ValueError(f"Found unexpected source tables: {unexpected_tables}")
308
+ msg = f"Found unexpected source tables: {unexpected_tables}"
309
+ raise ValueError(msg)
308
310
 
309
311
  if not tensor_feature.is_seq:
310
312
  if FeatureSource.INTERACTIONS in source_tables:
311
- raise ValueError("Interaction features must be treated as sequential")
313
+ msg = "Interaction features must be treated as sequential"
314
+ raise ValueError(msg)
312
315
 
313
316
  if FeatureSource.ITEM_FEATURES in source_tables:
314
- raise ValueError("Item features must be treated as sequential")
317
+ msg = "Item features must be treated as sequential"
318
+ raise ValueError(msg)
315
319
 
316
- # pylint: disable=too-many-branches
317
320
  @classmethod
318
- def _check_if_tensor_schema_matches_data(
321
+ def _check_if_tensor_schema_matches_data( # noqa: C901
319
322
  cls,
320
323
  dataset: Dataset,
321
324
  tensor_schema: TensorSchema,
@@ -324,77 +327,205 @@ class SequenceTokenizer:
324
327
  # Check if all source columns specified in tensor schema exist in provided data frames
325
328
  sources_for_tensors: List[TensorFeatureSource] = []
326
329
  for tensor_feature_name, tensor_feature in tensor_schema.items():
327
- if (tensor_features_to_keep is not None) and (tensor_feature_name not in tensor_features_to_keep):
330
+ if tensor_features_to_keep is not None and tensor_feature_name not in tensor_features_to_keep:
328
331
  continue
329
332
 
330
- feature_sources = tensor_feature.feature_sources
331
- if feature_sources:
332
- sources_for_tensors += feature_sources
333
+ if tensor_feature.feature_sources:
334
+ sources_for_tensors += tensor_feature.feature_sources
333
335
 
334
336
  query_id_column = dataset.feature_schema.query_id_column
335
337
  item_id_column = dataset.feature_schema.item_id_column
336
338
 
337
- interaction_feature_columns = set(
338
- list(dataset.feature_schema.interaction_features.columns) + [query_id_column, item_id_column]
339
- )
340
- query_feature_columns = set(list(dataset.feature_schema.query_features.columns) + [query_id_column])
341
- item_feature_columns = set(list(dataset.feature_schema.item_features.columns) + [item_id_column])
339
+ interaction_feature_columns = {
340
+ *dataset.feature_schema.interaction_features.columns,
341
+ query_id_column,
342
+ item_id_column,
343
+ }
344
+ query_feature_columns = {*dataset.feature_schema.query_features.columns, query_id_column}
345
+ item_feature_columns = {*dataset.feature_schema.item_features.columns, item_id_column}
342
346
 
343
347
  for feature_source in sources_for_tensors:
344
348
  assert feature_source is not None
345
349
  if feature_source.source == FeatureSource.INTERACTIONS:
346
350
  if feature_source.column not in interaction_feature_columns:
347
- raise ValueError(f"Expected column '{feature_source.column}' in dataset")
351
+ msg = f"Expected column '{feature_source.column}' in dataset"
352
+ raise ValueError(msg)
348
353
  elif feature_source.source == FeatureSource.QUERY_FEATURES:
349
354
  if dataset.query_features is None:
350
- raise ValueError(f"Expected column '{feature_source.column}', but query features are not specified")
355
+ msg = f"Expected column '{feature_source.column}', but query features are not specified"
356
+ raise ValueError(msg)
351
357
  if feature_source.column not in query_feature_columns:
352
- raise ValueError(f"Expected column '{feature_source.column}' in query features data frame")
358
+ msg = f"Expected column '{feature_source.column}' in query features data frame"
359
+ raise ValueError(msg)
353
360
  elif feature_source.source == FeatureSource.ITEM_FEATURES:
354
361
  if dataset.item_features is None:
355
- raise ValueError(f"Expected column '{feature_source.column}', but item features are not specified")
362
+ msg = f"Expected column '{feature_source.column}', but item features are not specified"
363
+ raise ValueError(msg)
356
364
  if feature_source.column not in item_feature_columns:
357
- raise ValueError(f"Expected column '{feature_source.column}' in item features data frame")
365
+ msg = f"Expected column '{feature_source.column}' in item features data frame"
366
+ raise ValueError(msg)
358
367
  else:
359
- raise ValueError(f"Found unexpected table '{feature_source.source}' in tensor schema")
368
+ msg = f"Found unexpected table '{feature_source.source}' in tensor schema"
369
+ raise ValueError(msg)
360
370
 
361
371
  # Check if user ID and item ID columns are consistent with tensor schema
362
372
  if tensor_schema.query_id_feature_name is not None:
363
373
  tensor_feature = tensor_schema.query_id_features.item()
364
374
  assert tensor_feature.feature_source
365
375
  if tensor_feature.feature_source.column != dataset.feature_schema.query_id_column:
366
- raise ValueError("Tensor schema query ID source colum does not match query ID in data frame")
376
+ msg = "Tensor schema query ID source colum does not match query ID in data frame"
377
+ raise ValueError(msg)
367
378
 
368
379
  if tensor_schema.item_id_feature_name is None:
369
- raise ValueError("Tensor schema must have item id feature defined")
380
+ msg = "Tensor schema must have item id feature defined"
381
+ raise ValueError(msg)
370
382
 
371
383
  tensor_feature = tensor_schema.item_id_features.item()
372
384
  assert tensor_feature.feature_source
373
385
  if tensor_feature.feature_source.column != dataset.feature_schema.item_id_column:
374
- raise ValueError("Tensor schema item ID source colum does not match item ID in data frame")
386
+ msg = "Tensor schema item ID source colum does not match item ID in data frame"
387
+ raise ValueError(msg)
388
+
389
+ def _assign_tensor_features_cardinality(self, dataset: Dataset) -> None:
390
+ for tensor_feature in self._tensor_schema.categorical_features.all_features:
391
+ dataset_feature = dataset.feature_schema[tensor_feature.feature_source.column]
392
+ if tensor_feature.cardinality is not None:
393
+ warnings.warn(
394
+ f"The specified cardinality of {tensor_feature.name} "
395
+ f"will be replaced by {dataset_feature.column} from Dataset"
396
+ )
397
+ if dataset_feature.feature_type != FeatureType.CATEGORICAL:
398
+ error_msg = (
399
+ f"TensorFeatureInfo {tensor_feature.name} "
400
+ f"and FeatureInfo {dataset_feature.column} must be the same FeatureType"
401
+ )
402
+ raise RuntimeError(error_msg)
403
+ tensor_feature._set_cardinality(dataset_feature.cardinality)
375
404
 
376
405
  @classmethod
377
- def load(cls, path: str) -> "SequenceTokenizer":
406
+ @deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
407
+ def load(cls, path: str, use_pickle: bool = False) -> "SequenceTokenizer":
378
408
  """
379
409
  Load tokenizer object from the given path.
380
410
 
381
411
  :param path: Path to load the tokenizer.
412
+ :param use_pickle: If `False` - tokenizer will be loaded from `.replay` directory.
413
+ If `True` - tokenizer will be loaded with pickle.
414
+ Default: `False`.
382
415
 
383
416
  :returns: Loaded tokenizer object.
384
417
  """
385
- with open(path, "rb") as file:
386
- tokenizer = pickle.load(file)
418
+ if not use_pickle:
419
+ base_path = Path(path).with_suffix(".replay").resolve()
420
+ with open(base_path / "init_args.json", "r") as file:
421
+ tokenizer_dict = json.loads(file.read())
422
+
423
+ # load tensor_schema, tensor_features
424
+ tensor_schema_data = tokenizer_dict["init_args"]["tensor_schema"]
425
+ features_list = []
426
+ for feature_data in tensor_schema_data:
427
+ feature_data["feature_sources"] = [
428
+ TensorFeatureSource(source=FeatureSource[x["source"]], column=x["column"], index=x["index"])
429
+ for x in feature_data["feature_sources"]
430
+ ]
431
+ f_type = feature_data["feature_type"]
432
+ f_hint = feature_data["feature_hint"]
433
+ feature_data["feature_type"] = FeatureType[f_type] if f_type else None
434
+ feature_data["feature_hint"] = FeatureHint[f_hint] if f_hint else None
435
+ features_list.append(TensorFeatureInfo(**feature_data))
436
+ tokenizer_dict["init_args"]["tensor_schema"] = TensorSchema(features_list)
437
+
438
+ # Load encoder columns and rules
439
+ types = list(FeatureHint) + list(FeatureSource)
440
+ map_types = {x.name: x for x in types}
441
+ encoder_features_columns = {
442
+ map_types[key]: value for key, value in tokenizer_dict["encoder"]["features_columns"].items()
443
+ }
444
+
445
+ rules_dict = tokenizer_dict["encoder"]["encoding_rules"]
446
+ for rule in rules_dict:
447
+ rule_data = rules_dict[rule]
448
+ if rule_data["mapping"] and rule_data["is_int"]:
449
+ rule_data["mapping"] = {int(key): value for key, value in rule_data["mapping"].items()}
450
+ del rule_data["is_int"]
451
+
452
+ tokenizer_dict["encoder"]["encoding_rules"][rule] = LabelEncodingRule(**rule_data)
453
+
454
+ # Init tokenizer
455
+ tokenizer = cls(**tokenizer_dict["init_args"])
456
+ tokenizer._encoder._features_columns = encoder_features_columns
457
+ tokenizer._encoder._encoding_rules = tokenizer_dict["encoder"]["encoding_rules"]
458
+ else:
459
+ with open(path, "rb") as file:
460
+ tokenizer = pickle.load(file)
387
461
 
388
462
  return tokenizer
389
463
 
390
- def save(self, path: str) -> None:
464
+ @deprecation_warning("with `use_pickle` equals to `True` will be deprecated in future versions")
465
+ def save(self, path: str, use_pickle: bool = False) -> None:
391
466
  """
392
467
  Save the tokenizer to the given path.
393
468
 
394
469
  :param path: Path to save the tokenizer.
395
- """
396
- with open(path, "wb") as file:
397
- pickle.dump(self, file)
470
+ :param use_pickle: If `False` - tokenizer will be saved in `.replay` directory.
471
+ If `True` - tokenizer will be saved with pickle.
472
+ Default: `False`.
473
+ """
474
+ if not use_pickle:
475
+ tokenizer_dict = {}
476
+ tokenizer_dict["_class_name"] = self.__class__.__name__
477
+ tokenizer_dict["init_args"] = {
478
+ "allow_collect_to_master": self._allow_collect_to_master,
479
+ "handle_unknown_rule": self._encoder._handle_unknown_rule,
480
+ "default_value_rule": self._encoder._default_value_rule,
481
+ "tensor_schema": [],
482
+ }
483
+
484
+ # save tensor schema
485
+ for feature in list(self._tensor_schema.values()):
486
+ tokenizer_dict["init_args"]["tensor_schema"].append(
487
+ {
488
+ "name": feature.name,
489
+ "feature_type": feature.feature_type.name,
490
+ "is_seq": feature.is_seq,
491
+ "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
492
+ "feature_sources": [
493
+ {"source": x.source.name, "column": x.column, "index": x.index}
494
+ for x in feature.feature_sources
495
+ ]
496
+ if feature.feature_sources
497
+ else None,
498
+ "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
499
+ "embedding_dim": feature.embedding_dim
500
+ if feature.feature_type == FeatureType.CATEGORICAL
501
+ else None,
502
+ "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
503
+ }
504
+ )
505
+
506
+ # save DatasetLabelEncoder
507
+ tokenizer_dict["encoder"] = {
508
+ "features_columns": {key.name: value for key, value in self._encoder._features_columns.items()},
509
+ "encoding_rules": {
510
+ key: {
511
+ "column": value.column,
512
+ "mapping": value._mapping,
513
+ "handle_unknown": value._handle_unknown,
514
+ "default_value": value._default_value,
515
+ "is_int": isinstance(next(iter(value._mapping.keys())), int),
516
+ }
517
+ for key, value in self._encoder._encoding_rules.items()
518
+ },
519
+ }
520
+
521
+ base_path = Path(path).with_suffix(".replay").resolve()
522
+ base_path.mkdir(parents=True, exist_ok=True)
523
+
524
+ with open(base_path / "init_args.json", "w+") as file:
525
+ json.dump(tokenizer_dict, file)
526
+ else:
527
+ with open(path, "wb") as file:
528
+ pickle.dump(self, file)
398
529
 
399
530
 
400
531
  class _SequenceProcessor:
@@ -409,7 +540,6 @@ class _SequenceProcessor:
409
540
  with passing all tensor features one by one.
410
541
  """
411
542
 
412
- # pylint: disable=too-many-arguments
413
543
  def __init__(
414
544
  self,
415
545
  tensor_schema: TensorSchema,
@@ -462,13 +592,9 @@ class _SequenceProcessor:
462
592
  for tensor_feature_name in self._tensor_schema:
463
593
  tensor_feature = self._tensor_schema[tensor_feature_name]
464
594
  if tensor_feature.is_cat:
465
- data = data.join(
466
- self._process_cat_feature(tensor_feature), on=self._query_id_column, how="left"
467
- )
595
+ data = data.join(self._process_cat_feature(tensor_feature), on=self._query_id_column, how="left")
468
596
  elif tensor_feature.is_num:
469
- data = data.join(
470
- self._process_num_feature(tensor_feature), on=self._query_id_column, how="left"
471
- )
597
+ data = data.join(self._process_num_feature(tensor_feature), on=self._query_id_column, how="left")
472
598
  else:
473
599
  assert False, "Unknown tensor feature type"
474
600
  return data
@@ -490,38 +616,40 @@ class _SequenceProcessor:
490
616
  def get_sequence(user, source, data):
491
617
  if source.source == FeatureSource.INTERACTIONS:
492
618
  return np.array(
493
- self._grouped_interactions
494
- .filter(pl.col(self._query_id_column) == user)[source.column][0],
495
- dtype=np.float32
619
+ self._grouped_interactions.filter(pl.col(self._query_id_column) == user)[source.column][0],
620
+ dtype=np.float32,
496
621
  ).tolist()
497
622
  elif source.source == FeatureSource.ITEM_FEATURES:
498
623
  return (
499
624
  pl.DataFrame({self._item_id_column: data})
500
625
  .join(self._item_features, on=self._item_id_column, how="left")
501
- .select(source.column).to_numpy().reshape(-1).tolist()
626
+ .select(source.column)
627
+ .to_numpy()
628
+ .reshape(-1)
629
+ .tolist()
502
630
  )
503
631
  else:
504
632
  assert False, "Unknown tensor feature source table"
633
+
505
634
  result = (
506
- self._grouped_interactions
507
- .select(self._query_id_column, self._item_id_column)
508
- .map_rows(
509
- lambda x:
510
- (
511
- x[0],
512
- [get_sequence(x[0], source, x[1])
513
- for source in tensor_feature.feature_sources]
514
- )
635
+ self._grouped_interactions.select(self._query_id_column, self._item_id_column).map_rows(
636
+ lambda x: (x[0], [get_sequence(x[0], source, x[1]) for source in tensor_feature.feature_sources])
515
637
  )
516
638
  ).rename({"column_0": self._query_id_column, "column_1": tensor_feature.name})
517
639
 
518
- return pl.DataFrame({
519
- self._query_id_column: result[self._query_id_column].to_list(),
520
- tensor_feature.name: list(map(
521
- lambda x: np.array(x).reshape(-1, len(tensor_feature.feature_sources)).tolist(),
522
- result[tensor_feature.name].to_list()
523
- ))
524
- })
640
+ if tensor_feature.feature_hint == FeatureHint.TIMESTAMP:
641
+ reshape_size = -1
642
+ else:
643
+ reshape_size = (-1, len(tensor_feature.feature_sources))
644
+
645
+ return pl.DataFrame(
646
+ {
647
+ self._query_id_column: result[self._query_id_column].to_list(),
648
+ tensor_feature.name: [
649
+ np.array(x).reshape(reshape_size).tolist() for x in result[tensor_feature.name].to_list()
650
+ ],
651
+ }
652
+ )
525
653
 
526
654
  def _process_num_feature(self, tensor_feature: TensorFeatureInfo) -> List[np.ndarray]:
527
655
  """
@@ -554,7 +682,10 @@ class _SequenceProcessor:
554
682
  else:
555
683
  assert False, "Unknown tensor feature source table"
556
684
  all_seqs = np.array(all_features_for_user, dtype=np.float32)
557
- all_seqs = all_seqs.reshape(-1, (len(tensor_feature.feature_sources)))
685
+ if tensor_feature.feature_hint == FeatureHint.TIMESTAMP:
686
+ all_seqs = all_seqs.reshape(-1)
687
+ else:
688
+ all_seqs = all_seqs.reshape(-1, (len(tensor_feature.feature_sources)))
558
689
  values.append(all_seqs)
559
690
  return values
560
691
 
@@ -572,9 +703,9 @@ class _SequenceProcessor:
572
703
  assert source is not None
573
704
 
574
705
  if self._is_polars:
575
- return self._grouped_interactions.select(
576
- self._query_id_column, source.column
577
- ).rename({source.column: tensor_feature.name})
706
+ return self._grouped_interactions.select(self._query_id_column, source.column).rename(
707
+ {source.column: tensor_feature.name}
708
+ )
578
709
 
579
710
  return [np.array(sequence, dtype=np.int64) for sequence in self._grouped_interactions[source.column]]
580
711
 
@@ -603,9 +734,9 @@ class _SequenceProcessor:
603
734
  result = self._query_features
604
735
  repeat_value = 1
605
736
 
606
- return result.select(
607
- self._query_id_column, pl.col(source.column).repeat_by(repeat_value)
608
- ).rename({source.column: tensor_feature.name})
737
+ return result.select(self._query_id_column, pl.col(source.column).repeat_by(repeat_value)).rename(
738
+ {source.column: tensor_feature.name}
739
+ )
609
740
 
610
741
  query_feature = self._query_features[source.column].values
611
742
  if tensor_feature.is_seq:
@@ -632,20 +763,19 @@ class _SequenceProcessor:
632
763
 
633
764
  if self._is_polars:
634
765
  return (
635
- self._grouped_interactions
636
- .select(self._query_id_column, self._item_id_column)
766
+ self._grouped_interactions.select(self._query_id_column, self._item_id_column)
637
767
  .map_rows(
638
- lambda x:
639
- (
768
+ lambda x: (
640
769
  x[0],
641
770
  pl.DataFrame({self._item_id_column: x[1]})
642
771
  .join(self._item_features, on=self._item_id_column, how="left")
643
- .select(source.column).to_numpy().reshape(-1).tolist(),
772
+ .select(source.column)
773
+ .to_numpy()
774
+ .reshape(-1)
775
+ .tolist(),
644
776
  )
645
- ).rename({
646
- "column_0": self._query_id_column,
647
- "column_1": tensor_feature.name
648
- })
777
+ )
778
+ .rename({"column_0": self._query_id_column, "column_1": tensor_feature.name})
649
779
  )
650
780
 
651
781
  item_feature = self._item_features[source.column]
@@ -6,11 +6,9 @@ import polars as pl
6
6
  from pandas import DataFrame as PandasDataFrame
7
7
  from polars import DataFrame as PolarsDataFrame
8
8
 
9
- from replay.data.schema import FeatureType
10
9
  from .schema import TensorSchema
11
10
 
12
11
 
13
- # pylint: disable=missing-function-docstring
14
12
  class SequentialDataset(abc.ABC):
15
13
  """
16
14
  Abstract base class for sequential dataset
@@ -132,19 +130,9 @@ class PandasSequentialDataset(SequentialDataset):
132
130
 
133
131
  self._sequences = sequences
134
132
 
135
- for feature in tensor_schema.all_features:
136
- if feature.feature_type == FeatureType.CATEGORICAL:
137
- # pylint: disable=protected-access
138
- feature._set_cardinality_callback(self.cardinality_callback)
139
-
140
133
  def __len__(self) -> int:
141
134
  return len(self._sequences)
142
135
 
143
- def cardinality_callback(self, column: str) -> int:
144
- if self._query_id_column == column:
145
- return self._sequences.index.nunique()
146
- return len({x for seq in self._sequences[column] for x in seq})
147
-
148
136
  def get_query_id(self, index: int) -> int:
149
137
  return self._sequences.index[index]
150
138
 
@@ -181,12 +169,12 @@ class PandasSequentialDataset(SequentialDataset):
181
169
 
182
170
  @classmethod
183
171
  def _check_if_schema_matches_data(cls, tensor_schema: TensorSchema, data: PandasDataFrame) -> None:
184
- for tensor_feature_name in tensor_schema.keys():
172
+ for tensor_feature_name in tensor_schema:
185
173
  if tensor_feature_name not in data:
186
- raise ValueError("Tensor schema does not match with provided data frame")
174
+ msg = "Tensor schema does not match with provided data frame"
175
+ raise ValueError(msg)
187
176
 
188
177
 
189
- # pylint:disable=super-init-not-called
190
178
  class PolarsSequentialDataset(PandasSequentialDataset):
191
179
  """
192
180
  Sequential dataset that stores sequences in PolarsDataFrame format.
@@ -215,11 +203,6 @@ class PolarsSequentialDataset(PandasSequentialDataset):
215
203
  if self._sequences.index.name != query_id_column:
216
204
  self._sequences = self._sequences.set_index(query_id_column)
217
205
 
218
- for feature in tensor_schema.all_features:
219
- if feature.feature_type == FeatureType.CATEGORICAL:
220
- # pylint: disable=protected-access
221
- feature._set_cardinality_callback(self.cardinality_callback)
222
-
223
206
  def filter_by_query_id(self, query_ids_to_keep: np.ndarray) -> "PolarsSequentialDataset":
224
207
  filtered_sequences = self._sequences.loc[query_ids_to_keep]
225
208
  if filtered_sequences.index.name == self._query_id_column:
@@ -233,6 +216,7 @@ class PolarsSequentialDataset(PandasSequentialDataset):
233
216
 
234
217
  @classmethod
235
218
  def _check_if_schema_matches_data(cls, tensor_schema: TensorSchema, data: PolarsDataFrame) -> None:
236
- for tensor_feature_name in tensor_schema.keys():
219
+ for tensor_feature_name in tensor_schema:
237
220
  if tensor_feature_name not in data:
238
- raise ValueError("Tensor schema does not match with provided data frame")
221
+ msg = "Tensor schema does not match with provided data frame"
222
+ raise ValueError(msg)
@@ -14,6 +14,7 @@ class TorchSequentialBatch(NamedTuple):
14
14
  """
15
15
  Batch of TorchSequentialDataset
16
16
  """
17
+
17
18
  query_id: torch.LongTensor
18
19
  padding_mask: torch.BoolTensor
19
20
  features: TensorMap
@@ -88,7 +89,7 @@ class TorchSequentialDataset(TorchDataset):
88
89
  ) -> torch.Tensor:
89
90
  sequence = self._sequential.get_sequence(sequence_index, feature.name)
90
91
  if feature.is_seq:
91
- sequence = sequence[sequence_offset : sequence_offset + self._max_sequence_length] # noqa: E203
92
+ sequence = sequence[sequence_offset : sequence_offset + self._max_sequence_length]
92
93
 
93
94
  tensor_dtype = self._get_tensor_dtype(feature)
94
95
  tensor_sequence = torch.tensor(sequence, dtype=tensor_dtype)
@@ -109,14 +110,15 @@ class TorchSequentialDataset(TorchDataset):
109
110
  elif len(sequence.shape) == 2:
110
111
  padded_sequence_shape = (self._max_sequence_length, sequence.shape[1])
111
112
  else:
112
- raise ValueError(f"Unsupported shape for sequence: {len(sequence.shape)}")
113
+ msg = f"Unsupported shape for sequence: {len(sequence.shape)}"
114
+ raise ValueError(msg)
113
115
 
114
116
  padded_sequence = torch.full(
115
117
  padded_sequence_shape,
116
118
  self._padding_value,
117
119
  dtype=sequence.dtype,
118
120
  )
119
- padded_sequence[-len(sequence) :].copy_(sequence) # noqa: E203
121
+ padded_sequence[-len(sequence) :].copy_(sequence)
120
122
  return padded_sequence
121
123
 
122
124
  def _get_tensor_dtype(self, feature: TensorFeatureInfo) -> torch.dtype:
@@ -151,6 +153,7 @@ class TorchSequentialValidationBatch(NamedTuple):
151
153
  """
152
154
  Batch of TorchSequentialValidationDataset
153
155
  """
156
+
154
157
  query_id: torch.LongTensor
155
158
  padding_mask: torch.BoolTensor
156
159
  features: TensorMap
@@ -167,7 +170,6 @@ class TorchSequentialValidationDataset(TorchDataset):
167
170
  Torch dataset for sequential recommender models that additionally stores ground truth
168
171
  """
169
172
 
170
- # pylint: disable=too-many-arguments
171
173
  def __init__(
172
174
  self,
173
175
  sequential: SequentialDataset,
@@ -195,19 +197,24 @@ class TorchSequentialValidationDataset(TorchDataset):
195
197
 
196
198
  if label_feature_name:
197
199
  if label_feature_name not in ground_truth.schema:
198
- raise ValueError("Label feature name not found in ground truth schema")
200
+ msg = "Label feature name not found in ground truth schema"
201
+ raise ValueError(msg)
199
202
 
200
203
  if label_feature_name not in train.schema:
201
- raise ValueError("Label feature name not found in train schema")
204
+ msg = "Label feature name not found in train schema"
205
+ raise ValueError(msg)
202
206
 
203
207
  if not ground_truth.schema[label_feature_name].is_cat:
204
- raise ValueError("Label feature must be categorical")
208
+ msg = "Label feature must be categorical"
209
+ raise ValueError(msg)
205
210
 
206
211
  if not ground_truth.schema[label_feature_name].is_seq:
207
- raise ValueError("Label feature must be sequential")
212
+ msg = "Label feature must be sequential"
213
+ raise ValueError(msg)
208
214
 
209
215
  if len(np.intersect1d(sequential.get_all_query_ids(), ground_truth.get_all_query_ids())) == 0:
210
- raise ValueError("Sequential data and ground truth must contain the same query IDs")
216
+ msg = "Sequential data and ground truth must contain the same query IDs"
217
+ raise ValueError(msg)
211
218
 
212
219
  self._ground_truth = ground_truth
213
220
  self._train = train
@@ -271,7 +278,9 @@ class TorchSequentialValidationDataset(TorchDataset):
271
278
  ground_truth_item_feature = ground_truth_schema.item_id_features.item()
272
279
 
273
280
  if sequential_item_feature.name != ground_truth_item_feature.name:
274
- raise ValueError("Schema mismatch: item feature name does not match ground truth")
281
+ msg = "Schema mismatch: item feature name does not match ground truth"
282
+ raise ValueError(msg)
275
283
 
276
284
  if sequential_item_feature.cardinality != ground_truth_item_feature.cardinality:
277
- raise ValueError("Schema mismatch: item feature cardinality does not match ground truth")
285
+ msg = "Schema mismatch: item feature cardinality does not match ground truth"
286
+ raise ValueError(msg)