replay-rec 0.17.1rc0__py3-none-any.whl → 0.18.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. replay/__init__.py +2 -1
  2. replay/data/dataset.py +3 -2
  3. replay/data/dataset_utils/dataset_label_encoder.py +1 -0
  4. replay/data/nn/schema.py +5 -5
  5. replay/experimental/metrics/__init__.py +1 -0
  6. replay/experimental/metrics/base_metric.py +1 -0
  7. replay/experimental/models/base_rec.py +7 -7
  8. replay/experimental/models/cql.py +2 -0
  9. replay/experimental/models/ddpg.py +6 -4
  10. replay/experimental/models/lightfm_wrap.py +2 -2
  11. replay/experimental/models/mult_vae.py +1 -0
  12. replay/experimental/models/neuromf.py +1 -0
  13. replay/experimental/models/scala_als.py +2 -2
  14. replay/experimental/preprocessing/data_preparator.py +2 -1
  15. replay/experimental/preprocessing/padder.py +1 -1
  16. replay/experimental/scenarios/two_stages/two_stages_scenario.py +1 -1
  17. replay/experimental/utils/model_handler.py +7 -2
  18. replay/metrics/__init__.py +1 -0
  19. replay/models/als.py +1 -1
  20. replay/models/base_rec.py +7 -7
  21. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +3 -3
  22. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +3 -3
  23. replay/models/nn/sequential/bert4rec/model.py +5 -112
  24. replay/models/nn/sequential/sasrec/model.py +8 -5
  25. replay/optimization/optuna_objective.py +1 -0
  26. replay/preprocessing/converter.py +1 -1
  27. replay/preprocessing/filters.py +19 -18
  28. replay/preprocessing/history_based_fp.py +5 -5
  29. replay/preprocessing/label_encoder.py +1 -0
  30. replay/scenarios/__init__.py +1 -0
  31. replay/splitters/last_n_splitter.py +1 -1
  32. replay/splitters/time_splitter.py +1 -1
  33. replay/splitters/two_stage_splitter.py +8 -6
  34. replay/utils/distributions.py +1 -0
  35. replay/utils/session_handler.py +3 -3
  36. replay/utils/spark_utils.py +2 -2
  37. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0rc0.dist-info}/METADATA +13 -11
  38. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0rc0.dist-info}/RECORD +41 -41
  39. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0rc0.dist-info}/LICENSE +0 -0
  40. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0rc0.dist-info}/NOTICE +0 -0
  41. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0rc0.dist-info}/WHEEL +0 -0
replay/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
1
  """ RecSys library """
2
- __version__ = "0.17.1.preview"
2
+
3
+ __version__ = "0.18.0.preview"
replay/data/dataset.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  ``Dataset`` universal dataset class for manipulating interactions and feed data to models.
3
3
  """
4
+
4
5
  from __future__ import annotations
5
6
 
6
7
  import json
@@ -606,7 +607,7 @@ class Dataset:
606
607
  if self.is_pandas:
607
608
  min_id = data[column].min()
608
609
  elif self.is_spark:
609
- min_id = data.agg(sf.min(column).alias("min_index")).collect()[0][0]
610
+ min_id = data.agg(sf.min(column).alias("min_index")).first()[0]
610
611
  else:
611
612
  min_id = data[column].min()
612
613
  if min_id < 0:
@@ -616,7 +617,7 @@ class Dataset:
616
617
  if self.is_pandas:
617
618
  max_id = data[column].max()
618
619
  elif self.is_spark:
619
- max_id = data.agg(sf.max(column).alias("max_index")).collect()[0][0]
620
+ max_id = data.agg(sf.max(column).alias("max_index")).first()[0]
620
621
  else:
621
622
  max_id = data[column].max()
622
623
 
@@ -4,6 +4,7 @@ Contains classes for encoding categorical data
4
4
  ``LabelEncoderTransformWarning`` new category of warning for DatasetLabelEncoder.
5
5
  ``DatasetLabelEncoder`` to encode categorical features in `Dataset` objects.
6
6
  """
7
+
7
8
  import warnings
8
9
  from typing import Dict, Iterable, Iterator, Optional, Sequence, Set, Union
9
10
 
replay/data/nn/schema.py CHANGED
@@ -418,11 +418,11 @@ class TensorSchema(Mapping[str, TensorFeatureInfo]):
418
418
  "feature_type": feature.feature_type.name,
419
419
  "is_seq": feature.is_seq,
420
420
  "feature_hint": feature.feature_hint.name if feature.feature_hint else None,
421
- "feature_sources": [
422
- {"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources
423
- ]
424
- if feature.feature_sources
425
- else None,
421
+ "feature_sources": (
422
+ [{"source": x.source.name, "column": x.column, "index": x.index} for x in feature.feature_sources]
423
+ if feature.feature_sources
424
+ else None
425
+ ),
426
426
  "cardinality": feature.cardinality if feature.feature_type == FeatureType.CATEGORICAL else None,
427
427
  "embedding_dim": feature.embedding_dim if feature.feature_type == FeatureType.CATEGORICAL else None,
428
428
  "tensor_dim": feature.tensor_dim if feature.feature_type == FeatureType.NUMERICAL else None,
@@ -47,6 +47,7 @@ For each metric, a formula for its calculation is given, because this is
47
47
  important for the correct comparison of algorithms, as mentioned in our
48
48
  `article <https://arxiv.org/abs/2206.12858>`_.
49
49
  """
50
+
50
51
  from replay.experimental.metrics.base_metric import Metric, NCISMetric
51
52
  from replay.experimental.metrics.coverage import Coverage
52
53
  from replay.experimental.metrics.hitrate import HitRate
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Base classes for quality and diversity metrics.
3
3
  """
4
+
4
5
  import logging
5
6
  from abc import ABC, abstractmethod
6
7
  from typing import Dict, List, Optional, Union
@@ -86,8 +86,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
86
86
  self.fit_items = sf.broadcast(items)
87
87
  self._num_users = self.fit_users.count()
88
88
  self._num_items = self.fit_items.count()
89
- self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).collect()[0][0] + 1
90
- self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).collect()[0][0] + 1
89
+ self._user_dim_size = self.fit_users.agg({"user_idx": "max"}).first()[0] + 1
90
+ self._item_dim_size = self.fit_items.agg({"item_idx": "max"}).first()[0] + 1
91
91
  self._fit(log, user_features, item_features)
92
92
 
93
93
  @abstractmethod
@@ -122,7 +122,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
122
122
  # count maximal number of items seen by users
123
123
  max_seen = 0
124
124
  if num_seen.count() > 0:
125
- max_seen = num_seen.select(sf.max("seen_count")).collect()[0][0]
125
+ max_seen = num_seen.select(sf.max("seen_count")).first()[0]
126
126
 
127
127
  # crop recommendations to first k + max_seen items for each user
128
128
  recs = recs.withColumn(
@@ -335,7 +335,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
335
335
  setattr(
336
336
  self,
337
337
  f"_{entity}_dim_size",
338
- getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).collect()[0][0] + 1,
338
+ getattr(self, f"fit_{entity}s").agg({f"{entity}_idx": "max"}).first()[0] + 1,
339
339
  )
340
340
  return getattr(self, f"_{entity}_dim_size")
341
341
 
@@ -1088,7 +1088,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1088
1088
  Calculating a fill value a the minimal relevance
1089
1089
  calculated during model training multiplied by weight.
1090
1090
  """
1091
- return item_popularity.select(sf.min("relevance")).collect()[0][0] * weight
1091
+ return item_popularity.select(sf.min("relevance")).first()[0] * weight
1092
1092
 
1093
1093
  @staticmethod
1094
1094
  def _check_relevance(log: SparkDataFrame):
@@ -1113,7 +1113,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1113
1113
  max_hist_len = (
1114
1114
  (log.join(users, on="user_idx").groupBy("user_idx").agg(sf.countDistinct("item_idx").alias("items_count")))
1115
1115
  .select(sf.max("items_count"))
1116
- .collect()[0][0]
1116
+ .first()[0]
1117
1117
  )
1118
1118
  # all users have empty history
1119
1119
  if max_hist_len is None:
@@ -1146,7 +1146,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1146
1146
  users = users.join(user_to_num_items, on="user_idx", how="left")
1147
1147
  users = users.fillna(0, "num_items")
1148
1148
  # 'selected_item_popularity' truncation by k + max_seen
1149
- max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).collect()[0][0]
1149
+ max_seen = users.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1150
1150
  selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1151
1151
  return users.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1152
1152
 
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Using CQL implementation from `d3rlpy` package.
3
3
  """
4
+
4
5
  import io
5
6
  import logging
6
7
  import tempfile
@@ -402,6 +403,7 @@ class MdpDatasetBuilder:
402
403
  top_k (int): the number of top user items to learn predicting.
403
404
  action_randomization_scale (float): the scale of action randomization gaussian noise.
404
405
  """
406
+
405
407
  logger: logging.Logger
406
408
  top_k: int
407
409
  action_randomization_scale: float
@@ -704,13 +704,15 @@ class DDPG(Recommender):
704
704
  :param data: pandas DataFrame
705
705
  """
706
706
  data = data[["user_idx", "item_idx", "relevance"]]
707
- train_data = data.values.tolist()
707
+ users = data["user_idx"].values.tolist()
708
+ items = data["item_idx"].values.tolist()
709
+ scores = data["relevance"].values.tolist()
708
710
 
709
- user_num = data["user_idx"].max() + 1
710
- item_num = data["item_idx"].max() + 1
711
+ user_num = max(users) + 1
712
+ item_num = max(items) + 1
711
713
 
712
714
  train_mat = defaultdict(float)
713
- for user, item, rel in train_data:
715
+ for user, item, rel in zip(users, items, scores):
714
716
  train_mat[user, item] = rel
715
717
  train_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
716
718
  dict.update(train_matrix, train_mat)
@@ -98,12 +98,12 @@ class LightFMWrap(HybridRecommender):
98
98
  fit_dim = getattr(self, f"_{entity}_dim")
99
99
  matrix_height = max(
100
100
  fit_dim,
101
- log_ids_list.select(sf.max(idx_col_name)).collect()[0][0] + 1,
101
+ log_ids_list.select(sf.max(idx_col_name)).first()[0] + 1,
102
102
  )
103
103
  if not feature_table.rdd.isEmpty():
104
104
  matrix_height = max(
105
105
  matrix_height,
106
- feature_table.select(sf.max(idx_col_name)).collect()[0][0] + 1,
106
+ feature_table.select(sf.max(idx_col_name)).first()[0] + 1,
107
107
  )
108
108
 
109
109
  features_np = (
@@ -2,6 +2,7 @@
2
2
  MultVAE implementation
3
3
  (Variational Autoencoders for Collaborative Filtering)
4
4
  """
5
+
5
6
  from typing import Optional, Tuple
6
7
 
7
8
  import numpy as np
@@ -3,6 +3,7 @@ Generalized Matrix Factorization (GMF),
3
3
  Multi-Layer Perceptron (MLP),
4
4
  Neural Matrix Factorization (MLP + GMF).
5
5
  """
6
+
6
7
  from typing import List, Optional
7
8
 
8
9
  import numpy as np
@@ -115,7 +115,7 @@ class ALSWrap(Recommender, ItemVectorModel):
115
115
  .groupBy("user_idx")
116
116
  .agg(sf.count("user_idx").alias("num_seen"))
117
117
  .select(sf.max("num_seen"))
118
- .collect()[0][0]
118
+ .first()[0]
119
119
  )
120
120
  max_seen = max_seen_in_log if max_seen_in_log is not None else 0
121
121
 
@@ -280,7 +280,7 @@ class ScalaALSWrap(ALSWrap, ANNMixin):
280
280
  .groupBy("user_idx")
281
281
  .agg(sf.count("user_idx").alias("num_seen"))
282
282
  .select(sf.max("num_seen"))
283
- .collect()[0][0]
283
+ .first()[0]
284
284
  )
285
285
  max_seen = max_seen_in_log if max_seen_in_log is not None else 0
286
286
 
@@ -6,6 +6,7 @@ Contains classes for data preparation and categorical features transformation.
6
6
  ``ToNumericFeatureTransformer`` leaves only numerical features
7
7
  by one-hot encoding of some features and deleting the others.
8
8
  """
9
+
9
10
  import json
10
11
  import logging
11
12
  import string
@@ -699,7 +700,7 @@ if PYSPARK_AVAILABLE:
699
700
  return
700
701
 
701
702
  cat_feat_values_dict = {
702
- name: (spark_df.select(sf.collect_set(sf.col(name))).collect()[0][0]) for name in self.cat_cols_list
703
+ name: (spark_df.select(sf.collect_set(sf.col(name))).first()[0]) for name in self.cat_cols_list
703
704
  }
704
705
  self.expressions_list = [
705
706
  sf.when(sf.col(col_name) == cur_name, 1)
@@ -179,7 +179,7 @@ class Padder:
179
179
  self, df_transformed: SparkDataFrame, col: str, pad_value: Union[str, float, List, None]
180
180
  ) -> SparkDataFrame:
181
181
  if self.array_size == -1:
182
- max_array_size = df_transformed.agg(sf.max(sf.size(col)).alias("max_array_len")).collect()[0][0]
182
+ max_array_size = df_transformed.agg(sf.max(sf.size(col)).alias("max_array_len")).first()[0]
183
183
  else:
184
184
  max_array_size = self.array_size
185
185
 
@@ -383,7 +383,7 @@ class TwoStagesScenario(HybridRecommender):
383
383
  log_to_filter_cached.groupBy("user_idx")
384
384
  .agg(sf.count("item_idx").alias("num_positives"))
385
385
  .select(sf.max("num_positives"))
386
- .collect()[0][0]
386
+ .first()[0]
387
387
  )
388
388
 
389
389
  pred = model._predict(
@@ -170,8 +170,13 @@ def load_indexer(path: str) -> Indexer:
170
170
 
171
171
  indexer = Indexer(**args)
172
172
 
173
- indexer.user_type = getattr(st, user_type)()
174
- indexer.item_type = getattr(st, item_type)()
173
+ if user_type.endswith("()"):
174
+ user_type = user_type[:-2]
175
+ item_type = item_type[:-2]
176
+ user_type = getattr(st, user_type)
177
+ item_type = getattr(st, item_type)
178
+ indexer.user_type = user_type()
179
+ indexer.item_type = item_type()
175
180
 
176
181
  indexer.user_indexer = StringIndexerModel.load(join(path, "user_indexer"))
177
182
  indexer.item_indexer = StringIndexerModel.load(join(path, "item_indexer"))
@@ -42,6 +42,7 @@ For each metric, a formula for its calculation is given, because this is
42
42
  important for the correct comparison of algorithms, as mentioned in our
43
43
  `article <https://arxiv.org/abs/2206.12858>`_.
44
44
  """
45
+
45
46
  from .base_metric import Metric
46
47
  from .categorical_diversity import CategoricalDiversity
47
48
  from .coverage import Coverage
replay/models/als.py CHANGED
@@ -115,7 +115,7 @@ class ALSWrap(Recommender, ItemVectorModel):
115
115
  .groupBy(self.query_column)
116
116
  .agg(sf.count(self.query_column).alias("num_seen"))
117
117
  .select(sf.max("num_seen"))
118
- .collect()[0][0]
118
+ .first()[0]
119
119
  )
120
120
  max_seen = max_seen_in_interactions if max_seen_in_interactions is not None else 0
121
121
 
replay/models/base_rec.py CHANGED
@@ -401,8 +401,8 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
401
401
  self.fit_items = sf.broadcast(items)
402
402
  self._num_queries = self.fit_queries.count()
403
403
  self._num_items = self.fit_items.count()
404
- self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).collect()[0][0] + 1
405
- self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).collect()[0][0] + 1
404
+ self._query_dim_size = self.fit_queries.agg({self.query_column: "max"}).first()[0] + 1
405
+ self._item_dim_size = self.fit_items.agg({self.item_column: "max"}).first()[0] + 1
406
406
  self._fit(dataset)
407
407
 
408
408
  @abstractmethod
@@ -431,7 +431,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
431
431
  # count maximal number of items seen by queries
432
432
  max_seen = 0
433
433
  if num_seen.count() > 0:
434
- max_seen = num_seen.select(sf.max("seen_count")).collect()[0][0]
434
+ max_seen = num_seen.select(sf.max("seen_count")).first()[0]
435
435
 
436
436
  # crop recommendations to first k + max_seen items for each query
437
437
  recs = recs.withColumn(
@@ -708,7 +708,7 @@ class BaseRecommender(RecommenderCommons, IsSavable, ABC):
708
708
  setattr(
709
709
  self,
710
710
  dim_size,
711
- fit_entities.agg({column: "max"}).collect()[0][0] + 1,
711
+ fit_entities.agg({column: "max"}).first()[0] + 1,
712
712
  )
713
713
  return getattr(self, dim_size)
714
714
 
@@ -1426,7 +1426,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1426
1426
  Calculating a fill value a the minimal rating
1427
1427
  calculated during model training multiplied by weight.
1428
1428
  """
1429
- return item_popularity.select(sf.min(rating_column)).collect()[0][0] * weight
1429
+ return item_popularity.select(sf.min(rating_column)).first()[0] * weight
1430
1430
 
1431
1431
  @staticmethod
1432
1432
  def _check_rating(dataset: Dataset):
@@ -1460,7 +1460,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1460
1460
  .agg(sf.countDistinct(item_column).alias("items_count"))
1461
1461
  )
1462
1462
  .select(sf.max("items_count"))
1463
- .collect()[0][0]
1463
+ .first()[0]
1464
1464
  )
1465
1465
  # all queries have empty history
1466
1466
  if max_hist_len is None:
@@ -1495,7 +1495,7 @@ class NonPersonalizedRecommender(Recommender, ABC):
1495
1495
  queries = queries.join(query_to_num_items, on=self.query_column, how="left")
1496
1496
  queries = queries.fillna(0, "num_items")
1497
1497
  # 'selected_item_popularity' truncation by k + max_seen
1498
- max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).collect()[0][0]
1498
+ max_seen = queries.select(sf.coalesce(sf.max("num_items"), sf.lit(0))).first()[0]
1499
1499
  selected_item_popularity = selected_item_popularity.filter(sf.col("rank") <= k + max_seen)
1500
1500
  return queries.join(selected_item_popularity, on=(sf.col("rank") <= k + sf.col("num_items")), how="left")
1501
1501
 
@@ -32,9 +32,9 @@ class NmslibFilterIndexInferer(IndexInferer):
32
32
  index = index_store.load_index(
33
33
  init_index=lambda: create_nmslib_index_instance(index_params),
34
34
  load_index=lambda index, path: index.loadIndex(path, load_data=True),
35
- configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
36
- if index_params.ef_s
37
- else None,
35
+ configure_index=lambda index: (
36
+ index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
37
+ ),
38
38
  )
39
39
 
40
40
  # max number of items to retrieve per batch
@@ -30,9 +30,9 @@ class NmslibIndexInferer(IndexInferer):
30
30
  index = index_store.load_index(
31
31
  init_index=lambda: create_nmslib_index_instance(index_params),
32
32
  load_index=lambda index, path: index.loadIndex(path, load_data=True),
33
- configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
34
- if index_params.ef_s
35
- else None,
33
+ configure_index=lambda index: (
34
+ index.setQueryTimeParams({"efSearch": index_params.ef_s}) if index_params.ef_s else None
35
+ ),
36
36
  )
37
37
 
38
38
  user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)
@@ -1,7 +1,7 @@
1
1
  import contextlib
2
2
  import math
3
3
  from abc import ABC, abstractmethod
4
- from typing import Dict, Optional, Tuple, Union, cast
4
+ from typing import Dict, Optional, Union
5
5
 
6
6
  import torch
7
7
 
@@ -115,13 +115,10 @@ class Bert4RecModel(torch.nn.Module):
115
115
  # (B x L x E)
116
116
  x = self.item_embedder(inputs, token_mask)
117
117
 
118
- # (B x 1 x L x L)
119
- pad_mask_for_attention = self._get_attention_mask_from_padding(pad_mask)
120
-
121
118
  # Running over multiple transformer blocks
122
119
  for transformer in self.transformer_blocks:
123
120
  for _ in range(self.num_passes_over_block):
124
- x = transformer(x, pad_mask_for_attention)
121
+ x = transformer(x, pad_mask)
125
122
 
126
123
  return x
127
124
 
@@ -147,11 +144,6 @@ class Bert4RecModel(torch.nn.Module):
147
144
  """
148
145
  return self.forward_step(inputs, pad_mask, token_mask)[:, -1, :]
149
146
 
150
- def _get_attention_mask_from_padding(self, pad_mask: torch.BoolTensor) -> torch.BoolTensor:
151
- # (B x L) -> (B x 1 x L x L)
152
- pad_mask_for_attention = pad_mask.unsqueeze(1).repeat(1, self.max_len, 1).unsqueeze(1)
153
- return cast(torch.BoolTensor, pad_mask_for_attention)
154
-
155
147
  def _init(self) -> None:
156
148
  for _, param in self.named_parameters():
157
149
  with contextlib.suppress(ValueError):
@@ -456,7 +448,7 @@ class TransformerBlock(torch.nn.Module):
456
448
  :param dropout: Dropout rate.
457
449
  """
458
450
  super().__init__()
459
- self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden_size, dropout=dropout)
451
+ self.attention = torch.nn.MultiheadAttention(hidden_size, attn_heads, dropout=dropout, batch_first=True)
460
452
  self.attention_dropout = torch.nn.Dropout(dropout)
461
453
  self.attention_norm = LayerNorm(hidden_size)
462
454
 
@@ -479,7 +471,8 @@ class TransformerBlock(torch.nn.Module):
479
471
  """
480
472
  # Attention + skip-connection
481
473
  x_norm = self.attention_norm(x)
482
- y = x + self.attention_dropout(self.attention(x_norm, x_norm, x_norm, mask))
474
+ attent_emb, _ = self.attention(x_norm, x_norm, x_norm, key_padding_mask=~mask, need_weights=False)
475
+ y = x + self.attention_dropout(attent_emb)
483
476
 
484
477
  # PFF + skip-connection
485
478
  z = y + self.pff_dropout(self.pff(self.pff_norm(y)))
@@ -487,106 +480,6 @@ class TransformerBlock(torch.nn.Module):
487
480
  return self.dropout(z)
488
481
 
489
482
 
490
- class Attention(torch.nn.Module):
491
- """
492
- Compute Scaled Dot Product Attention
493
- """
494
-
495
- def __init__(self, dropout: float) -> None:
496
- """
497
- :param dropout: Dropout rate.
498
- """
499
- super().__init__()
500
- self.dropout = torch.nn.Dropout(p=dropout)
501
-
502
- def forward(
503
- self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.BoolTensor
504
- ) -> Tuple[torch.Tensor, torch.Tensor]:
505
- """
506
- :param query: Query feature vector.
507
- :param key: Key feature vector.
508
- :param value: Value feature vector.
509
- :param mask: Mask where 0 - <MASK>, 1 - otherwise.
510
-
511
- :returns: Tuple of scaled dot product attention
512
- and attention logits for each element.
513
- """
514
- scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
515
-
516
- scores = scores.masked_fill(mask == 0, -1e9)
517
- p_attn = torch.nn.functional.softmax(scores, dim=-1)
518
- p_attn = self.dropout(p_attn)
519
-
520
- return torch.matmul(p_attn, value), p_attn
521
-
522
-
523
- class MultiHeadedAttention(torch.nn.Module):
524
- """
525
- Take in model size and number of heads.
526
- """
527
-
528
- def __init__(self, h: int, d_model: int, dropout: float = 0.1) -> None:
529
- """
530
- :param h: Head sizes of multi-head attention.
531
- :param d_model: Embedding dimension.
532
- :param dropout: Dropout rate.
533
- Default: ``0.1``.
534
- """
535
- super().__init__()
536
- assert d_model % h == 0
537
-
538
- # We assume d_v always equals d_k
539
- self.d_k = d_model // h
540
- self.h = h
541
-
542
- # 3 linear projections for Q, K, V
543
- self.qkv_linear_layers = torch.nn.ModuleList([torch.nn.Linear(d_model, d_model) for _ in range(3)])
544
-
545
- # 2 linear projections for P -> P_q, P_k
546
- self.pos_linear_layers = torch.nn.ModuleList([torch.nn.Linear(d_model, d_model) for _ in range(2)])
547
-
548
- self.output_linear = torch.nn.Linear(d_model, d_model)
549
-
550
- self.attention = Attention(dropout)
551
-
552
- def forward(
553
- self,
554
- query: torch.Tensor,
555
- key: torch.Tensor,
556
- value: torch.Tensor,
557
- mask: torch.BoolTensor,
558
- ) -> torch.Tensor:
559
- """
560
- :param query: Query feature vector.
561
- :param key: Key feature vector.
562
- :param value: Value feature vector.
563
- :param mask: Mask where 0 - <MASK>, 1 - otherwise.
564
-
565
- :returns: Attention outputs.
566
- """
567
- batch_size = query.size(0)
568
-
569
- # B - batch size
570
- # L - sequence length (max_len)
571
- # E - embedding size for tokens fed into transformer
572
- # K - max relative distance
573
- # H - attention head count
574
-
575
- # Do all the linear projections in batch from d_model => h x d_k
576
- # (B x L x E) -> (B x H x L x (E / H))
577
- query, key, value = [
578
- layer(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
579
- for layer, x in zip(self.qkv_linear_layers, (query, key, value))
580
- ]
581
-
582
- x, _ = self.attention(query, key, value, mask)
583
-
584
- # Concat using a view and apply a final linear.
585
- x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
586
-
587
- return self.output_linear(x)
588
-
589
-
590
483
  class LayerNorm(torch.nn.Module):
591
484
  """
592
485
  Construct a layernorm module (See citation for details).
@@ -401,7 +401,12 @@ class SasRecLayers(torch.nn.Module):
401
401
  """
402
402
  super().__init__()
403
403
  self.attention_layers = self._layers_stacker(
404
- num_blocks, torch.nn.MultiheadAttention, hidden_size, num_heads, dropout
404
+ num_blocks,
405
+ torch.nn.MultiheadAttention,
406
+ hidden_size,
407
+ num_heads,
408
+ dropout,
409
+ batch_first=True,
405
410
  )
406
411
  self.attention_layernorms = self._layers_stacker(num_blocks, torch.nn.LayerNorm, hidden_size, eps=1e-8)
407
412
  self.forward_layers = self._layers_stacker(num_blocks, SasRecPointWiseFeedForward, hidden_size, dropout)
@@ -422,11 +427,9 @@ class SasRecLayers(torch.nn.Module):
422
427
  """
423
428
  length = len(self.attention_layers)
424
429
  for i in range(length):
425
- seqs = torch.transpose(seqs, 0, 1)
426
430
  query = self.attention_layernorms[i](seqs)
427
- attent_emb, _ = self.attention_layers[i](query, seqs, seqs, attn_mask=attention_mask)
431
+ attent_emb, _ = self.attention_layers[i](query, seqs, seqs, attn_mask=attention_mask, need_weights=False)
428
432
  seqs = query + attent_emb
429
- seqs = torch.transpose(seqs, 0, 1)
430
433
 
431
434
  seqs = self.forward_layernorms[i](seqs)
432
435
  seqs = self.forward_layers[i](seqs)
@@ -492,7 +495,7 @@ class SasRecPointWiseFeedForward(torch.nn.Module):
492
495
 
493
496
  :returns: Output tensors.
494
497
  """
495
- outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
498
+ outputs = self.dropout2(self.conv2(self.dropout1(self.relu(self.conv1(inputs.transpose(-1, -2))))))
496
499
  outputs = outputs.transpose(-1, -2)
497
500
  outputs += inputs
498
501
 
@@ -1,6 +1,7 @@
1
1
  """
2
2
  This class calculates loss function for optimization process
3
3
  """
4
+
4
5
  import collections
5
6
  import logging
6
7
  from functools import partial
@@ -102,6 +102,6 @@ class CSRConverter:
102
102
  row_count = self.row_count if self.row_count is not None else _get_max(rows_data) + 1
103
103
  col_count = self.column_count if self.column_count is not None else _get_max(cols_data) + 1
104
104
  return csr_matrix(
105
- (data, (rows_data, cols_data)),
105
+ (data.tolist(), (rows_data.tolist(), cols_data.tolist())),
106
106
  shape=(row_count, col_count),
107
107
  )
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Select or remove data by some criteria
3
3
  """
4
+
4
5
  from abc import ABC, abstractmethod
5
6
  from datetime import datetime, timedelta
6
7
  from typing import Callable, Optional, Tuple, Union
@@ -355,8 +356,8 @@ class NumInteractionsFilter(_BaseFilter):
355
356
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
356
357
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
357
358
  ... "rating": [1., 0.5, 3, 1, 0, 1],
358
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
359
- ... "2020-02-01", "2020-01-01 00:04:15",
359
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
360
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
360
361
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
361
362
  ... )
362
363
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -367,7 +368,7 @@ class NumInteractionsFilter(_BaseFilter):
367
368
  +-------+-------+------+-------------------+
368
369
  | u1| i1| 1.0|2020-01-01 23:59:59|
369
370
  | u2| i2| 0.5|2020-02-01 00:00:00|
370
- | u2| i3| 3.0|2020-02-01 00:00:00|
371
+ | u2| i3| 3.0|2020-02-01 00:00:01|
371
372
  | u3| i1| 1.0|2020-01-01 00:04:15|
372
373
  | u3| i2| 0.0|2020-01-02 00:04:14|
373
374
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -393,7 +394,7 @@ class NumInteractionsFilter(_BaseFilter):
393
394
  |user_id|item_id|rating| timestamp|
394
395
  +-------+-------+------+-------------------+
395
396
  | u1| i1| 1.0|2020-01-01 23:59:59|
396
- | u2| i2| 0.5|2020-02-01 00:00:00|
397
+ | u2| i3| 3.0|2020-02-01 00:00:01|
397
398
  | u3| i3| 1.0|2020-01-05 23:59:59|
398
399
  +-------+-------+------+-------------------+
399
400
  <BLANKLINE>
@@ -403,7 +404,7 @@ class NumInteractionsFilter(_BaseFilter):
403
404
  |user_id|item_id|rating| timestamp|
404
405
  +-------+-------+------+-------------------+
405
406
  | u1| i1| 1.0|2020-01-01 23:59:59|
406
- | u2| i3| 3.0|2020-02-01 00:00:00|
407
+ | u2| i3| 3.0|2020-02-01 00:00:01|
407
408
  | u3| i3| 1.0|2020-01-05 23:59:59|
408
409
  +-------+-------+------+-------------------+
409
410
  <BLANKLINE>
@@ -482,7 +483,7 @@ class NumInteractionsFilter(_BaseFilter):
482
483
 
483
484
  return (
484
485
  interactions.sort(sorting_columns, descending=descending)
485
- .with_columns(pl.col(self.query_column).cumcount().over(self.query_column).alias("temp_rank"))
486
+ .with_columns(pl.col(self.query_column).cum_count().over(self.query_column).alias("temp_rank"))
486
487
  .filter(pl.col("temp_rank") <= self.num_interactions)
487
488
  .drop("temp_rank")
488
489
  )
@@ -497,8 +498,8 @@ class EntityDaysFilter(_BaseFilter):
497
498
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
498
499
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
499
500
  ... "rating": [1., 0.5, 3, 1, 0, 1],
500
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
501
- ... "2020-02-01", "2020-01-01 00:04:15",
501
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
502
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
502
503
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
503
504
  ... )
504
505
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -509,7 +510,7 @@ class EntityDaysFilter(_BaseFilter):
509
510
  +-------+-------+------+-------------------+
510
511
  | u1| i1| 1.0|2020-01-01 23:59:59|
511
512
  | u2| i2| 0.5|2020-02-01 00:00:00|
512
- | u2| i3| 3.0|2020-02-01 00:00:00|
513
+ | u2| i3| 3.0|2020-02-01 00:00:01|
513
514
  | u3| i1| 1.0|2020-01-01 00:04:15|
514
515
  | u3| i2| 0.0|2020-01-02 00:04:14|
515
516
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -524,7 +525,7 @@ class EntityDaysFilter(_BaseFilter):
524
525
  +-------+-------+------+-------------------+
525
526
  | u1| i1| 1.0|2020-01-01 23:59:59|
526
527
  | u2| i2| 0.5|2020-02-01 00:00:00|
527
- | u2| i3| 3.0|2020-02-01 00:00:00|
528
+ | u2| i3| 3.0|2020-02-01 00:00:01|
528
529
  | u3| i1| 1.0|2020-01-01 00:04:15|
529
530
  | u3| i2| 0.0|2020-01-02 00:04:14|
530
531
  +-------+-------+------+-------------------+
@@ -539,7 +540,7 @@ class EntityDaysFilter(_BaseFilter):
539
540
  | u1| i1| 1.0|2020-01-01 23:59:59|
540
541
  | u3| i1| 1.0|2020-01-01 00:04:15|
541
542
  | u2| i2| 0.5|2020-02-01 00:00:00|
542
- | u2| i3| 3.0|2020-02-01 00:00:00|
543
+ | u2| i3| 3.0|2020-02-01 00:00:01|
543
544
  +-------+-------+------+-------------------+
544
545
  <BLANKLINE>
545
546
  """
@@ -636,8 +637,8 @@ class GlobalDaysFilter(_BaseFilter):
636
637
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
637
638
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
638
639
  ... "rating": [1., 0.5, 3, 1, 0, 1],
639
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
640
- ... "2020-02-01", "2020-01-01 00:04:15",
640
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
641
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
641
642
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
642
643
  ... )
643
644
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -648,7 +649,7 @@ class GlobalDaysFilter(_BaseFilter):
648
649
  +-------+-------+------+-------------------+
649
650
  | u1| i1| 1.0|2020-01-01 23:59:59|
650
651
  | u2| i2| 0.5|2020-02-01 00:00:00|
651
- | u2| i3| 3.0|2020-02-01 00:00:00|
652
+ | u2| i3| 3.0|2020-02-01 00:00:01|
652
653
  | u3| i1| 1.0|2020-01-01 00:04:15|
653
654
  | u3| i2| 0.0|2020-01-02 00:04:14|
654
655
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -670,7 +671,7 @@ class GlobalDaysFilter(_BaseFilter):
670
671
  |user_id|item_id|rating| timestamp|
671
672
  +-------+-------+------+-------------------+
672
673
  | u2| i2| 0.5|2020-02-01 00:00:00|
673
- | u2| i3| 3.0|2020-02-01 00:00:00|
674
+ | u2| i3| 3.0|2020-02-01 00:00:01|
674
675
  +-------+-------+------+-------------------+
675
676
  <BLANKLINE>
676
677
  """
@@ -738,8 +739,8 @@ class TimePeriodFilter(_BaseFilter):
738
739
  >>> log_pd = pd.DataFrame({"user_id": ["u1", "u2", "u2", "u3", "u3", "u3"],
739
740
  ... "item_id": ["i1", "i2","i3", "i1", "i2","i3"],
740
741
  ... "rating": [1., 0.5, 3, 1, 0, 1],
741
- ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01",
742
- ... "2020-02-01", "2020-01-01 00:04:15",
742
+ ... "timestamp": ["2020-01-01 23:59:59", "2020-02-01 00:00:00",
743
+ ... "2020-02-01 00:00:01", "2020-01-01 00:04:15",
743
744
  ... "2020-01-02 00:04:14", "2020-01-05 23:59:59"]},
744
745
  ... )
745
746
  >>> log_pd["timestamp"] = pd.to_datetime(log_pd["timestamp"], format="ISO8601")
@@ -750,7 +751,7 @@ class TimePeriodFilter(_BaseFilter):
750
751
  +-------+-------+------+-------------------+
751
752
  | u1| i1| 1.0|2020-01-01 23:59:59|
752
753
  | u2| i2| 0.5|2020-02-01 00:00:00|
753
- | u2| i3| 3.0|2020-02-01 00:00:00|
754
+ | u2| i3| 3.0|2020-02-01 00:00:01|
754
755
  | u3| i1| 1.0|2020-01-01 00:04:15|
755
756
  | u3| i2| 0.0|2020-01-02 00:04:14|
756
757
  | u3| i3| 1.0|2020-01-05 23:59:59|
@@ -179,8 +179,8 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
179
179
  abnormality_aggs = [sf.mean(sf.col("abnormality")).alias("abnormality")]
180
180
 
181
181
  # Abnormality CR:
182
- max_std = item_features.select(sf.max("i_std")).collect()[0][0]
183
- min_std = item_features.select(sf.min("i_std")).collect()[0][0]
182
+ max_std = item_features.select(sf.max("i_std")).first()[0]
183
+ min_std = item_features.select(sf.min("i_std")).first()[0]
184
184
  if max_std - min_std != 0:
185
185
  abnormality_df = abnormality_df.withColumn(
186
186
  "controversy",
@@ -201,15 +201,15 @@ class LogStatFeaturesProcessor(EmptyFeatureProcessor):
201
201
  :param log: input SparkDataFrame ``[user_idx, item_idx, timestamp, relevance]``
202
202
  """
203
203
  self.calc_timestamp_based = (isinstance(log.schema["timestamp"].dataType, TimestampType)) & (
204
- log.select(sf.countDistinct(sf.col("timestamp"))).collect()[0][0] > 1
204
+ log.select(sf.countDistinct(sf.col("timestamp"))).first()[0] > 1
205
205
  )
206
- self.calc_relevance_based = log.select(sf.countDistinct(sf.col("relevance"))).collect()[0][0] > 1
206
+ self.calc_relevance_based = log.select(sf.countDistinct(sf.col("relevance"))).first()[0] > 1
207
207
 
208
208
  user_log_features = log.groupBy("user_idx").agg(*self._create_log_aggregates(agg_col="user_idx"))
209
209
  item_log_features = log.groupBy("item_idx").agg(*self._create_log_aggregates(agg_col="item_idx"))
210
210
 
211
211
  if self.calc_timestamp_based:
212
- last_date = log.select(sf.max("timestamp")).collect()[0][0]
212
+ last_date = log.select(sf.max("timestamp")).first()[0]
213
213
  user_log_features = self._add_ts_based(features=user_log_features, max_log_date=last_date, prefix="u")
214
214
 
215
215
  item_log_features = self._add_ts_based(features=item_log_features, max_log_date=last_date, prefix="i")
@@ -5,6 +5,7 @@ Contains classes for encoding categorical data
5
5
  Recommended to use together with the LabelEncoder.
6
6
  ``LabelEncoder`` to apply multiple LabelEncodingRule to dataframe.
7
7
  """
8
+
8
9
  import abc
9
10
  import warnings
10
11
  from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union
@@ -1,4 +1,5 @@
1
1
  """
2
2
  Scenarios are a series of actions for recommendations
3
3
  """
4
+
4
5
  from .fallback import Fallback
@@ -193,7 +193,7 @@ class LastNSplitter(Splitter):
193
193
 
194
194
  def _add_time_partition_to_polars(self, interactions: PolarsDataFrame) -> PolarsDataFrame:
195
195
  res = interactions.sort(self.timestamp_column).with_columns(
196
- pl.col(self.divide_column).cumcount().over(pl.col(self.divide_column)).alias("row_num")
196
+ pl.col(self.divide_column).cum_count().over(pl.col(self.divide_column)).alias("row_num")
197
197
  )
198
198
 
199
199
  return res
@@ -193,7 +193,7 @@ class TimeSplitter(Splitter):
193
193
  )
194
194
  test_start = int(dates.count() * (1 - threshold)) + 1
195
195
  test_start = (
196
- dates.filter(sf.col("_row_number_by_ts") == test_start).select(self.timestamp_column).collect()[0][0]
196
+ dates.filter(sf.col("_row_number_by_ts") == test_start).select(self.timestamp_column).first()[0]
197
197
  )
198
198
  res = interactions.withColumn("is_test", sf.col(self.timestamp_column) >= test_start)
199
199
  else:
@@ -1,8 +1,10 @@
1
1
  """
2
2
  This splitter split data by two columns.
3
3
  """
4
+
4
5
  from typing import Optional, Tuple
5
6
 
7
+ import numpy as np
6
8
  import polars as pl
7
9
 
8
10
  from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
@@ -124,15 +126,15 @@ class TwoStageSplitter(Splitter):
124
126
  :return: DataFrame with single column `first_divide_column`
125
127
  """
126
128
  if isinstance(interactions, SparkDataFrame):
127
- all_values = interactions.select(self.first_divide_column).distinct()
129
+ all_values = interactions.select(self.first_divide_column).distinct().sort(self.first_divide_column)
128
130
  user_count = all_values.count()
129
131
  elif isinstance(interactions, PandasDataFrame):
130
132
  all_values = PandasDataFrame(
131
- interactions[self.first_divide_column].unique(), columns=[self.first_divide_column]
133
+ np.sort(interactions[self.first_divide_column].unique()), columns=[self.first_divide_column]
132
134
  )
133
135
  user_count = len(all_values)
134
136
  else:
135
- all_values = interactions.select(self.first_divide_column).unique()
137
+ all_values = interactions.select(self.first_divide_column).unique().sort(self.first_divide_column)
136
138
  user_count = len(all_values)
137
139
 
138
140
  value_error = False
@@ -152,7 +154,7 @@ class TwoStageSplitter(Splitter):
152
154
  if isinstance(interactions, SparkDataFrame):
153
155
  test_users = (
154
156
  all_values.withColumn("_rand", sf.rand(self.seed))
155
- .withColumn("_row_num", sf.row_number().over(Window.orderBy("_rand")))
157
+ .withColumn("_row_num", sf.row_number().over(Window.partitionBy(sf.lit(0)).orderBy("_rand")))
156
158
  .filter(f"_row_num <= {test_user_count}")
157
159
  .drop("_rand", "_row_num")
158
160
  )
@@ -240,10 +242,10 @@ class TwoStageSplitter(Splitter):
240
242
  res = res.fill_null(False)
241
243
 
242
244
  train = res.filter((pl.col("_frac") > self.second_divide_size) | (~pl.col("is_test"))).drop(
243
- "_rand", "_row_num", "count", "_frac", "is_test"
245
+ "_row_num", "count", "_frac", "is_test"
244
246
  )
245
247
  test = res.filter((pl.col("_frac") <= self.second_divide_size) & pl.col("is_test")).drop(
246
- "_rand", "_row_num", "count", "_frac", "is_test"
248
+ "_row_num", "count", "_frac", "is_test"
247
249
  )
248
250
 
249
251
  return train, test
@@ -1,4 +1,5 @@
1
1
  """Distribution calculations"""
2
+
2
3
  from .types import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame
3
4
 
4
5
  if PYSPARK_AVAILABLE:
@@ -48,10 +48,10 @@ def get_spark_session(
48
48
  path_to_replay_jar = (
49
49
  "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar"
50
50
  )
51
- elif pyspark_version.startswith(("3.2", "3.3")):
51
+ elif pyspark_version.startswith(("3.2", "3.3")): # pragma: no cover
52
52
  path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.2.0_als_metrics/replay_2.12-3.2.0_als_metrics.jar"
53
53
  elif pyspark_version.startswith("3.4"): # pragma: no cover
54
- path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.4.0_als_metrics/replay_2.12-3.4.0_als_metrics.jar"
54
+ path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_after_fix_2.12/0.1/replay_after_fix_2.12-0.1.jar"
55
55
  else: # pragma: no cover
56
56
  path_to_replay_jar = (
57
57
  "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar"
@@ -83,7 +83,7 @@ def get_spark_session(
83
83
  .config("spark.driver.maxResultSize", "4g")
84
84
  .config("spark.driver.bindAddress", "127.0.0.1")
85
85
  .config("spark.driver.host", "localhost")
86
- .config("spark.sql.execution.arrow.pyspark.enabled", "true")
86
+ .config("spark.sql.execution.arrow.enabled", "true")
87
87
  .config("spark.kryoserializer.buffer.max", "256m")
88
88
  .config("spark.files.overwrite", "true")
89
89
  .master(f"local[{'*' if core_count == -1 else core_count}]")
@@ -459,8 +459,8 @@ def fallback(
459
459
  if base.count() == 0:
460
460
  return get_top_k_recs(fill, k, query_column=query_column, rating_column=rating_column)
461
461
  margin = 0.1
462
- min_in_base = base.agg({rating_column: "min"}).collect()[0][0]
463
- max_in_fill = fill.agg({rating_column: "max"}).collect()[0][0]
462
+ min_in_base = base.agg({rating_column: "min"}).first()[0]
463
+ max_in_fill = fill.agg({rating_column: "max"}).first()[0]
464
464
  diff = max_in_fill - min_in_base
465
465
  fill = fill.withColumnRenamed(rating_column, "relevance_fallback")
466
466
  if diff >= 0:
@@ -1,11 +1,11 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: replay-rec
3
- Version: 0.17.1rc0
3
+ Version: 0.18.0rc0
4
4
  Summary: RecSys Library
5
5
  Home-page: https://sb-ai-lab.github.io/RePlay/
6
6
  License: Apache-2.0
7
7
  Author: AI Lab
8
- Requires-Python: >=3.8.1,<3.11
8
+ Requires-Python: >=3.8.1,<3.12
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: Environment :: Console
11
11
  Classifier: Intended Audience :: Developers
@@ -16,32 +16,34 @@ Classifier: Operating System :: Unix
16
16
  Classifier: Programming Language :: Python :: 3
17
17
  Classifier: Programming Language :: Python :: 3.9
18
18
  Classifier: Programming Language :: Python :: 3.10
19
+ Classifier: Programming Language :: Python :: 3.11
19
20
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
21
  Provides-Extra: all
21
22
  Provides-Extra: spark
22
23
  Provides-Extra: torch
23
24
  Requires-Dist: d3rlpy (>=2.0.4,<3.0.0)
25
+ Requires-Dist: fixed-install-nmslib (==2.1.2)
24
26
  Requires-Dist: gym (>=0.26.0,<0.27.0)
25
- Requires-Dist: hnswlib (==0.7.0)
27
+ Requires-Dist: hnswlib (>=0.7.0,<0.8.0)
26
28
  Requires-Dist: implicit (>=0.7.0,<0.8.0)
27
29
  Requires-Dist: lightautoml (>=0.3.1,<0.4.0)
28
30
  Requires-Dist: lightfm (==1.17)
29
- Requires-Dist: lightning (>=2.0.2,<3.0.0) ; extra == "torch" or extra == "all"
31
+ Requires-Dist: lightning (>=2.0.2,<=2.4.0) ; extra == "torch" or extra == "all"
30
32
  Requires-Dist: llvmlite (>=0.32.1)
31
- Requires-Dist: nmslib (==2.1.1)
32
33
  Requires-Dist: numba (>=0.50)
33
34
  Requires-Dist: numpy (>=1.20.0)
34
35
  Requires-Dist: optuna (>=3.2.0,<3.3.0)
35
36
  Requires-Dist: pandas (>=1.3.5,<=2.2.2)
36
- Requires-Dist: polars (>=0.20.7,<0.21.0)
37
- Requires-Dist: psutil (>=5.9.5,<5.10.0)
37
+ Requires-Dist: polars (>=1.0.0,<1.1.0)
38
+ Requires-Dist: psutil (>=6.0.0,<6.1.0)
38
39
  Requires-Dist: pyarrow (>=12.0.1)
39
- Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark" or extra == "all"
40
+ Requires-Dist: pyspark (>=3.0,<3.5) ; (python_full_version >= "3.8.1" and python_version < "3.11") and (extra == "spark" or extra == "all")
41
+ Requires-Dist: pyspark (>=3.4,<3.5) ; (python_version >= "3.11" and python_version < "3.12") and (extra == "spark" or extra == "all")
40
42
  Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "all"
41
- Requires-Dist: sb-obp (>=0.5.7,<0.6.0)
43
+ Requires-Dist: sb-obp (>=0.5.8,<0.6.0)
42
44
  Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
43
- Requires-Dist: scipy (>=1.8.1,<1.9.0)
44
- Requires-Dist: torch (>=1.8,<2.0) ; extra == "torch" or extra == "all"
45
+ Requires-Dist: scipy (>=1.8.1,<2.0.0)
46
+ Requires-Dist: torch (>=1.8,<=2.4.0) ; extra == "torch" or extra == "all"
45
47
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
46
48
  Description-Content-Type: text/markdown
47
49
 
@@ -1,10 +1,10 @@
1
- replay/__init__.py,sha256=_PQ2zFERSGjgeThzFv3t6MPODgutry1eR82biGhB98o,54
1
+ replay/__init__.py,sha256=8QXsQRY27Ie9xmwimwzqKYG4KTLnxtZW0ns89LKKtUU,55
2
2
  replay/data/__init__.py,sha256=g5bKRyF76QL_BqlED-31RnS8pBdcyj9loMsx5vAG_0E,301
3
- replay/data/dataset.py,sha256=cSStvCqIc6WAJNtbmsxncSpcQZ1KfULMsrmf_V0UdPw,29490
3
+ replay/data/dataset.py,sha256=FnvsFeIcCMlq94_NDQRY3-jgpVvKN-4FdivABWVr8Pk,29481
4
4
  replay/data/dataset_utils/__init__.py,sha256=9wUvG8ZwGUvuzLU4zQI5FDcH0WVVo5YLN2ey3DterP0,55
5
- replay/data/dataset_utils/dataset_label_encoder.py,sha256=TEx2zLw5rJdIz1SRBEznyVv5x_Cs7o6QQbzMk-M1LU0,9598
5
+ replay/data/dataset_utils/dataset_label_encoder.py,sha256=o8p7XvQewKuqYY8anrUhuY8gTau1FbpPjnNSAwbDZTY,9599
6
6
  replay/data/nn/__init__.py,sha256=WxLsi4rgOuuvGYHN49xBPxP2Srhqf3NYgfBDVH-ZvBo,1122
7
- replay/data/nn/schema.py,sha256=pO4N7RgmgrqfD1-2d95OTeihKHTZ-5y2BG7CX_wBFi4,16198
7
+ replay/data/nn/schema.py,sha256=N6lBWC1Q_kX1s6oVdOaxxAYE2pWqwbkDK7LmLL8N1Ts,16208
8
8
  replay/data/nn/sequence_tokenizer.py,sha256=Ambrp3CMOp3JP68PiwmVh0m-_zNXiWzxxVreHkEwOyY,32592
9
9
  replay/data/nn/sequential_dataset.py,sha256=jCWxC0Pm1eQ5p8Y6_Bmg4fSEvPaecLrqz1iaWzaICdI,11014
10
10
  replay/data/nn/torch_sequential_dataset.py,sha256=BqrK_PtkhpsaY1zRIWGk4EgwPL31a7IWCc0hLDuwDQc,10984
@@ -12,8 +12,8 @@ replay/data/nn/utils.py,sha256=YKE9gkIHZDDiwv4THqOWL4PzsdOujnPuM97v79Mwq0E,2769
12
12
  replay/data/schema.py,sha256=F_cv6sYb6l23yuX5xWnbqoJ9oSeUT2NpIM19u8Lf2jA,15606
13
13
  replay/data/spark_schema.py,sha256=4o0Kn_fjwz2-9dBY3q46F9PL0F3E7jdVpIlX7SG3OZI,1111
14
14
  replay/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- replay/experimental/metrics/__init__.py,sha256=W6S9YTGCezLORyTKCqL4Y_PniC1k3Bu5XWIM3WVHg2Q,2860
16
- replay/experimental/metrics/base_metric.py,sha256=aYmKZ_336dRrlslBzYsgsOzmed54BNjNXsRcpzB5gyM,22648
15
+ replay/experimental/metrics/__init__.py,sha256=bdQogGbEDVAeH7Ejbb6vpw7bP6CYhftTu_DQuoFRuCA,2861
16
+ replay/experimental/metrics/base_metric.py,sha256=mWbkRGdHTF3ZHq9WSqTGGAX2XJtOSzwcefjSu1Mdl0Y,22649
17
17
  replay/experimental/metrics/coverage.py,sha256=3kVBAUhIEOuD8aJ6DShH2xh_1F61dcLZb001VCkmeJk,3154
18
18
  replay/experimental/metrics/experiment.py,sha256=Bd_XB9zbngcAwf5JLZKVPsFWQoz9pEGlPEUbkiR_MDc,7343
19
19
  replay/experimental/metrics/hitrate.py,sha256=TfWJrUyZXabdMr4tn8zqUPGDcYy2yphVCzXmLSHCxY0,675
@@ -29,10 +29,10 @@ replay/experimental/metrics/unexpectedness.py,sha256=JQQXEYHtQM8nqp7X2He4E9ZYwbp
29
29
  replay/experimental/models/__init__.py,sha256=R284PXgSxt-JWWwlSTLggchash0hrLfy4b2w-ySaQf4,588
30
30
  replay/experimental/models/admm_slim.py,sha256=Oz-x0aQAnGFN9z7PB7MiKfduBasc4KQrBT0JwtYdwLY,6581
31
31
  replay/experimental/models/base_neighbour_rec.py,sha256=pRcffr0cdRNZRVpzWb2Qv-UIsLkhbs7K1GRAmrSqPSM,7506
32
- replay/experimental/models/base_rec.py,sha256=rj2r7r_mmJdzKAkg5CHG1eqJhOpUHAETPe0NwfibFjU,49606
32
+ replay/experimental/models/base_rec.py,sha256=eTHQdjEaS_5e-8y7xB6tHlSObD0cbD66_NfFZJK2NxU,49571
33
33
  replay/experimental/models/base_torch_rec.py,sha256=oDkCxVFQjIHSWKlCns6mU3ECWbQW3mQZWvBHBxJQdwc,8111
34
- replay/experimental/models/cql.py,sha256=9ONDMblfxUgol5Pb2UInfSHVRbB2Ma15zAZC6valhtk,19628
35
- replay/experimental/models/ddpg.py,sha256=sZrGgwj_kKeUnwwT9qooc4Cxz2oVGkNfUwUe1N7mreI,31982
34
+ replay/experimental/models/cql.py,sha256=3IBQEqWfyHmvGxCvWtIbLgjuRWfd_8mySg8bVaI4KHQ,19630
35
+ replay/experimental/models/ddpg.py,sha256=uqWk235-YZ2na-NPN4TxUM9ZhogpLZEjivt1oSC2rtI,32080
36
36
  replay/experimental/models/dt4rec/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  replay/experimental/models/dt4rec/dt4rec.py,sha256=ZIHYonDubStN7Gb703csy86R7Q3_1fZc4zJf98HYFe4,5895
38
38
  replay/experimental/models/dt4rec/gpt1.py,sha256=T3buFtYyF6Fh6sW6f9dUZFcFEnQdljItbRa22CiKb0w,14044
@@ -41,15 +41,15 @@ replay/experimental/models/dt4rec/utils.py,sha256=jbCx2Xc85VtjQx-caYhJFfVuj1Wf86
41
41
  replay/experimental/models/extensions/spark_custom_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
42
  replay/experimental/models/extensions/spark_custom_models/als_extension.py,sha256=dKSVCMXWRB7IUnpEK_QNhSEuUSVcG793E8MT_AGXneY,25890
43
43
  replay/experimental/models/implicit_wrap.py,sha256=8F-f-CaStmlNHwphu-yu8o4Aft08NKDD_SqqH0zp1Uo,4655
44
- replay/experimental/models/lightfm_wrap.py,sha256=a2ctIEoZf7I0C_awiQI1lE4RGJ7ISs60znysgHRXZCw,11337
45
- replay/experimental/models/mult_vae.py,sha256=FdJ-GL6Jj2l5-38edKp_jsNfwFNGPxMHXKn8cG2tGJs,11607
46
- replay/experimental/models/neuromf.py,sha256=QRu--zIyOSQIp8R5Ksgiw7o0s5yOhQpuAX9YshKJs4w,14391
47
- replay/experimental/models/scala_als.py,sha256=PVf0YA3ii4iRwGqpYg6nStgaauyrm9QTzLtK_4f1En0,10985
44
+ replay/experimental/models/lightfm_wrap.py,sha256=8nuTpiBuddKlMFFpbUpRt5k_JiBGRjPpF_hNbKdLP4Q,11327
45
+ replay/experimental/models/mult_vae.py,sha256=BnnlUHPlNuvh7EFA8bjITRW_m8JQANRD6zvsNQ1SUXM,11608
46
+ replay/experimental/models/neuromf.py,sha256=Hr9qEKv1shkwAqCVCxfews1Pk3F6yni2WIZUGS2tNCE,14392
47
+ replay/experimental/models/scala_als.py,sha256=-sMZ8P_XbmVi-hApuS46MpaosVIXRED05cgsOI3ojvQ,10975
48
48
  replay/experimental/nn/data/__init__.py,sha256=5EAF-FNd7xhkUpTq_5MyVcPXBD81mJCwYrcbhdGOWjE,48
49
49
  replay/experimental/nn/data/schema_builder.py,sha256=5PphL9kK-tVm30nWdTjHUzqVOnTwKiU_MlxGdL5HJ8Y,1736
50
50
  replay/experimental/preprocessing/__init__.py,sha256=uMyeyQ_GKqjLhVGwhrEk3NLhhzS0DKi5xGo3VF4WkiA,130
51
- replay/experimental/preprocessing/data_preparator.py,sha256=fQ8Blo_uzA-2eC-_ViVeU26Tqj5lxLTCBoDJfEmiqUo,35968
52
- replay/experimental/preprocessing/padder.py,sha256=o7S_Zk-ne_jria3QhWCKkYa6bEqhCdtvCA-R0MjOvU4,9569
51
+ replay/experimental/preprocessing/data_preparator.py,sha256=SLyk4HWurLmUHuev5L_GmI3oVU-58lCflOExHJ7zCGw,35964
52
+ replay/experimental/preprocessing/padder.py,sha256=ROKnGA0136C9W9Qkky-1V5klcMxvwos5KL4_jMLOgwY,9564
53
53
  replay/experimental/preprocessing/sequence_generator.py,sha256=E1_0uZJLv8V_n6YzRlgUWtcrHIdjNwPeBN-BMbz0e-A,9053
54
54
  replay/experimental/scenarios/__init__.py,sha256=gWFLCkLyOmOppvbRMK7C3UMlMpcbIgiGVolSH6LPgWA,91
55
55
  replay/experimental/scenarios/obp_wrapper/__init__.py,sha256=rsRyfsTnVNp20LkTEugwoBrV9XWbIhR8tOqec_Au6dY,450
@@ -58,12 +58,12 @@ replay/experimental/scenarios/obp_wrapper/replay_offline.py,sha256=A6TPBFHj_UUL0
58
58
  replay/experimental/scenarios/obp_wrapper/utils.py,sha256=-ioWTb73NmHWxVxw4BdSolctqeeGIyjKtydwc45nrrk,3271
59
59
  replay/experimental/scenarios/two_stages/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
60
  replay/experimental/scenarios/two_stages/reranker.py,sha256=tJtWhbHRNV4sJZ9RZzqIfylTplKh9QVwTIBhEGGnXq8,4244
61
- replay/experimental/scenarios/two_stages/two_stages_scenario.py,sha256=ZgflnQ6xuxDFphdKX6Q0jtXidHS7c2YvDaccoaL78Qo,29846
61
+ replay/experimental/scenarios/two_stages/two_stages_scenario.py,sha256=frwsST85YGMGEZPf4DZFp3kPKPEcVgaxOCEdtZywpkw,29841
62
62
  replay/experimental/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
63
63
  replay/experimental/utils/logger.py,sha256=UwLowaeOG17sDEe32LiZel8MnjSTzeW7J3uLG1iwLuA,639
64
- replay/experimental/utils/model_handler.py,sha256=0ksSm5bJ1bL32VV5HI-KPe0a1EAzzOhMtmSYaM_zRrE,6271
64
+ replay/experimental/utils/model_handler.py,sha256=Rfj57E1R_XMEEigHNZa9a-rzEsyLWSDsgKfXoRzWWdg,6426
65
65
  replay/experimental/utils/session_handler.py,sha256=076TLpTOcnh13BznNTtJW6Zhrqvm9Ee1mlpP5YMD4No,1313
66
- replay/metrics/__init__.py,sha256=KDkxVnKa4ks9K9GmlrdTx1pkIl-MAmm78ZASsp2ZndE,2812
66
+ replay/metrics/__init__.py,sha256=j0PGvUehaPEZMNo9SQwJsnvzrS4bam9eHrRMQFLnMjY,2813
67
67
  replay/metrics/base_metric.py,sha256=uleW5vLrdA3iRx72tFyW0cxe6ne_ugQ1XaY_ZTcnAOo,15960
68
68
  replay/metrics/categorical_diversity.py,sha256=OYsF-Ng-WrF9CC-sKgQKngrA779NO8MtgRvvAyC8MXM,10781
69
69
  replay/metrics/coverage.py,sha256=wE1Y_TgKOzf_9ixeas-vsxANAHeHSGPuGrzKk8DklaY,8843
@@ -82,10 +82,10 @@ replay/metrics/surprisal.py,sha256=wj9Q5mAdECpl0LfykJWt8jgN3_CUSlai2fhiFgJr_Vw,7
82
82
  replay/metrics/torch_metrics_builder.py,sha256=2gcCcb0A-TVpYcBIYGhXrggyFX-M_T7Q1pQUiMpxEZE,13845
83
83
  replay/metrics/unexpectedness.py,sha256=cfDnkpK6nPeawwHDVNQAkUtsW0SvAttI84k4M5ttkyo,6888
84
84
  replay/models/__init__.py,sha256=_4gNsauyrVMYEoFDihPYY9kGuBGGFyy1krvxF7oEYjk,808
85
- replay/models/als.py,sha256=dpBwyg1ZBqtdgrFluHaq5nuPQT---fmA-N2TspJAM0U,6232
85
+ replay/models/als.py,sha256=eGiMok_zu5ZUKXU9i9feCP4RGMqSnlIGHjks6MqKzHw,6227
86
86
  replay/models/association_rules.py,sha256=cp4myXvMqro6zLMjJzJMb0DZ5DQFQEZvhqf5OBgBw8Y,14659
87
87
  replay/models/base_neighbour_rec.py,sha256=zMORSm4uMQSNj12v0n_6w8fVHgSYjeiqyYE9rrWgSfU,7887
88
- replay/models/base_rec.py,sha256=iF0eMlNQVcd-nb3aCRG3ObpmEi7P4-jP_5mKjwc6anc,66407
88
+ replay/models/base_rec.py,sha256=NFz_xcarDwhaB3fSa-5uLBa6tyATOwOQLD_yR445m5U,66372
89
89
  replay/models/cat_pop_rec.py,sha256=tzI1UMlC3kEOrtDZ1UPpCP13tX8CeDJP7PHwQKl9Mmo,11922
90
90
  replay/models/cluster.py,sha256=9JcpGnbfgFa4UsyxPAa4WMuJFa3rsuAxiKoy-s_UfyE,4970
91
91
  replay/models/extensions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -106,8 +106,8 @@ replay/models/extensions/ann/index_inferers/__init__.py,sha256=47DEQpj8HBSa-_TIm
106
106
  replay/models/extensions/ann/index_inferers/base_inferer.py,sha256=I39aqEc2somfndrCd-KC3XYZnYSrJ2hGpR9y6wO93NA,2524
107
107
  replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py,sha256=JjT4l_XAjzUOsTAE7OS88zAgPd_h_O44oUnn2kVr8E0,2477
108
108
  replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py,sha256=CoY_oMfdcwnh87ceuSpHXu4Czle9xxeMisO8XJUuJLE,1717
109
- replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py,sha256=1bpBjRhj4J_ecaORRhkhEke7ImJcxVTFRmmGK2wISB4,3120
110
- replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py,sha256=TqyunbjMQp1bWltbouvqK2kr2cnER6_d75NuCTVB3O0,2195
109
+ replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py,sha256=tjuqbkztWBU4K6qp5LPFU_GOGJf2f4oXneExtUEVUzw,3128
110
+ replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py,sha256=S5eCBZlTXxEAeX6yeZGC7j56gOcJ7lMNb4Cs_5PEj9E,2203
111
111
  replay/models/extensions/ann/index_inferers/utils.py,sha256=6IST2FPSY3nuYu5KqzRpd4FgdaV3GnQRQlxp9LN_yyA,641
112
112
  replay/models/extensions/ann/index_stores/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
113
113
  replay/models/extensions/ann/index_stores/base_index_store.py,sha256=u4l2ybAXX92ZMGK7NqqosbKF75QgFqhAMVadd5ePj6Y,910
@@ -125,7 +125,7 @@ replay/models/nn/sequential/__init__.py,sha256=CI2n0cxs_amqJrwBMq6n0Z_uBOu7CGXfa
125
125
  replay/models/nn/sequential/bert4rec/__init__.py,sha256=JfZqHOGxcvOkICl5cWmZbZhaKXpkIvua-Wj57VWWEhw,399
126
126
  replay/models/nn/sequential/bert4rec/dataset.py,sha256=sCnYGF-sQ1YlLq7vofQo2GIIlc59YlbUgmW7bHI6MPg,10324
127
127
  replay/models/nn/sequential/bert4rec/lightning.py,sha256=TqO0V-g0JA0D-L2t08AgAIQgBkDtLUgl4xqekSiDWJ4,22605
128
- replay/models/nn/sequential/bert4rec/model.py,sha256=tiAiKOUwk3iPPYWyWkfOF23IzfL1NbeaF-8kNt9uZlU,21303
128
+ replay/models/nn/sequential/bert4rec/model.py,sha256=lZJwJbWPjrcvQCpD2LULMva-nXaTL8PgZHkZ-8z9okU,17758
129
129
  replay/models/nn/sequential/callbacks/__init__.py,sha256=Q7mSZ_RB6iyD7QZaBL_NJ0uh8cRfgxq7gtPHbkSyhoo,282
130
130
  replay/models/nn/sequential/callbacks/prediction_callbacks.py,sha256=H4MZ87_N0hCKtHbsTuN-Cq_SJ-n9TSkvv2okuGnwo3M,9045
131
131
  replay/models/nn/sequential/callbacks/validation_callback.py,sha256=6TNl3NN9oahK1J7DT44461xqBuUCblCsLzUi2svlhF4,5825
@@ -135,7 +135,7 @@ replay/models/nn/sequential/postprocessors/postprocessors.py,sha256=V32xMyNPztJ5
135
135
  replay/models/nn/sequential/sasrec/__init__.py,sha256=c6130lRpPkcbuGgkM7slagBIgH7Uk5zUtSzFDEwAsik,250
136
136
  replay/models/nn/sequential/sasrec/dataset.py,sha256=ReGNc6t9jjXxMZJp0WqFj1jatJFHnWOrkK3W8lwBNIs,7036
137
137
  replay/models/nn/sequential/sasrec/lightning.py,sha256=DtLnNikTNvqroCzaVFw7u-QZpZdvwiYbCwJLE7FkHms,21397
138
- replay/models/nn/sequential/sasrec/model.py,sha256=DE9kaqlcL22v07kpi2IzIwZ4-3AXNBVTZCnfuTS5usg,27775
138
+ replay/models/nn/sequential/sasrec/model.py,sha256=EBAfDP3WHZC-Pyb8dm0mr3gpxhrCOFQDHbZ2itFPWmk,27780
139
139
  replay/models/pop_rec.py,sha256=Ju9y2rU2vW_jFU9-W15fbbr5_ZzYGihSjSxsqKsAf0Q,4964
140
140
  replay/models/query_pop_rec.py,sha256=UNsHtf3eQpJom73ZmEO5us4guI4SnCLJYTfuUpRgqes,4086
141
141
  replay/models/random_rec.py,sha256=9SC012_X3sNzrAjDG1CPGhjisZb6gnv4VCW7yIMSNpk,8066
@@ -145,36 +145,36 @@ replay/models/ucb.py,sha256=X98ulD8L3gWR3VA7rbQkXFqQyzWc-Nt12lp_gbLTfLQ,6964
145
145
  replay/models/wilson.py,sha256=o7aUWjq3648dAfgGBoWD5Gu-HzdyobPMaH2lzCLijiA,4558
146
146
  replay/models/word2vec.py,sha256=MgoRIS5vqW9cH1HKAGa2xsLLnTH6XC1EXk4Dzvn5lXA,9171
147
147
  replay/optimization/__init__.py,sha256=az6U10rF7X6rPRUUPwLyiM1WFNJ_6kl0imA5xLVWFLs,120
148
- replay/optimization/optuna_objective.py,sha256=Z-8X0_FT3BicVWj0UhxoLrvZAck3Dhn7jHDGo0i0hxA,7653
148
+ replay/optimization/optuna_objective.py,sha256=OUYlC3wQj4GmrSbE_z5IPPS6OEEPUoeRCWFJnIR1Na8,7654
149
149
  replay/preprocessing/__init__.py,sha256=TtBysFqYeDy4kZAEnWEaNSwPvbffYdfMkEs71YG51fM,411
150
- replay/preprocessing/converter.py,sha256=DczqsVLrwFi6EFhK2HR8rGiIxGCwXeY7QNgWorjA41g,4390
151
- replay/preprocessing/filters.py,sha256=wsXWQoZ-2aAecunLkaTxeLWi5ow4e3FAGcElx0iNx0w,41669
152
- replay/preprocessing/history_based_fp.py,sha256=tfgKJPKm53LSNqM6VmMXYsVrRDc-rP1Tbzn8s3mbziQ,18751
153
- replay/preprocessing/label_encoder.py,sha256=MLBavPD-dB644as0E9ZJSE9-8QxGCB_IHek1w3xtqDI,27040
150
+ replay/preprocessing/converter.py,sha256=JQ-4u5x0eXtswl1iH-bZITBXQov1nebnZ6XcvpD8Twk,4417
151
+ replay/preprocessing/filters.py,sha256=4Lk3gnNwksPscdW6a47qJ_r8QEpbYRuNqTPJ9-bvSRo,41743
152
+ replay/preprocessing/history_based_fp.py,sha256=Wb2DXHawE2dYghm1ARr05_5opd_TLfthZ7h5e0zbDjY,18726
153
+ replay/preprocessing/label_encoder.py,sha256=JrVNP93NVt630OFmacQ6MlkH7rTLIPog05-0vyBuQtQ,27041
154
154
  replay/preprocessing/sessionizer.py,sha256=G6i0K3FwqtweRxvcSYraJ-tBWAT2HnV-bWHHlIZJF-s,12217
155
- replay/scenarios/__init__.py,sha256=kw2wRkPPinw0IBA20D83XQ3xeSudk3KuYAAA1Wdr8xY,93
155
+ replay/scenarios/__init__.py,sha256=XXAKEQPTLlve-0O6NPwFgahFrb4oGcIq3HaYaaGxG2E,94
156
156
  replay/scenarios/fallback.py,sha256=EeBmIR-5igzKR2m55bQRFyhxTkpJez6ZkCW449n8hWs,7130
157
157
  replay/splitters/__init__.py,sha256=DnqVMelrzLwR8fGQgcWN_8FipGs8T4XGSPOMW-L_x2g,454
158
158
  replay/splitters/base_splitter.py,sha256=hj9_GYDWllzv3XnxN6WHu1JKRRVjXo77vZEOLbF9v-s,7761
159
159
  replay/splitters/cold_user_random_splitter.py,sha256=gVwBVdn_0IOaLGT_UzJoS9AMaPhelZy-FpC5JQS1PhA,4136
160
160
  replay/splitters/k_folds.py,sha256=WH02_DP18A2ae893ysonmfLPB56_i1ETllTAwaCYekg,6218
161
- replay/splitters/last_n_splitter.py,sha256=r9kdq2JPi508C9ywjwc68an-iq27KsigMfHWLz0YohE,15346
161
+ replay/splitters/last_n_splitter.py,sha256=ITq8yzd7PrbAi3yp5XJlBehq0E0boiPyTEn72sXZEOA,15347
162
162
  replay/splitters/new_users_splitter.py,sha256=bv_QCPkL7KFxJIovAXQbP3Rlty3My48YNTqrj-2ucFQ,9167
163
163
  replay/splitters/random_splitter.py,sha256=mbOcxeF0B9WQ9OSxK8CHkPtO8UzKCZJt3rRyFhn-hyQ,2996
164
164
  replay/splitters/ratio_splitter.py,sha256=8zvuCn16Icc4ntQPKXJ5ArAWuJzCZ9NHZtgWctKyBVY,17519
165
- replay/splitters/time_splitter.py,sha256=iXhuafjBx7dWyJSy-TEVy1IUQBwMpA1gAiF4-GtRe2g,9031
166
- replay/splitters/two_stage_splitter.py,sha256=PWozxjjgjrVzdz6Sm9dcDTeH0bOA24reFzkk_N_TgbQ,17734
165
+ replay/splitters/time_splitter.py,sha256=tsoK3Qg_pcYHDxBlv2xC8ohAikoIqac3fRGBvCb-QRo,9026
166
+ replay/splitters/two_stage_splitter.py,sha256=U90l1wfJnMAAW1j4YpJSd8zWvWB-LDUKFCifnanXraU,17830
167
167
  replay/utils/__init__.py,sha256=vDJgOWq81fbBs-QO4ZDpdqR4KDyO1kMOOxBRi-5Gp7E,253
168
168
  replay/utils/common.py,sha256=s4Pro3QCkPeVBsj-s0vrbhd_pkJD-_-2M_sIguxGzQQ,5411
169
169
  replay/utils/dataframe_bucketizer.py,sha256=LipmBBQkdkLGroZpbP9i7qvTombLdMxo2dUUys1m5OY,3748
170
- replay/utils/distributions.py,sha256=kGGq2KzQZ-yhTuw_vtOsKFXVpXUOQ2l4aIFBcaDufZ8,1202
170
+ replay/utils/distributions.py,sha256=UuhaC9HI6HnUXW97fEd-TsyDk4JT8t7k1T_6l5FpOMs,1203
171
171
  replay/utils/model_handler.py,sha256=V-mHDh8_UexjVSsMBBRA9yrjS_5MPHwYOwv_UrI-Zfs,6466
172
- replay/utils/session_handler.py,sha256=ijTvDSNAe1D9R1e-dhtd-r80tFNiIBsFdWZLgw-gLEo,5153
173
- replay/utils/spark_utils.py,sha256=k5lUFM2C9QZKQON3dqhgfswyUF4tsgJOn0U2wCKimqM,26901
172
+ replay/utils/session_handler.py,sha256=RYzQvvOnukundccEBnH4ghEdyUgiGB9etz5e3Elvfgw,5157
173
+ replay/utils/spark_utils.py,sha256=LBzS8PJc6Mq8q7S_f6BbQZkeOEW49briAdp--pwFWbs,26891
174
174
  replay/utils/time.py,sha256=J8asoQBytPcNw-BLGADYIsKeWhIoN1H5hKiX9t2AMqo,9376
175
175
  replay/utils/types.py,sha256=5sw0A7NG4ZgQKdWORnBy0wBZ5F98sP_Ju8SKQ6zbDS4,651
176
- replay_rec-0.17.1rc0.dist-info/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
177
- replay_rec-0.17.1rc0.dist-info/METADATA,sha256=FgZduBS6AVq1qSNahVyNFCJILLPdVLVosbxjUxN7WkQ,10890
178
- replay_rec-0.17.1rc0.dist-info/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
179
- replay_rec-0.17.1rc0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
180
- replay_rec-0.17.1rc0.dist-info/RECORD,,
176
+ replay_rec-0.18.0rc0.dist-info/LICENSE,sha256=rPmcA7UrHxBChEAAlJyE24qUWKKl9yLQXxFsKeg_LX4,11344
177
+ replay_rec-0.18.0rc0.dist-info/METADATA,sha256=u_aqIEAypmp3QkU8Jgt0knTsyUbgxcF2lBGz2evOdIg,11164
178
+ replay_rec-0.18.0rc0.dist-info/NOTICE,sha256=k0bo4KHiHLRax5K3XKTTrf2Fi8V91mJ-R3FMdh6Reg0,2002
179
+ replay_rec-0.18.0rc0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
180
+ replay_rec-0.18.0rc0.dist-info/RECORD,,