replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/__init__.py +1 -1
- replay/data/dataset.py +45 -42
- replay/data/dataset_utils/dataset_label_encoder.py +6 -7
- replay/data/nn/__init__.py +1 -1
- replay/data/nn/schema.py +20 -33
- replay/data/nn/sequence_tokenizer.py +217 -87
- replay/data/nn/sequential_dataset.py +6 -22
- replay/data/nn/torch_sequential_dataset.py +20 -11
- replay/data/nn/utils.py +7 -9
- replay/data/schema.py +17 -17
- replay/data/spark_schema.py +0 -1
- replay/metrics/base_metric.py +38 -79
- replay/metrics/categorical_diversity.py +24 -58
- replay/metrics/coverage.py +25 -49
- replay/metrics/descriptors.py +4 -13
- replay/metrics/experiment.py +3 -8
- replay/metrics/hitrate.py +3 -6
- replay/metrics/map.py +3 -6
- replay/metrics/mrr.py +1 -4
- replay/metrics/ndcg.py +4 -7
- replay/metrics/novelty.py +10 -29
- replay/metrics/offline_metrics.py +26 -61
- replay/metrics/precision.py +3 -6
- replay/metrics/recall.py +3 -6
- replay/metrics/rocauc.py +7 -10
- replay/metrics/surprisal.py +13 -30
- replay/metrics/torch_metrics_builder.py +0 -4
- replay/metrics/unexpectedness.py +15 -20
- replay/models/__init__.py +1 -2
- replay/models/als.py +7 -15
- replay/models/association_rules.py +12 -28
- replay/models/base_neighbour_rec.py +21 -36
- replay/models/base_rec.py +92 -215
- replay/models/cat_pop_rec.py +9 -22
- replay/models/cluster.py +17 -28
- replay/models/extensions/ann/ann_mixin.py +7 -12
- replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
- replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
- replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
- replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
- replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
- replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
- replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
- replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
- replay/models/extensions/ann/index_inferers/utils.py +2 -9
- replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
- replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
- replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
- replay/models/extensions/ann/index_stores/utils.py +5 -2
- replay/models/extensions/ann/utils.py +3 -5
- replay/models/kl_ucb.py +16 -22
- replay/models/knn.py +37 -59
- replay/models/nn/optimizer_utils/__init__.py +1 -6
- replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
- replay/models/nn/sequential/bert4rec/__init__.py +1 -1
- replay/models/nn/sequential/bert4rec/dataset.py +6 -7
- replay/models/nn/sequential/bert4rec/lightning.py +53 -56
- replay/models/nn/sequential/bert4rec/model.py +12 -25
- replay/models/nn/sequential/callbacks/__init__.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
- replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
- replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
- replay/models/nn/sequential/sasrec/dataset.py +8 -7
- replay/models/nn/sequential/sasrec/lightning.py +53 -48
- replay/models/nn/sequential/sasrec/model.py +4 -17
- replay/models/pop_rec.py +9 -10
- replay/models/query_pop_rec.py +7 -15
- replay/models/random_rec.py +10 -18
- replay/models/slim.py +8 -13
- replay/models/thompson_sampling.py +13 -14
- replay/models/ucb.py +11 -22
- replay/models/wilson.py +5 -14
- replay/models/word2vec.py +24 -69
- replay/optimization/optuna_objective.py +13 -27
- replay/preprocessing/__init__.py +1 -2
- replay/preprocessing/converter.py +2 -7
- replay/preprocessing/filters.py +67 -142
- replay/preprocessing/history_based_fp.py +44 -116
- replay/preprocessing/label_encoder.py +106 -68
- replay/preprocessing/sessionizer.py +1 -11
- replay/scenarios/fallback.py +3 -8
- replay/splitters/base_splitter.py +43 -15
- replay/splitters/cold_user_random_splitter.py +18 -31
- replay/splitters/k_folds.py +14 -24
- replay/splitters/last_n_splitter.py +33 -43
- replay/splitters/new_users_splitter.py +31 -55
- replay/splitters/random_splitter.py +16 -23
- replay/splitters/ratio_splitter.py +30 -54
- replay/splitters/time_splitter.py +13 -18
- replay/splitters/two_stage_splitter.py +44 -79
- replay/utils/__init__.py +1 -1
- replay/utils/common.py +65 -0
- replay/utils/dataframe_bucketizer.py +25 -31
- replay/utils/distributions.py +3 -15
- replay/utils/model_handler.py +36 -33
- replay/utils/session_handler.py +11 -15
- replay/utils/spark_utils.py +51 -85
- replay/utils/time.py +8 -22
- replay/utils/types.py +1 -3
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
- replay_rec-0.17.0.dist-info/RECORD +127 -0
- replay_rec-0.16.0.dist-info/RECORD +0 -126
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
- {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
|
@@ -29,7 +29,6 @@ class EmptyFeatureProcessor:
|
|
|
29
29
|
:param features: DataFrame with ``user_idx/item_idx`` and feature columns
|
|
30
30
|
"""
|
|
31
31
|
|
|
32
|
-
# pylint: disable=no-self-use
|
|
33
32
|
def transform(self, log: SparkDataFrame) -> SparkDataFrame:
|
|
34
33
|
"""
|
|
35
34
|
Return log without any transformations
|
|
@@ -74,26 +73,16 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
74
73
|
"""
|
|
75
74
|
prefix = agg_col[:1]
|
|
76
75
|
|
|
77
|
-
aggregates = [
|
|
78
|
-
sf.log(sf.count(sf.col("relevance"))).alias(
|
|
79
|
-
f"{prefix}_log_num_interact"
|
|
80
|
-
)
|
|
81
|
-
]
|
|
76
|
+
aggregates = [sf.log(sf.count(sf.col("relevance"))).alias(f"{prefix}_log_num_interact")]
|
|
82
77
|
|
|
83
78
|
if self.calc_timestamp_based:
|
|
84
79
|
aggregates.extend(
|
|
85
80
|
[
|
|
86
|
-
sf.log(
|
|
87
|
-
|
|
88
|
-
sf.date_trunc("dd", sf.col("timestamp"))
|
|
89
|
-
)
|
|
90
|
-
).alias(f"{prefix}_log_interact_days_count"),
|
|
91
|
-
sf.min(sf.col("timestamp")).alias(
|
|
92
|
-
f"{prefix}_min_interact_date"
|
|
93
|
-
),
|
|
94
|
-
sf.max(sf.col("timestamp")).alias(
|
|
95
|
-
f"{prefix}_max_interact_date"
|
|
81
|
+
sf.log(sf.countDistinct(sf.date_trunc("dd", sf.col("timestamp")))).alias(
|
|
82
|
+
f"{prefix}_log_interact_days_count"
|
|
96
83
|
),
|
|
84
|
+
sf.min(sf.col("timestamp")).alias(f"{prefix}_min_interact_date"),
|
|
85
|
+
sf.max(sf.col("timestamp")).alias(f"{prefix}_max_interact_date"),
|
|
97
86
|
]
|
|
98
87
|
)
|
|
99
88
|
|
|
@@ -102,8 +91,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
102
91
|
[
|
|
103
92
|
(
|
|
104
93
|
sf.when(
|
|
105
|
-
sf.stddev(sf.col("relevance")).isNull()
|
|
106
|
-
| sf.isnan(sf.stddev(sf.col("relevance"))),
|
|
94
|
+
sf.stddev(sf.col("relevance")).isNull() | sf.isnan(sf.stddev(sf.col("relevance"))),
|
|
107
95
|
0,
|
|
108
96
|
)
|
|
109
97
|
.otherwise(sf.stddev(sf.col("relevance")))
|
|
@@ -112,19 +100,15 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
112
100
|
sf.mean(sf.col("relevance")).alias(f"{prefix}_mean"),
|
|
113
101
|
]
|
|
114
102
|
)
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
).alias(f"{prefix}_quantile_{str(percentile)[2:]}")
|
|
120
|
-
)
|
|
103
|
+
aggregates.extend(
|
|
104
|
+
sf.expr(f"percentile_approx(relevance, {percentile})").alias(f"{prefix}_quantile_{str(percentile)[2:]}")
|
|
105
|
+
for percentile in [0.05, 0.5, 0.95]
|
|
106
|
+
)
|
|
121
107
|
|
|
122
108
|
return aggregates
|
|
123
109
|
|
|
124
110
|
@staticmethod
|
|
125
|
-
def _add_ts_based(
|
|
126
|
-
features: SparkDataFrame, max_log_date: datetime, prefix: str
|
|
127
|
-
) -> SparkDataFrame:
|
|
111
|
+
def _add_ts_based(features: SparkDataFrame, max_log_date: datetime, prefix: str) -> SparkDataFrame:
|
|
128
112
|
"""
|
|
129
113
|
Add history length (max - min timestamp) and difference in days between
|
|
130
114
|
last date in log and last interaction of the user/item
|
|
@@ -142,15 +126,11 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
142
126
|
),
|
|
143
127
|
).withColumn(
|
|
144
128
|
f"{prefix}_last_interaction_gap_days",
|
|
145
|
-
sf.datediff(
|
|
146
|
-
sf.lit(max_log_date), sf.col(f"{prefix}_max_interact_date")
|
|
147
|
-
),
|
|
129
|
+
sf.datediff(sf.lit(max_log_date), sf.col(f"{prefix}_max_interact_date")),
|
|
148
130
|
)
|
|
149
131
|
|
|
150
132
|
@staticmethod
|
|
151
|
-
def _cals_cross_interactions_count(
|
|
152
|
-
log: SparkDataFrame, features: SparkDataFrame
|
|
153
|
-
) -> SparkDataFrame:
|
|
133
|
+
def _cals_cross_interactions_count(log: SparkDataFrame, features: SparkDataFrame) -> SparkDataFrame:
|
|
154
134
|
"""
|
|
155
135
|
Calculate difference between the log number of interactions by the user
|
|
156
136
|
and average log number of interactions users interacted with the item has.
|
|
@@ -165,9 +145,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
165
145
|
new_feature_entity, calc_by_entity = "user_idx", "item_idx"
|
|
166
146
|
|
|
167
147
|
mean_log_num_interact = log.join(
|
|
168
|
-
features.select(
|
|
169
|
-
calc_by_entity, f"{calc_by_entity[0]}_log_num_interact"
|
|
170
|
-
),
|
|
148
|
+
features.select(calc_by_entity, f"{calc_by_entity[0]}_log_num_interact"),
|
|
171
149
|
on=calc_by_entity,
|
|
172
150
|
how="left",
|
|
173
151
|
)
|
|
@@ -178,9 +156,7 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
178
156
|
)
|
|
179
157
|
|
|
180
158
|
@staticmethod
|
|
181
|
-
def _calc_abnormality(
|
|
182
|
-
log: SparkDataFrame, item_features: SparkDataFrame
|
|
183
|
-
) -> SparkDataFrame:
|
|
159
|
+
def _calc_abnormality(log: SparkDataFrame, item_features: SparkDataFrame) -> SparkDataFrame:
|
|
184
160
|
"""
|
|
185
161
|
Calculate discrepancy between a rating on a resource
|
|
186
162
|
and the average rating of this resource (Abnormality) and
|
|
@@ -198,13 +174,9 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
198
174
|
on_col_name="item_idx",
|
|
199
175
|
how="left",
|
|
200
176
|
)
|
|
201
|
-
abnormality_df = abnormality_df.withColumn(
|
|
202
|
-
"abnormality", sf.abs(sf.col("relevance") - sf.col("i_mean"))
|
|
203
|
-
)
|
|
177
|
+
abnormality_df = abnormality_df.withColumn("abnormality", sf.abs(sf.col("relevance") - sf.col("i_mean")))
|
|
204
178
|
|
|
205
|
-
abnormality_aggs = [
|
|
206
|
-
sf.mean(sf.col("abnormality")).alias("abnormality")
|
|
207
|
-
]
|
|
179
|
+
abnormality_aggs = [sf.mean(sf.col("abnormality")).alias("abnormality")]
|
|
208
180
|
|
|
209
181
|
# Abnormality CR:
|
|
210
182
|
max_std = item_features.select(sf.max("i_std")).collect()[0][0]
|
|
@@ -212,80 +184,53 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
212
184
|
if max_std - min_std != 0:
|
|
213
185
|
abnormality_df = abnormality_df.withColumn(
|
|
214
186
|
"controversy",
|
|
215
|
-
1
|
|
216
|
-
- (sf.col("i_std") - sf.lit(min_std))
|
|
217
|
-
/ (sf.lit(max_std - min_std)),
|
|
187
|
+
1 - (sf.col("i_std") - sf.lit(min_std)) / (sf.lit(max_std - min_std)),
|
|
218
188
|
)
|
|
219
189
|
abnormality_df = abnormality_df.withColumn(
|
|
220
190
|
"abnormalityCR",
|
|
221
191
|
(sf.col("abnormality") * sf.col("controversy")) ** 2,
|
|
222
192
|
)
|
|
223
|
-
abnormality_aggs.append(
|
|
224
|
-
sf.mean(sf.col("abnormalityCR")).alias("abnormalityCR")
|
|
225
|
-
)
|
|
193
|
+
abnormality_aggs.append(sf.mean(sf.col("abnormalityCR")).alias("abnormalityCR"))
|
|
226
194
|
|
|
227
195
|
return abnormality_df.groupBy("user_idx").agg(*abnormality_aggs)
|
|
228
196
|
|
|
229
|
-
def fit(
|
|
230
|
-
self, log: SparkDataFrame, features: Optional[SparkDataFrame] = None
|
|
231
|
-
) -> None:
|
|
197
|
+
def fit(self, log: SparkDataFrame, features: Optional[SparkDataFrame] = None) -> None: # noqa: ARG002
|
|
232
198
|
"""
|
|
233
199
|
Calculate log-based features for users and items
|
|
234
200
|
|
|
235
201
|
:param log: input SparkDataFrame ``[user_idx, item_idx, timestamp, relevance]``
|
|
236
|
-
:param features: not required
|
|
237
202
|
"""
|
|
238
|
-
self.calc_timestamp_based = (
|
|
239
|
-
|
|
240
|
-
) & (
|
|
241
|
-
log.select(sf.countDistinct(sf.col("timestamp"))).collect()[0][0]
|
|
242
|
-
> 1
|
|
243
|
-
)
|
|
244
|
-
self.calc_relevance_based = (
|
|
245
|
-
log.select(sf.countDistinct(sf.col("relevance"))).collect()[0][0]
|
|
246
|
-
> 1
|
|
203
|
+
self.calc_timestamp_based = (isinstance(log.schema["timestamp"].dataType, TimestampType)) & (
|
|
204
|
+
log.select(sf.countDistinct(sf.col("timestamp"))).collect()[0][0] > 1
|
|
247
205
|
)
|
|
206
|
+
self.calc_relevance_based = log.select(sf.countDistinct(sf.col("relevance"))).collect()[0][0] > 1
|
|
248
207
|
|
|
249
|
-
user_log_features = log.groupBy("user_idx").agg(
|
|
250
|
-
|
|
251
|
-
)
|
|
252
|
-
item_log_features = log.groupBy("item_idx").agg(
|
|
253
|
-
*self._create_log_aggregates(agg_col="item_idx")
|
|
254
|
-
)
|
|
208
|
+
user_log_features = log.groupBy("user_idx").agg(*self._create_log_aggregates(agg_col="user_idx"))
|
|
209
|
+
item_log_features = log.groupBy("item_idx").agg(*self._create_log_aggregates(agg_col="item_idx"))
|
|
255
210
|
|
|
256
211
|
if self.calc_timestamp_based:
|
|
257
212
|
last_date = log.select(sf.max("timestamp")).collect()[0][0]
|
|
258
|
-
user_log_features = self._add_ts_based(
|
|
259
|
-
features=user_log_features, max_log_date=last_date, prefix="u"
|
|
260
|
-
)
|
|
213
|
+
user_log_features = self._add_ts_based(features=user_log_features, max_log_date=last_date, prefix="u")
|
|
261
214
|
|
|
262
|
-
item_log_features = self._add_ts_based(
|
|
263
|
-
features=item_log_features, max_log_date=last_date, prefix="i"
|
|
264
|
-
)
|
|
215
|
+
item_log_features = self._add_ts_based(features=item_log_features, max_log_date=last_date, prefix="i")
|
|
265
216
|
|
|
266
217
|
if self.calc_relevance_based:
|
|
267
218
|
user_log_features = user_log_features.join(
|
|
268
|
-
self._calc_abnormality(
|
|
269
|
-
log=log, item_features=item_log_features
|
|
270
|
-
),
|
|
219
|
+
self._calc_abnormality(log=log, item_features=item_log_features),
|
|
271
220
|
on="user_idx",
|
|
272
221
|
how="left",
|
|
273
222
|
).cache()
|
|
274
223
|
|
|
275
224
|
self.user_log_features = join_with_col_renaming(
|
|
276
225
|
left=user_log_features,
|
|
277
|
-
right=self._cals_cross_interactions_count(
|
|
278
|
-
log=log, features=item_log_features
|
|
279
|
-
),
|
|
226
|
+
right=self._cals_cross_interactions_count(log=log, features=item_log_features),
|
|
280
227
|
on_col_name="user_idx",
|
|
281
228
|
how="left",
|
|
282
229
|
).cache()
|
|
283
230
|
|
|
284
231
|
self.item_log_features = join_with_col_renaming(
|
|
285
232
|
left=item_log_features,
|
|
286
|
-
right=self._cals_cross_interactions_count(
|
|
287
|
-
log=log, features=user_log_features
|
|
288
|
-
),
|
|
233
|
+
right=self._cals_cross_interactions_count(log=log, features=user_log_features),
|
|
289
234
|
on_col_name="item_idx",
|
|
290
235
|
how="left",
|
|
291
236
|
).cache()
|
|
@@ -311,25 +256,15 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
|
|
|
311
256
|
)
|
|
312
257
|
.withColumn(
|
|
313
258
|
"na_u_log_features",
|
|
314
|
-
sf.when(sf.col("u_log_num_interact").isNull(), 1.0).otherwise(
|
|
315
|
-
0.0
|
|
316
|
-
),
|
|
259
|
+
sf.when(sf.col("u_log_num_interact").isNull(), 1.0).otherwise(0.0),
|
|
317
260
|
)
|
|
318
261
|
.withColumn(
|
|
319
262
|
"na_i_log_features",
|
|
320
|
-
sf.when(sf.col("i_log_num_interact").isNull(), 1.0).otherwise(
|
|
321
|
-
0.0
|
|
322
|
-
),
|
|
263
|
+
sf.when(sf.col("i_log_num_interact").isNull(), 1.0).otherwise(0.0),
|
|
323
264
|
)
|
|
324
265
|
# TO DO std и date diff заменяем на inf, date features - будут ли работать корректно?
|
|
325
266
|
# если не заменять, будет ли работать корректно?
|
|
326
|
-
.fillna(
|
|
327
|
-
{
|
|
328
|
-
col_name: 0
|
|
329
|
-
for col_name in self.user_log_features.columns
|
|
330
|
-
+ self.item_log_features.columns
|
|
331
|
-
}
|
|
332
|
-
)
|
|
267
|
+
.fillna({col_name: 0 for col_name in self.user_log_features.columns + self.item_log_features.columns})
|
|
333
268
|
)
|
|
334
269
|
|
|
335
270
|
joined = joined.withColumn(
|
|
@@ -375,19 +310,16 @@ class ConditionalPopularityProcessor(EmptyFeatureProcessor):
|
|
|
375
310
|
:param log: input SparkDataFrame ``[user_idx, item_idx, timestamp, relevance]``
|
|
376
311
|
:param features: SparkDataFrame with ``user_idx/item_idx`` and feature columns
|
|
377
312
|
"""
|
|
378
|
-
if len(
|
|
379
|
-
|
|
380
|
-
) != len(self.cat_features_list):
|
|
381
|
-
raise ValueError(
|
|
313
|
+
if len(set(self.cat_features_list).intersection(features.columns)) != len(self.cat_features_list):
|
|
314
|
+
msg = (
|
|
382
315
|
f"Columns {set(self.cat_features_list).difference(features.columns)} "
|
|
383
316
|
f"defined in `cat_features_list` are absent in features. "
|
|
384
317
|
f"features columns are: {features.columns}."
|
|
385
318
|
)
|
|
319
|
+
raise ValueError(msg)
|
|
386
320
|
|
|
387
321
|
join_col, self.entity_name = (
|
|
388
|
-
("item_idx", "user_idx")
|
|
389
|
-
if "item_idx" in features.columns
|
|
390
|
-
else ("user_idx", "item_idx")
|
|
322
|
+
("item_idx", "user_idx") if "item_idx" in features.columns else ("user_idx", "item_idx")
|
|
391
323
|
)
|
|
392
324
|
|
|
393
325
|
self.conditional_pop_dict = {}
|
|
@@ -400,9 +332,9 @@ class ConditionalPopularityProcessor(EmptyFeatureProcessor):
|
|
|
400
332
|
|
|
401
333
|
for cat_col in self.cat_features_list:
|
|
402
334
|
col_name = f"{self.entity_name[0]}_pop_by_{cat_col}"
|
|
403
|
-
intermediate_df = log_with_features.groupBy(
|
|
404
|
-
|
|
405
|
-
)
|
|
335
|
+
intermediate_df = log_with_features.groupBy(self.entity_name, cat_col).agg(
|
|
336
|
+
sf.count("relevance").alias(col_name)
|
|
337
|
+
)
|
|
406
338
|
intermediate_df = intermediate_df.join(
|
|
407
339
|
sf.broadcast(count_by_entity_col),
|
|
408
340
|
on=self.entity_name,
|
|
@@ -447,7 +379,6 @@ class ConditionalPopularityProcessor(EmptyFeatureProcessor):
|
|
|
447
379
|
unpersist_if_exists(df)
|
|
448
380
|
|
|
449
381
|
|
|
450
|
-
# pylint: disable=too-many-instance-attributes, too-many-arguments
|
|
451
382
|
class HistoryBasedFeaturesProcessor:
|
|
452
383
|
"""
|
|
453
384
|
Calculate user and item features based on interactions history (log).
|
|
@@ -484,13 +415,9 @@ class HistoryBasedFeaturesProcessor:
|
|
|
484
415
|
|
|
485
416
|
if use_conditional_popularity and user_cat_features_list:
|
|
486
417
|
if user_cat_features_list:
|
|
487
|
-
self.user_cond_pop_proc = ConditionalPopularityProcessor(
|
|
488
|
-
cat_features_list=user_cat_features_list
|
|
489
|
-
)
|
|
418
|
+
self.user_cond_pop_proc = ConditionalPopularityProcessor(cat_features_list=user_cat_features_list)
|
|
490
419
|
if item_cat_features_list:
|
|
491
|
-
self.item_cond_pop_proc = ConditionalPopularityProcessor(
|
|
492
|
-
cat_features_list=item_cat_features_list
|
|
493
|
-
)
|
|
420
|
+
self.item_cond_pop_proc = ConditionalPopularityProcessor(cat_features_list=item_cat_features_list)
|
|
494
421
|
self.fitted: bool = False
|
|
495
422
|
|
|
496
423
|
def fit(
|
|
@@ -524,7 +451,8 @@ class HistoryBasedFeaturesProcessor:
|
|
|
524
451
|
:return: augmented SparkDataFrame
|
|
525
452
|
"""
|
|
526
453
|
if not self.fitted:
|
|
527
|
-
|
|
454
|
+
msg = "Call fit before running transform"
|
|
455
|
+
raise AttributeError(msg)
|
|
528
456
|
joined = self.log_processor.transform(log)
|
|
529
457
|
joined = self.user_cond_pop_proc.transform(joined)
|
|
530
458
|
joined = self.item_cond_pop_proc.transform(joined)
|