replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -29,7 +29,6 @@ class EmptyFeatureProcessor:
29
29
  :param features: DataFrame with ``user_idx/item_idx`` and feature columns
30
30
  """
31
31
 
32
- # pylint: disable=no-self-use
33
32
  def transform(self, log: SparkDataFrame) -> SparkDataFrame:
34
33
  """
35
34
  Return log without any transformations
@@ -74,26 +73,16 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
74
73
  """
75
74
  prefix = agg_col[:1]
76
75
 
77
- aggregates = [
78
- sf.log(sf.count(sf.col("relevance"))).alias(
79
- f"{prefix}_log_num_interact"
80
- )
81
- ]
76
+ aggregates = [sf.log(sf.count(sf.col("relevance"))).alias(f"{prefix}_log_num_interact")]
82
77
 
83
78
  if self.calc_timestamp_based:
84
79
  aggregates.extend(
85
80
  [
86
- sf.log(
87
- sf.countDistinct(
88
- sf.date_trunc("dd", sf.col("timestamp"))
89
- )
90
- ).alias(f"{prefix}_log_interact_days_count"),
91
- sf.min(sf.col("timestamp")).alias(
92
- f"{prefix}_min_interact_date"
93
- ),
94
- sf.max(sf.col("timestamp")).alias(
95
- f"{prefix}_max_interact_date"
81
+ sf.log(sf.countDistinct(sf.date_trunc("dd", sf.col("timestamp")))).alias(
82
+ f"{prefix}_log_interact_days_count"
96
83
  ),
84
+ sf.min(sf.col("timestamp")).alias(f"{prefix}_min_interact_date"),
85
+ sf.max(sf.col("timestamp")).alias(f"{prefix}_max_interact_date"),
97
86
  ]
98
87
  )
99
88
 
@@ -102,8 +91,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
102
91
  [
103
92
  (
104
93
  sf.when(
105
- sf.stddev(sf.col("relevance")).isNull()
106
- | sf.isnan(sf.stddev(sf.col("relevance"))),
94
+ sf.stddev(sf.col("relevance")).isNull() | sf.isnan(sf.stddev(sf.col("relevance"))),
107
95
  0,
108
96
  )
109
97
  .otherwise(sf.stddev(sf.col("relevance")))
@@ -112,19 +100,15 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
112
100
  sf.mean(sf.col("relevance")).alias(f"{prefix}_mean"),
113
101
  ]
114
102
  )
115
- for percentile in [0.05, 0.5, 0.95]:
116
- aggregates.append(
117
- sf.expr(
118
- f"percentile_approx(relevance, {percentile})"
119
- ).alias(f"{prefix}_quantile_{str(percentile)[2:]}")
120
- )
103
+ aggregates.extend(
104
+ sf.expr(f"percentile_approx(relevance, {percentile})").alias(f"{prefix}_quantile_{str(percentile)[2:]}")
105
+ for percentile in [0.05, 0.5, 0.95]
106
+ )
121
107
 
122
108
  return aggregates
123
109
 
124
110
  @staticmethod
125
- def _add_ts_based(
126
- features: SparkDataFrame, max_log_date: datetime, prefix: str
127
- ) -> SparkDataFrame:
111
+ def _add_ts_based(features: SparkDataFrame, max_log_date: datetime, prefix: str) -> SparkDataFrame:
128
112
  """
129
113
  Add history length (max - min timestamp) and difference in days between
130
114
  last date in log and last interaction of the user/item
@@ -142,15 +126,11 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
142
126
  ),
143
127
  ).withColumn(
144
128
  f"{prefix}_last_interaction_gap_days",
145
- sf.datediff(
146
- sf.lit(max_log_date), sf.col(f"{prefix}_max_interact_date")
147
- ),
129
+ sf.datediff(sf.lit(max_log_date), sf.col(f"{prefix}_max_interact_date")),
148
130
  )
149
131
 
150
132
  @staticmethod
151
- def _cals_cross_interactions_count(
152
- log: SparkDataFrame, features: SparkDataFrame
153
- ) -> SparkDataFrame:
133
+ def _cals_cross_interactions_count(log: SparkDataFrame, features: SparkDataFrame) -> SparkDataFrame:
154
134
  """
155
135
  Calculate difference between the log number of interactions by the user
156
136
  and average log number of interactions users interacted with the item has.
@@ -165,9 +145,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
165
145
  new_feature_entity, calc_by_entity = "user_idx", "item_idx"
166
146
 
167
147
  mean_log_num_interact = log.join(
168
- features.select(
169
- calc_by_entity, f"{calc_by_entity[0]}_log_num_interact"
170
- ),
148
+ features.select(calc_by_entity, f"{calc_by_entity[0]}_log_num_interact"),
171
149
  on=calc_by_entity,
172
150
  how="left",
173
151
  )
@@ -178,9 +156,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
178
156
  )
179
157
 
180
158
  @staticmethod
181
- def _calc_abnormality(
182
- log: SparkDataFrame, item_features: SparkDataFrame
183
- ) -> SparkDataFrame:
159
+ def _calc_abnormality(log: SparkDataFrame, item_features: SparkDataFrame) -> SparkDataFrame:
184
160
  """
185
161
  Calculate discrepancy between a rating on a resource
186
162
  and the average rating of this resource (Abnormality) and
@@ -198,13 +174,9 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
198
174
  on_col_name="item_idx",
199
175
  how="left",
200
176
  )
201
- abnormality_df = abnormality_df.withColumn(
202
- "abnormality", sf.abs(sf.col("relevance") - sf.col("i_mean"))
203
- )
177
+ abnormality_df = abnormality_df.withColumn("abnormality", sf.abs(sf.col("relevance") - sf.col("i_mean")))
204
178
 
205
- abnormality_aggs = [
206
- sf.mean(sf.col("abnormality")).alias("abnormality")
207
- ]
179
+ abnormality_aggs = [sf.mean(sf.col("abnormality")).alias("abnormality")]
208
180
 
209
181
  # Abnormality CR:
210
182
  max_std = item_features.select(sf.max("i_std")).collect()[0][0]
@@ -212,80 +184,53 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
212
184
  if max_std - min_std != 0:
213
185
  abnormality_df = abnormality_df.withColumn(
214
186
  "controversy",
215
- 1
216
- - (sf.col("i_std") - sf.lit(min_std))
217
- / (sf.lit(max_std - min_std)),
187
+ 1 - (sf.col("i_std") - sf.lit(min_std)) / (sf.lit(max_std - min_std)),
218
188
  )
219
189
  abnormality_df = abnormality_df.withColumn(
220
190
  "abnormalityCR",
221
191
  (sf.col("abnormality") * sf.col("controversy")) ** 2,
222
192
  )
223
- abnormality_aggs.append(
224
- sf.mean(sf.col("abnormalityCR")).alias("abnormalityCR")
225
- )
193
+ abnormality_aggs.append(sf.mean(sf.col("abnormalityCR")).alias("abnormalityCR"))
226
194
 
227
195
  return abnormality_df.groupBy("user_idx").agg(*abnormality_aggs)
228
196
 
229
- def fit(
230
- self, log: SparkDataFrame, features: Optional[SparkDataFrame] = None
231
- ) -> None:
197
+ def fit(self, log: SparkDataFrame, features: Optional[SparkDataFrame] = None) -> None: # noqa: ARG002
232
198
  """
233
199
  Calculate log-based features for users and items
234
200
 
235
201
  :param log: input SparkDataFrame ``[user_idx, item_idx, timestamp, relevance]``
236
- :param features: not required
237
202
  """
238
- self.calc_timestamp_based = (
239
- isinstance(log.schema["timestamp"].dataType, TimestampType)
240
- ) & (
241
- log.select(sf.countDistinct(sf.col("timestamp"))).collect()[0][0]
242
- > 1
243
- )
244
- self.calc_relevance_based = (
245
- log.select(sf.countDistinct(sf.col("relevance"))).collect()[0][0]
246
- > 1
203
+ self.calc_timestamp_based = (isinstance(log.schema["timestamp"].dataType, TimestampType)) & (
204
+ log.select(sf.countDistinct(sf.col("timestamp"))).collect()[0][0] > 1
247
205
  )
206
+ self.calc_relevance_based = log.select(sf.countDistinct(sf.col("relevance"))).collect()[0][0] > 1
248
207
 
249
- user_log_features = log.groupBy("user_idx").agg(
250
- *self._create_log_aggregates(agg_col="user_idx")
251
- )
252
- item_log_features = log.groupBy("item_idx").agg(
253
- *self._create_log_aggregates(agg_col="item_idx")
254
- )
208
+ user_log_features = log.groupBy("user_idx").agg(*self._create_log_aggregates(agg_col="user_idx"))
209
+ item_log_features = log.groupBy("item_idx").agg(*self._create_log_aggregates(agg_col="item_idx"))
255
210
 
256
211
  if self.calc_timestamp_based:
257
212
  last_date = log.select(sf.max("timestamp")).collect()[0][0]
258
- user_log_features = self._add_ts_based(
259
- features=user_log_features, max_log_date=last_date, prefix="u"
260
- )
213
+ user_log_features = self._add_ts_based(features=user_log_features, max_log_date=last_date, prefix="u")
261
214
 
262
- item_log_features = self._add_ts_based(
263
- features=item_log_features, max_log_date=last_date, prefix="i"
264
- )
215
+ item_log_features = self._add_ts_based(features=item_log_features, max_log_date=last_date, prefix="i")
265
216
 
266
217
  if self.calc_relevance_based:
267
218
  user_log_features = user_log_features.join(
268
- self._calc_abnormality(
269
- log=log, item_features=item_log_features
270
- ),
219
+ self._calc_abnormality(log=log, item_features=item_log_features),
271
220
  on="user_idx",
272
221
  how="left",
273
222
  ).cache()
274
223
 
275
224
  self.user_log_features = join_with_col_renaming(
276
225
  left=user_log_features,
277
- right=self._cals_cross_interactions_count(
278
- log=log, features=item_log_features
279
- ),
226
+ right=self._cals_cross_interactions_count(log=log, features=item_log_features),
280
227
  on_col_name="user_idx",
281
228
  how="left",
282
229
  ).cache()
283
230
 
284
231
  self.item_log_features = join_with_col_renaming(
285
232
  left=item_log_features,
286
- right=self._cals_cross_interactions_count(
287
- log=log, features=user_log_features
288
- ),
233
+ right=self._cals_cross_interactions_count(log=log, features=user_log_features),
289
234
  on_col_name="item_idx",
290
235
  how="left",
291
236
  ).cache()
@@ -311,25 +256,15 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
311
256
  )
312
257
  .withColumn(
313
258
  "na_u_log_features",
314
- sf.when(sf.col("u_log_num_interact").isNull(), 1.0).otherwise(
315
- 0.0
316
- ),
259
+ sf.when(sf.col("u_log_num_interact").isNull(), 1.0).otherwise(0.0),
317
260
  )
318
261
  .withColumn(
319
262
  "na_i_log_features",
320
- sf.when(sf.col("i_log_num_interact").isNull(), 1.0).otherwise(
321
- 0.0
322
- ),
263
+ sf.when(sf.col("i_log_num_interact").isNull(), 1.0).otherwise(0.0),
323
264
  )
324
265
  # TO DO std и date diff заменяем на inf, date features - будут ли работать корректно?
325
266
  # если не заменять, будет ли работать корректно?
326
- .fillna(
327
- {
328
- col_name: 0
329
- for col_name in self.user_log_features.columns
330
- + self.item_log_features.columns
331
- }
332
- )
267
+ .fillna({col_name: 0 for col_name in self.user_log_features.columns + self.item_log_features.columns})
333
268
  )
334
269
 
335
270
  joined = joined.withColumn(
@@ -375,19 +310,16 @@ class ConditionalPopularityProcessor(EmptyFeatureProcessor):
375
310
  :param log: input SparkDataFrame ``[user_idx, item_idx, timestamp, relevance]``
376
311
  :param features: SparkDataFrame with ``user_idx/item_idx`` and feature columns
377
312
  """
378
- if len(
379
- set(self.cat_features_list).intersection(features.columns)
380
- ) != len(self.cat_features_list):
381
- raise ValueError(
313
+ if len(set(self.cat_features_list).intersection(features.columns)) != len(self.cat_features_list):
314
+ msg = (
382
315
  f"Columns {set(self.cat_features_list).difference(features.columns)} "
383
316
  f"defined in `cat_features_list` are absent in features. "
384
317
  f"features columns are: {features.columns}."
385
318
  )
319
+ raise ValueError(msg)
386
320
 
387
321
  join_col, self.entity_name = (
388
- ("item_idx", "user_idx")
389
- if "item_idx" in features.columns
390
- else ("user_idx", "item_idx")
322
+ ("item_idx", "user_idx") if "item_idx" in features.columns else ("user_idx", "item_idx")
391
323
  )
392
324
 
393
325
  self.conditional_pop_dict = {}
@@ -400,9 +332,9 @@ class ConditionalPopularityProcessor(EmptyFeatureProcessor):
400
332
 
401
333
  for cat_col in self.cat_features_list:
402
334
  col_name = f"{self.entity_name[0]}_pop_by_{cat_col}"
403
- intermediate_df = log_with_features.groupBy(
404
- self.entity_name, cat_col
405
- ).agg(sf.count("relevance").alias(col_name))
335
+ intermediate_df = log_with_features.groupBy(self.entity_name, cat_col).agg(
336
+ sf.count("relevance").alias(col_name)
337
+ )
406
338
  intermediate_df = intermediate_df.join(
407
339
  sf.broadcast(count_by_entity_col),
408
340
  on=self.entity_name,
@@ -447,7 +379,6 @@ class ConditionalPopularityProcessor(EmptyFeatureProcessor):
447
379
  unpersist_if_exists(df)
448
380
 
449
381
 
450
- # pylint: disable=too-many-instance-attributes, too-many-arguments
451
382
  class HistoryBasedFeaturesProcessor:
452
383
  """
453
384
  Calculate user and item features based on interactions history (log).
@@ -484,13 +415,9 @@ class HistoryBasedFeaturesProcessor:
484
415
 
485
416
  if use_conditional_popularity and user_cat_features_list:
486
417
  if user_cat_features_list:
487
- self.user_cond_pop_proc = ConditionalPopularityProcessor(
488
- cat_features_list=user_cat_features_list
489
- )
418
+ self.user_cond_pop_proc = ConditionalPopularityProcessor(cat_features_list=user_cat_features_list)
490
419
  if item_cat_features_list:
491
- self.item_cond_pop_proc = ConditionalPopularityProcessor(
492
- cat_features_list=item_cat_features_list
493
- )
420
+ self.item_cond_pop_proc = ConditionalPopularityProcessor(cat_features_list=item_cat_features_list)
494
421
  self.fitted: bool = False
495
422
 
496
423
  def fit(
@@ -524,7 +451,8 @@ class HistoryBasedFeaturesProcessor:
524
451
  :return: augmented SparkDataFrame
525
452
  """
526
453
  if not self.fitted:
527
- raise AttributeError("Call fit before running transform")
454
+ msg = "Call fit before running transform"
455
+ raise AttributeError(msg)
528
456
  joined = self.log_processor.transform(log)
529
457
  joined = self.user_cond_pop_proc.transform(joined)
530
458
  joined = self.item_cond_pop_proc.transform(joined)