replay-rec 0.16.0rc0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -10
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +1 -1
  109. replay/experimental/__init__.py +0 -0
  110. replay/experimental/metrics/__init__.py +0 -61
  111. replay/experimental/metrics/base_metric.py +0 -661
  112. replay/experimental/metrics/coverage.py +0 -117
  113. replay/experimental/metrics/experiment.py +0 -200
  114. replay/experimental/metrics/hitrate.py +0 -27
  115. replay/experimental/metrics/map.py +0 -31
  116. replay/experimental/metrics/mrr.py +0 -19
  117. replay/experimental/metrics/ncis_precision.py +0 -32
  118. replay/experimental/metrics/ndcg.py +0 -50
  119. replay/experimental/metrics/precision.py +0 -23
  120. replay/experimental/metrics/recall.py +0 -26
  121. replay/experimental/metrics/rocauc.py +0 -50
  122. replay/experimental/metrics/surprisal.py +0 -102
  123. replay/experimental/metrics/unexpectedness.py +0 -74
  124. replay/experimental/models/__init__.py +0 -10
  125. replay/experimental/models/admm_slim.py +0 -216
  126. replay/experimental/models/base_neighbour_rec.py +0 -222
  127. replay/experimental/models/base_rec.py +0 -1361
  128. replay/experimental/models/base_torch_rec.py +0 -247
  129. replay/experimental/models/cql.py +0 -468
  130. replay/experimental/models/ddpg.py +0 -1007
  131. replay/experimental/models/dt4rec/__init__.py +0 -0
  132. replay/experimental/models/dt4rec/dt4rec.py +0 -193
  133. replay/experimental/models/dt4rec/gpt1.py +0 -411
  134. replay/experimental/models/dt4rec/trainer.py +0 -128
  135. replay/experimental/models/dt4rec/utils.py +0 -274
  136. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  137. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -733
  138. replay/experimental/models/implicit_wrap.py +0 -138
  139. replay/experimental/models/lightfm_wrap.py +0 -327
  140. replay/experimental/models/mult_vae.py +0 -374
  141. replay/experimental/models/neuromf.py +0 -462
  142. replay/experimental/models/scala_als.py +0 -311
  143. replay/experimental/nn/data/__init__.py +0 -1
  144. replay/experimental/nn/data/schema_builder.py +0 -58
  145. replay/experimental/preprocessing/__init__.py +0 -3
  146. replay/experimental/preprocessing/data_preparator.py +0 -929
  147. replay/experimental/preprocessing/padder.py +0 -231
  148. replay/experimental/preprocessing/sequence_generator.py +0 -218
  149. replay/experimental/scenarios/__init__.py +0 -1
  150. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  151. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -86
  152. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -271
  153. replay/experimental/scenarios/obp_wrapper/utils.py +0 -88
  154. replay/experimental/scenarios/two_stages/reranker.py +0 -116
  155. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -843
  156. replay/experimental/utils/__init__.py +0 -0
  157. replay/experimental/utils/logger.py +0 -24
  158. replay/experimental/utils/model_handler.py +0 -213
  159. replay/experimental/utils/session_handler.py +0 -47
  160. replay_rec-0.16.0rc0.dist-info/NOTICE +0 -41
  161. replay_rec-0.16.0rc0.dist-info/RECORD +0 -178
  162. {replay_rec-0.16.0rc0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
replay/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
1
  """ RecSys library """
2
- __version__ = "0.16.0.preview"
2
+ __version__ = "0.17.0"
replay/data/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
- from .spark_schema import get_schema
2
1
  from .dataset import Dataset
3
2
  from .schema import FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
3
+ from .spark_schema import get_schema
4
4
 
5
5
  __all__ = [
6
6
  "Dataset",
replay/data/dataset.py CHANGED
@@ -7,21 +7,20 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence
7
7
 
8
8
  import numpy as np
9
9
 
10
- from .schema import FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
11
10
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
12
11
 
12
+ from .schema import FeatureHint, FeatureInfo, FeatureSchema, FeatureSource, FeatureType
13
+
13
14
  if PYSPARK_AVAILABLE:
14
- import pyspark.sql.functions as F
15
+ import pyspark.sql.functions as sf
15
16
  from pyspark.storagelevel import StorageLevel
16
17
 
17
18
 
18
- # pylint: disable=too-many-instance-attributes
19
19
  class Dataset:
20
20
  """
21
21
  Universal dataset for feeding data to models.
22
22
  """
23
23
 
24
- # pylint: disable=too-many-arguments
25
24
  def __init__(
26
25
  self,
27
26
  feature_schema: FeatureSchema,
@@ -57,23 +56,23 @@ class Dataset:
57
56
  try:
58
57
  feature_schema.item_id_column
59
58
  except Exception as exception:
60
- raise ValueError("Item id column is not set.") from exception
59
+ msg = "Item id column is not set."
60
+ raise ValueError(msg) from exception
61
61
 
62
62
  try:
63
63
  feature_schema.query_id_column
64
64
  except Exception as exception:
65
- raise ValueError("Query id column is not set.") from exception
66
-
67
- if (
68
- self.item_features is not None
69
- and not check_dataframes_types_equal(self._interactions, self.item_features)
70
- ):
71
- raise TypeError("Interactions and item features should have the same type.")
72
- if (
73
- self.query_features is not None
74
- and not check_dataframes_types_equal(self._interactions, self.query_features)
65
+ msg = "Query id column is not set."
66
+ raise ValueError(msg) from exception
67
+
68
+ if self.item_features is not None and not check_dataframes_types_equal(self._interactions, self.item_features):
69
+ msg = "Interactions and item features should have the same type."
70
+ raise TypeError(msg)
71
+ if self.query_features is not None and not check_dataframes_types_equal(
72
+ self._interactions, self.query_features
75
73
  ):
76
- raise TypeError("Interactions and query features should have the same type.")
74
+ msg = "Interactions and query features should have the same type."
75
+ raise TypeError(msg)
77
76
 
78
77
  self._feature_source_map: Dict[FeatureSource, DataFrameLike] = {
79
78
  FeatureSource.INTERACTIONS: self.interactions,
@@ -191,6 +190,7 @@ class Dataset:
191
190
  return self._feature_schema
192
191
 
193
192
  if PYSPARK_AVAILABLE:
193
+
194
194
  def persist(self, storage_level: StorageLevel = StorageLevel(True, True, False, True, 1)) -> None:
195
195
  """
196
196
  Sets the storage level to persist SparkDataFrame for interactions, item_features
@@ -295,7 +295,6 @@ class Dataset:
295
295
  def _set_cardinality(self, features_list: Sequence[FeatureInfo]) -> None:
296
296
  for feature in features_list:
297
297
  if feature.feature_type == FeatureType.CATEGORICAL:
298
- # pylint: disable=protected-access
299
298
  feature._set_cardinality_callback(self._get_cardinality(feature))
300
299
 
301
300
  def _fill_feature_schema(self, feature_schema: FeatureSchema) -> FeatureSchema:
@@ -333,15 +332,14 @@ class Dataset:
333
332
 
334
333
  for feature in features_list:
335
334
  if feature.feature_hint in [FeatureHint.QUERY_ID, FeatureHint.ITEM_ID]:
336
- # pylint: disable=protected-access
337
335
  feature._set_feature_source(source=FeatureSource.INTERACTIONS)
338
336
  continue
339
- source = source_mapping.get(feature.column) # type: ignore
337
+ source = source_mapping.get(feature.column)
340
338
  if source:
341
- # pylint: disable=protected-access
342
339
  feature._set_feature_source(source=source_mapping[feature.column])
343
340
  else:
344
- raise ValueError(f"{feature.column} doesn't exist in provided dataframes")
341
+ msg = f"{feature.column} doesn't exist in provided dataframes"
342
+ raise ValueError(msg)
345
343
 
346
344
  self._set_cardinality(features_list=features_list)
347
345
  return features_list
@@ -362,10 +360,8 @@ class Dataset:
362
360
  self._set_cardinality(features_list=unlabeled_columns)
363
361
  return unlabeled_columns
364
362
 
365
- # pylint: disable=no-self-use
366
363
  def _set_features_source(self, feature_list: List[FeatureInfo], source: FeatureSource) -> None:
367
364
  for feature in feature_list:
368
- # pylint: disable=protected-access
369
365
  feature._set_feature_source(source)
370
366
 
371
367
  def _check_ids_consistency(self, hint: FeatureHint) -> None:
@@ -377,8 +373,8 @@ class Dataset:
377
373
  self.feature_schema.item_id_column if hint == FeatureHint.ITEM_ID else self.feature_schema.query_id_column
378
374
  )
379
375
  if self.is_pandas:
380
- interactions_unique_ids = set(self.interactions[ids_column].unique()) # type: ignore
381
- features_df_unique_ids = set(features_df[ids_column].unique()) # type: ignore # pylint: disable=E1136
376
+ interactions_unique_ids = set(self.interactions[ids_column].unique())
377
+ features_df_unique_ids = set(features_df[ids_column].unique())
382
378
  in_interactions_not_in_features_ids = interactions_unique_ids - features_df_unique_ids
383
379
  is_consistent = len(in_interactions_not_in_features_ids) == 0
384
380
  elif self.is_spark:
@@ -389,14 +385,18 @@ class Dataset:
389
385
  .count()
390
386
  ) == 0
391
387
  else:
392
- is_consistent = len(
393
- self.interactions.select(ids_column)
394
- .unique()
395
- .join(features_df.select(ids_column).unique(), on=ids_column, how="anti")
396
- ) == 0
388
+ is_consistent = (
389
+ len(
390
+ self.interactions.select(ids_column)
391
+ .unique()
392
+ .join(features_df.select(ids_column).unique(), on=ids_column, how="anti")
393
+ )
394
+ == 0
395
+ )
397
396
 
398
397
  if not is_consistent:
399
- raise ValueError(f"There are IDs in the interactions that are missing in the {hint.name} dataframe.")
398
+ msg = f"There are IDs in the interactions that are missing in the {hint.name} dataframe."
399
+ raise ValueError(msg)
400
400
 
401
401
  def _check_column_encoded(
402
402
  self, data: DataFrameLike, column: str, source: FeatureSource, cardinality: Optional[int]
@@ -419,26 +419,29 @@ class Dataset:
419
419
  is_int = data[column].dtype.is_integer()
420
420
 
421
421
  if not is_int:
422
- raise ValueError(f"IDs in {source.name}.{column} are not encoded. They are not int.")
422
+ msg = f"IDs in {source.name}.{column} are not encoded. They are not int."
423
+ raise ValueError(msg)
423
424
 
424
425
  if self.is_pandas:
425
- min_id = data[column].min() # type: ignore
426
+ min_id = data[column].min()
426
427
  elif self.is_spark:
427
- min_id = data.agg(F.min(column).alias("min_index")).collect()[0][0]
428
+ min_id = data.agg(sf.min(column).alias("min_index")).collect()[0][0]
428
429
  else:
429
- min_id = data[column].min() # type: ignore
430
+ min_id = data[column].min()
430
431
  if min_id < 0:
431
- raise ValueError(f"IDs in {source.name}.{column} are not encoded. Min ID is less than 0.")
432
+ msg = f"IDs in {source.name}.{column} are not encoded. Min ID is less than 0."
433
+ raise ValueError(msg)
432
434
 
433
435
  if self.is_pandas:
434
- max_id = data[column].max() # type: ignore
436
+ max_id = data[column].max()
435
437
  elif self.is_spark:
436
- max_id = data.agg(F.max(column).alias("max_index")).collect()[0][0]
438
+ max_id = data.agg(sf.max(column).alias("max_index")).collect()[0][0]
437
439
  else:
438
- max_id = data[column].max() # type: ignore
440
+ max_id = data[column].max()
439
441
 
440
442
  if max_id >= cardinality:
441
- raise ValueError(f"IDs in {source.name}.{column} are not encoded. Max ID is more than quantity of IDs.")
443
+ msg = f"IDs in {source.name}.{column} are not encoded. Max ID is more than quantity of IDs."
444
+ raise ValueError(msg)
442
445
 
443
446
  def _check_encoded(self) -> None:
444
447
  for feature in self.feature_schema.categorical_features.all_features:
@@ -471,11 +474,11 @@ class Dataset:
471
474
  feature.cardinality,
472
475
  )
473
476
  else:
474
- data = self._feature_source_map[feature.feature_source] # type: ignore
477
+ data = self._feature_source_map[feature.feature_source]
475
478
  self._check_column_encoded(
476
479
  data,
477
480
  feature.column,
478
- feature.feature_source, # type: ignore
481
+ feature.feature_source,
479
482
  feature.cardinality,
480
483
  )
481
484
 
@@ -31,6 +31,7 @@ class DatasetLabelEncoder:
31
31
  When set to ``error`` an error will be raised in case an unknown label is present during transform.
32
32
  When set to ``use_default_value``, the encoded value of unknown label will be set
33
33
  to the value given for the parameter default_value.
34
+ When set to ``drop``, the unknown labels will be dropped.
34
35
  Default: ``error``.
35
36
  :param default_value: Default value that will fill the unknown labels after transform.
36
37
  When the parameter handle_unknown is set to ``use_default_value``,
@@ -105,7 +106,7 @@ class DatasetLabelEncoder:
105
106
  for column, feature_info in dataset.feature_schema.categorical_features.items():
106
107
  if column not in self._encoding_rules:
107
108
  warnings.warn(
108
- f"Cannot transform feature '{column}' " "as it was not present at the fit stage",
109
+ f"Cannot transform feature '{column}' as it was not present at the fit stage",
109
110
  LabelEncoderTransformWarning,
110
111
  )
111
112
  continue
@@ -157,10 +158,7 @@ class DatasetLabelEncoder:
157
158
  self._check_if_initialized()
158
159
 
159
160
  columns_set: Set[str]
160
- if isinstance(columns, str):
161
- columns_set = set([columns])
162
- else:
163
- columns_set = set(columns)
161
+ columns_set = {columns} if isinstance(columns, str) else {*columns}
164
162
 
165
163
  def get_encoding_rules() -> Iterator[LabelEncodingRule]:
166
164
  for column, rule in self._encoding_rules.items():
@@ -200,7 +198,7 @@ class DatasetLabelEncoder:
200
198
  """
201
199
  query_id_column = self._features_columns[FeatureHint.QUERY_ID]
202
200
  item_id_column = self._features_columns[FeatureHint.ITEM_ID]
203
- encoder = self.get_encoder(query_id_column + item_id_column) # type: ignore
201
+ encoder = self.get_encoder(query_id_column + item_id_column)
204
202
  assert encoder is not None
205
203
  return encoder
206
204
 
@@ -231,7 +229,8 @@ class DatasetLabelEncoder:
231
229
 
232
230
  def _check_if_initialized(self) -> None:
233
231
  if not self._encoding_rules:
234
- raise ValueError("Encoder is not initialized")
232
+ msg = "Encoder is not initialized"
233
+ raise ValueError(msg)
235
234
 
236
235
  def _fill_features_columns(self, feature_info: FeatureSchema) -> None:
237
236
  self._features_columns = {
@@ -3,7 +3,7 @@ from replay.utils import TORCH_AVAILABLE
3
3
  if TORCH_AVAILABLE:
4
4
  from .schema import MutableTensorMap, TensorFeatureInfo, TensorFeatureSource, TensorMap, TensorSchema
5
5
  from .sequence_tokenizer import SequenceTokenizer
6
- from .sequential_dataset import PandasSequentialDataset, SequentialDataset, PolarsSequentialDataset
6
+ from .sequential_dataset import PandasSequentialDataset, PolarsSequentialDataset, SequentialDataset
7
7
  from .torch_sequential_dataset import (
8
8
  DEFAULT_GROUND_TRUTH_PADDING_VALUE,
9
9
  DEFAULT_TRAIN_PADDING_VALUE,
replay/data/nn/schema.py CHANGED
@@ -11,7 +11,6 @@ from typing import (
11
11
  Set,
12
12
  Union,
13
13
  ValuesView,
14
- Callable
15
14
  )
16
15
 
17
16
  import torch
@@ -23,7 +22,6 @@ TensorMap = Mapping[str, torch.Tensor]
23
22
  MutableTensorMap = Dict[str, torch.Tensor]
24
23
 
25
24
 
26
- # pylint: disable=too-many-instance-attributes
27
25
  class TensorFeatureSource:
28
26
  """
29
27
  Describes source of a feature
@@ -72,7 +70,6 @@ class TensorFeatureInfo:
72
70
  Information about a tensor feature.
73
71
  """
74
72
 
75
- # pylint: disable=too-many-arguments
76
73
  def __init__(
77
74
  self,
78
75
  name: str,
@@ -108,15 +105,18 @@ class TensorFeatureInfo:
108
105
  self._is_seq = is_seq
109
106
 
110
107
  if not isinstance(feature_type, FeatureType):
111
- raise ValueError("Unknown feature type")
108
+ msg = "Unknown feature type"
109
+ raise ValueError(msg)
112
110
  self._feature_type = feature_type
113
111
 
114
112
  if feature_type == FeatureType.NUMERICAL and (cardinality or embedding_dim):
115
- raise ValueError("Cardinality and embedding dimensions are needed only with categorical feature type.")
113
+ msg = "Cardinality and embedding dimensions are needed only with categorical feature type."
114
+ raise ValueError(msg)
116
115
  self._cardinality = cardinality
117
116
 
118
117
  if feature_type == FeatureType.CATEGORICAL and tensor_dim:
119
- raise ValueError("Tensor dimensions is needed only with numerical feature type.")
118
+ msg = "Tensor dimensions is needed only with numerical feature type."
119
+ raise ValueError(msg)
120
120
 
121
121
  if feature_type == FeatureType.CATEGORICAL:
122
122
  default_embedding_dim = 64
@@ -168,7 +168,8 @@ class TensorFeatureInfo:
168
168
  return None
169
169
 
170
170
  if len(source) > 1:
171
- raise ValueError("Only one element feature sources can be converted to single feature source.")
171
+ msg = "Only one element feature sources can be converted to single feature source."
172
+ raise ValueError(msg)
172
173
  assert isinstance(self.feature_sources, list)
173
174
  return self.feature_sources[0]
174
175
 
@@ -199,35 +200,21 @@ class TensorFeatureInfo:
199
200
  :returns: Cardinality of the feature.
200
201
  """
201
202
  if self.feature_type != FeatureType.CATEGORICAL:
202
- raise RuntimeError(
203
- f"Can not get cardinality because feature type of {self.name} column is not categorical."
204
- )
205
- if hasattr(self, "_cardinality_callback") and self._cardinality is None:
206
- self._set_cardinality(self._cardinality_callback(self._name))
203
+ msg = f"Can not get cardinality because feature type of {self.name} column is not categorical."
204
+ raise RuntimeError(msg)
207
205
  return self._cardinality
208
206
 
209
- # pylint: disable=attribute-defined-outside-init
210
- def _set_cardinality_callback(self, callback: Callable) -> None:
211
- self._cardinality_callback = callback
212
-
213
207
  def _set_cardinality(self, cardinality: int) -> None:
214
208
  self._cardinality = cardinality
215
209
 
216
- def reset_cardinality(self) -> None:
217
- """
218
- Reset cardinality of the feature to None.
219
- """
220
- self._cardinality = None
221
-
222
210
  @property
223
211
  def tensor_dim(self) -> Optional[int]:
224
212
  """
225
213
  :returns: Dimensions of the numerical feature.
226
214
  """
227
215
  if self.feature_type != FeatureType.NUMERICAL:
228
- raise RuntimeError(
229
- f"Can not get tensor dimensions because feature type of {self.name} feature is not numerical."
230
- )
216
+ msg = f"Can not get tensor dimensions because feature type of {self.name} feature is not numerical."
217
+ raise RuntimeError(msg)
231
218
  return self._tensor_dim
232
219
 
233
220
  def _set_tensor_dim(self, tensor_dim: int) -> None:
@@ -239,9 +226,8 @@ class TensorFeatureInfo:
239
226
  :returns: Embedding dimensions of the feature.
240
227
  """
241
228
  if self.feature_type != FeatureType.CATEGORICAL:
242
- raise RuntimeError(
243
- f"Can not get embedding dimensions because feature type of {self.name} feature is not categorical."
244
- )
229
+ msg = f"Can not get embedding dimensions because feature type of {self.name} feature is not categorical."
230
+ raise RuntimeError(msg)
245
231
  return self._embedding_dim
246
232
 
247
233
  def _set_embedding_dim(self, embedding_dim: int) -> None:
@@ -278,8 +264,9 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
278
264
  :returns: Extract single feature from a schema.
279
265
  """
280
266
  if len(self._tensor_schema) != 1:
281
- raise ValueError("Only one element tensor schema can be converted to single feature")
282
- return list(self._tensor_schema.values())[0]
267
+ msg = "Only one element tensor schema can be converted to single feature"
268
+ raise ValueError(msg)
269
+ return next(iter(self._tensor_schema.values()))
283
270
 
284
271
  def items(self) -> ItemsView[str, TensorFeatureInfo]:
285
272
  return self._tensor_schema.items()
@@ -290,7 +277,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
290
277
  def values(self) -> ValuesView[TensorFeatureInfo]:
291
278
  return self._tensor_schema.values()
292
279
 
293
- def get( # type: ignore
280
+ def get(
294
281
  self,
295
282
  key: str,
296
283
  default: Optional[TensorFeatureInfo] = None,
@@ -377,7 +364,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
377
364
  @property
378
365
  def names(self) -> Sequence[str]:
379
366
  """
380
- :returns: List of all feature's names.
367
+ :returns: List of all feature's names.
381
368
  """
382
369
  return list(self._tensor_schema)
383
370
 
@@ -447,7 +434,7 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
447
434
  for filtration_func, filtration_param in zip(filter_functions, filter_parameters):
448
435
  filtered_features = list(
449
436
  filter(
450
- lambda x: filtration_func(x, filtration_param), # type: ignore # pylint: disable=W0640
437
+ lambda x: filtration_func(x, filtration_param),
451
438
  filtered_features,
452
439
  )
453
440
  )