replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -9,39 +9,37 @@ if PYSPARK_AVAILABLE:
9
9
  from replay.utils.session_handler import State
10
10
 
11
11
 
12
- class DataframeBucketizer(
13
- Transformer, DefaultParamsWritable, DefaultParamsReadable
14
- ): # pylint: disable=R0901
12
+ class DataframeBucketizer(Transformer, DefaultParamsWritable, DefaultParamsReadable):
15
13
  """
16
14
  Buckets the input dataframe, dumps it to spark warehouse directory,
17
15
  and returns a bucketed dataframe.
18
16
  """
19
17
 
20
- bucketingKey = Param(
18
+ bucketing_key = Param(
21
19
  Params._dummy(),
22
- "bucketingKey",
20
+ "bucketing_key",
23
21
  "bucketing key (also used as sort key)",
24
22
  typeConverter=TypeConverters.toString,
25
23
  )
26
24
 
27
- partitionNum = Param(
25
+ partition_num = Param(
28
26
  Params._dummy(),
29
- "partitionNum",
27
+ "partition_num",
30
28
  "number of buckets",
31
29
  typeConverter=TypeConverters.toInt,
32
30
  )
33
31
 
34
- tableName = Param(
32
+ table_name = Param(
35
33
  Params._dummy(),
36
- "tableName",
34
+ "table_name",
37
35
  "parquet file name (for storage in 'spark-warehouse') and spark table name",
38
36
  typeConverter=TypeConverters.toString,
39
37
  )
40
38
 
41
- sparkWarehouseDir = Param(
39
+ spark_warehouse_dir = Param(
42
40
  Params._dummy(),
43
- "sparkWarehouseDir",
44
- "sparkWarehouseDir",
41
+ "spark_warehouse_dir",
42
+ "spark_warehouse_dir",
45
43
  typeConverter=TypeConverters.toString,
46
44
  )
47
45
 
@@ -62,10 +60,10 @@ class DataframeBucketizer(
62
60
  i.e. value of 'spark.sql.warehouse.dir' property
63
61
  """
64
62
  super().__init__()
65
- self.set(self.bucketingKey, bucketing_key)
66
- self.set(self.partitionNum, partition_num)
67
- self.set(self.tableName, table_name)
68
- self.set(self.sparkWarehouseDir, spark_warehouse_dir)
63
+ self.set(self.bucketing_key, bucketing_key)
64
+ self.set(self.partition_num, partition_num)
65
+ self.set(self.table_name, table_name)
66
+ self.set(self.spark_warehouse_dir, spark_warehouse_dir)
69
67
 
70
68
  def __enter__(self):
71
69
  return self
@@ -76,31 +74,27 @@ class DataframeBucketizer(
76
74
  def remove_parquet(self):
77
75
  """Removes parquets where bucketed dataset is stored"""
78
76
  spark = State().session
79
- spark_warehouse_dir = self.getOrDefault(self.sparkWarehouseDir)
80
- table_name = self.getOrDefault(self.tableName)
81
- fs = get_fs(spark) # pylint: disable=invalid-name
82
- fs_path = spark._jvm.org.apache.hadoop.fs.Path(
83
- f"{spark_warehouse_dir}/{table_name}"
84
- )
77
+ spark_warehouse_dir = self.getOrDefault(self.spark_warehouse_dir)
78
+ table_name = self.getOrDefault(self.table_name)
79
+ fs = get_fs(spark)
80
+ fs_path = spark._jvm.org.apache.hadoop.fs.Path(f"{spark_warehouse_dir}/{table_name}")
85
81
  is_exists = fs.exists(fs_path)
86
82
  if is_exists:
87
83
  fs.delete(fs_path, True)
88
84
 
89
85
  def set_table_name(self, table_name: str):
90
86
  """Sets table name"""
91
- self.set(self.tableName, table_name)
87
+ self.set(self.table_name, table_name)
92
88
 
93
89
  def _transform(self, dataset: SparkDataFrame):
94
- bucketing_key = self.getOrDefault(self.bucketingKey)
95
- partition_num = self.getOrDefault(self.partitionNum)
96
- table_name = self.getOrDefault(self.tableName)
97
- spark_warehouse_dir = self.getOrDefault(self.sparkWarehouseDir)
90
+ bucketing_key = self.getOrDefault(self.bucketing_key)
91
+ partition_num = self.getOrDefault(self.partition_num)
92
+ table_name = self.getOrDefault(self.table_name)
93
+ spark_warehouse_dir = self.getOrDefault(self.spark_warehouse_dir)
98
94
 
99
95
  if not table_name:
100
- raise ValueError(
101
- "Parameter 'table_name' is not set! "
102
- "Please set it via method 'set_table_name'."
103
- )
96
+ msg = "Parameter 'table_name' is not set! Please set it via method 'set_table_name'."
97
+ raise ValueError(msg)
104
98
 
105
99
  (
106
100
  dataset.repartition(partition_num, bucketing_key)
@@ -22,23 +22,11 @@ def item_distribution(
22
22
  :return: DataFrame with results
23
23
  """
24
24
  log = convert2spark(log)
25
- res = (
26
- log.groupBy("item_idx")
27
- .agg(sf.countDistinct("user_idx").alias("user_count"))
28
- .select("item_idx", "user_count")
29
- )
25
+ res = log.groupBy("item_idx").agg(sf.countDistinct("user_idx").alias("user_count")).select("item_idx", "user_count")
30
26
 
31
27
  rec = convert2spark(recommendations)
32
28
  rec = get_top_k_recs(rec, k)
33
- rec = (
34
- rec.groupBy("item_idx")
35
- .agg(sf.countDistinct("user_idx").alias("rec_count"))
36
- .select("item_idx", "rec_count")
37
- )
29
+ rec = rec.groupBy("item_idx").agg(sf.countDistinct("user_idx").alias("rec_count")).select("item_idx", "rec_count")
38
30
 
39
- res = (
40
- res.join(rec, on="item_idx", how="outer")
41
- .fillna(0)
42
- .orderBy(["user_count", "item_idx"])
43
- )
31
+ res = res.join(rec, on="item_idx", how="outer").fillna(0).orderBy(["user_count", "item_idx"])
44
32
  return spark_to_pandas(res, allow_collect_to_master)
@@ -1,17 +1,18 @@
1
- # pylint: disable=wildcard-import,invalid-name,unused-wildcard-import,unspecified-encoding
1
+ import functools
2
2
  import json
3
3
  import os
4
4
  import pickle
5
+ import warnings
5
6
  from os.path import join
6
7
  from pathlib import Path
7
- from typing import Union
8
+ from typing import Any, Callable, Optional, Union
8
9
 
9
10
  from replay.data.dataset_utils import DatasetLabelEncoder
10
11
  from replay.models import *
11
12
  from replay.models.base_rec import BaseRecommender
12
13
  from replay.splitters import *
13
- from .session_handler import State
14
14
 
15
+ from .session_handler import State
15
16
  from .types import PYSPARK_AVAILABLE
16
17
 
17
18
  if PYSPARK_AVAILABLE:
@@ -26,9 +27,7 @@ if PYSPARK_AVAILABLE:
26
27
  :param spark: spark session
27
28
  :return:
28
29
  """
29
- fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(
30
- spark._jsc.hadoopConfiguration()
31
- )
30
+ fs = spark._jvm.org.apache.hadoop.fs.FileSystem.get(spark._jsc.hadoopConfiguration())
32
31
  return fs
33
32
 
34
33
  def get_list_of_paths(spark: SparkSession, dir_path: str):
@@ -44,9 +43,7 @@ if PYSPARK_AVAILABLE:
44
43
  return [str(f.getPath()) for f in statuses]
45
44
 
46
45
 
47
- def save(
48
- model: BaseRecommender, path: Union[str, Path], overwrite: bool = False
49
- ):
46
+ def save(model: BaseRecommender, path: Union[str, Path], overwrite: bool = False):
50
47
  """
51
48
  Save fitted model to disk as a folder
52
49
 
@@ -63,9 +60,8 @@ def save(
63
60
  if not overwrite:
64
61
  is_exists = fs.exists(spark._jvm.org.apache.hadoop.fs.Path(path))
65
62
  if is_exists:
66
- raise FileExistsError(
67
- f"Path '{path}' already exists. Mode is 'overwrite = False'."
68
- )
63
+ msg = f"Path '{path}' already exists. Mode is 'overwrite = False'."
64
+ raise FileExistsError(msg)
69
65
 
70
66
  fs.mkdirs(spark._jvm.org.apache.hadoop.fs.Path(path))
71
67
  model._save_model(join(path, "model"))
@@ -74,9 +70,7 @@ def save(
74
70
  init_args["_model_name"] = str(model)
75
71
  sc = spark.sparkContext
76
72
  df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
77
- df.coalesce(1).write.mode("overwrite").option(
78
- "ignoreNullFields", "false"
79
- ).json(join(path, "init_args.json"))
73
+ df.coalesce(1).write.mode("overwrite").option("ignoreNullFields", "false").json(join(path, "init_args.json"))
80
74
 
81
75
  dataframes = model._dataframes
82
76
  df_path = join(path, "dataframes")
@@ -85,13 +79,9 @@ def save(
85
79
  df.write.mode("overwrite").parquet(join(df_path, name))
86
80
 
87
81
  if hasattr(model, "fit_queries"):
88
- model.fit_queries.write.mode("overwrite").parquet(
89
- join(df_path, "fit_queries")
90
- )
82
+ model.fit_queries.write.mode("overwrite").parquet(join(df_path, "fit_queries"))
91
83
  if hasattr(model, "fit_items"):
92
- model.fit_items.write.mode("overwrite").parquet(
93
- join(df_path, "fit_items")
94
- )
84
+ model.fit_items.write.mode("overwrite").parquet(join(df_path, "fit_items"))
95
85
  if hasattr(model, "study"):
96
86
  save_picklable_to_parquet(model.study, join(path, "study"))
97
87
 
@@ -104,18 +94,11 @@ def load(path: str, model_type=None) -> BaseRecommender:
104
94
  :return: Restored trained model
105
95
  """
106
96
  spark = State().session
107
- args = (
108
- spark.read.json(join(path, "init_args.json"))
109
- .first()
110
- .asDict(recursive=True)
111
- )
97
+ args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True)
112
98
  name = args["_model_name"]
113
99
  del args["_model_name"]
114
100
 
115
- if model_type is not None:
116
- model_class = model_type
117
- else:
118
- model_class = globals()[name]
101
+ model_class = model_type if model_type is not None else globals()[name]
119
102
 
120
103
  model = model_class(**args)
121
104
 
@@ -180,9 +163,7 @@ def save_splitter(splitter: Splitter, path: str, overwrite: bool = False):
180
163
  sc = spark.sparkContext
181
164
  df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
182
165
  if overwrite:
183
- df.coalesce(1).write.mode("overwrite").json(
184
- join(path, "init_args.json")
185
- )
166
+ df.coalesce(1).write.mode("overwrite").json(join(path, "init_args.json"))
186
167
  else:
187
168
  df.coalesce(1).write.json(join(path, "init_args.json"))
188
169
 
@@ -200,3 +181,25 @@ def load_splitter(path: str) -> Splitter:
200
181
  del args["_splitter_name"]
201
182
  splitter = globals()[name]
202
183
  return splitter(**args)
184
+
185
+
186
+ def deprecation_warning(message: Optional[str] = None) -> Callable[..., Any]:
187
+ """
188
+ Decorator that throws deprecation warnings.
189
+
190
+ :param message: message to deprecation warning without func name.
191
+ """
192
+ base_msg = "will be deprecated in future versions."
193
+
194
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
195
+ @functools.wraps(func)
196
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
197
+ msg = f"{func.__qualname__} {message if message else base_msg}"
198
+ warnings.simplefilter("always", DeprecationWarning) # turn off filter
199
+ warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
200
+ warnings.simplefilter("default", DeprecationWarning) # reset filter
201
+ return func(*args, **kwargs)
202
+
203
+ return wrapper
204
+
205
+ return decorator
@@ -36,7 +36,6 @@ def get_spark_session(
36
36
  Default: ``None``.
37
37
  """
38
38
  if os.environ.get("SCRIPT_ENV", None) == "cluster": # pragma: no cover
39
- # pylint: disable=no-member
40
39
  return SparkSession.builder.getOrCreate()
41
40
 
42
41
  os.environ["PYSPARK_PYTHON"] = sys.executable
@@ -46,33 +45,32 @@ def get_spark_session(
46
45
  path_to_replay_jar = os.environ.get("REPLAY_JAR_PATH")
47
46
  else:
48
47
  if pyspark_version.startswith("3.1"): # pragma: no cover
49
- path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar"
50
- elif pyspark_version.startswith("3.2") or pyspark_version.startswith(
51
- "3.3"
52
- ):
48
+ path_to_replay_jar = (
49
+ "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar"
50
+ )
51
+ elif pyspark_version.startswith(("3.2", "3.3")):
53
52
  path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.2.0_als_metrics/replay_2.12-3.2.0_als_metrics.jar"
54
53
  elif pyspark_version.startswith("3.4"): # pragma: no cover
55
54
  path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.4.0_als_metrics/replay_2.12-3.4.0_als_metrics.jar"
56
55
  else: # pragma: no cover
57
- path_to_replay_jar = "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar"
56
+ path_to_replay_jar = (
57
+ "https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar"
58
+ )
58
59
  logging.warning(
59
- "Replay ALS model support only spark 3.1-3.4 versions! "
60
- "Replay will use 'https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar' in 'spark.jars' property."
60
+ "Replay ALS model support only spark 3.1-3.4 versions! Replay will use "
61
+ "'https://repo1.maven.org/maven2/io/github/sb-ai-lab/replay_2.12/3.1.3/replay_2.12-3.1.3.jar' "
62
+ "in 'spark.jars' property."
61
63
  )
62
64
 
63
65
  if core_count is None: # checking out env variable
64
66
  core_count = int(os.environ.get("REPLAY_SPARK_CORE_COUNT", "-1"))
65
67
  if spark_memory is None:
66
68
  env_var = os.environ.get("REPLAY_SPARK_MEMORY")
67
- if env_var is not None: # pragma: no cover
68
- spark_memory = int(env_var)
69
- else: # pragma: no cover
70
- spark_memory = floor(psutil.virtual_memory().total / 1024**3 * 0.7)
69
+ spark_memory = int(env_var) if env_var is not None else floor(psutil.virtual_memory().total / 1024**3 * 0.7)
71
70
  if shuffle_partitions is None:
72
71
  shuffle_partitions = os.cpu_count() * 3
73
72
  driver_memory = f"{spark_memory}g"
74
73
  user_home = os.environ["HOME"]
75
- # pylint: disable=no-member
76
74
  spark = (
77
75
  SparkSession.builder.config("spark.driver.memory", driver_memory)
78
76
  .config(
@@ -111,7 +109,6 @@ def logger_with_settings() -> logging.Logger:
111
109
  return logger
112
110
 
113
111
 
114
- # pylint: disable=too-few-public-methods
115
112
  class Borg:
116
113
  """
117
114
  This class allows to share objects between instances.
@@ -123,7 +120,6 @@ class Borg:
123
120
  self.__dict__ = self._shared_state
124
121
 
125
122
 
126
- # pylint: disable=too-few-public-methods
127
123
  class State(Borg):
128
124
  """
129
125
  All modules look for Spark session via this class. You can put your own session here.
@@ -10,14 +10,17 @@ import pandas as pd
10
10
  from numpy.random import default_rng
11
11
 
12
12
  from .session_handler import State
13
-
14
13
  from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImportType, NumType, SparkDataFrame
15
14
 
16
15
  if PYSPARK_AVAILABLE:
17
16
  import pyspark.sql.types as st
18
17
  from pyspark.ml.linalg import DenseVector, Vectors, VectorUDT
19
- from pyspark.sql import Column, SparkSession, Window
20
- from pyspark.sql import functions as sf
18
+ from pyspark.sql import (
19
+ Column,
20
+ SparkSession,
21
+ Window,
22
+ functions as sf,
23
+ )
21
24
  from pyspark.sql.column import _to_java_column, _to_seq
22
25
  from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType
23
26
  else:
@@ -48,7 +51,6 @@ def spark_to_pandas(data: SparkDataFrame, allow_collect_to_master: bool = False)
48
51
  return data.toPandas()
49
52
 
50
53
 
51
- # pylint: disable=invalid-name
52
54
  def convert2spark(data_frame: Optional[DataFrameLike]) -> Optional[SparkDataFrame]:
53
55
  """
54
56
  Converts Pandas DataFrame to Spark DataFrame
@@ -61,7 +63,7 @@ def convert2spark(data_frame: Optional[DataFrameLike]) -> Optional[SparkDataFram
61
63
  if isinstance(data_frame, SparkDataFrame):
62
64
  return data_frame
63
65
  spark = State().session
64
- return spark.createDataFrame(data_frame) # type: ignore
66
+ return spark.createDataFrame(data_frame)
65
67
 
66
68
 
67
69
  def get_top_k(
@@ -76,7 +78,11 @@ def get_top_k(
76
78
 
77
79
  >>> from replay.utils.session_handler import State
78
80
  >>> spark = State().session
79
- >>> log = spark.createDataFrame([(1, 2, 1.), (1, 3, 1.), (1, 4, 0.5), (2, 1, 1.)]).toDF("user_id", "item_id", "relevance")
81
+ >>> log = (
82
+ ... spark
83
+ ... .createDataFrame([(1, 2, 1.), (1, 3, 1.), (1, 4, 0.5), (2, 1, 1.)])
84
+ ... .toDF("user_id", "item_id", "relevance")
85
+ ... )
80
86
  >>> log.show()
81
87
  +-------+-------+---------+
82
88
  |user_id|item_id|relevance|
@@ -108,9 +114,7 @@ def get_top_k(
108
114
  return (
109
115
  dataframe.withColumn(
110
116
  "temp_rank",
111
- sf.row_number().over(
112
- Window.partitionBy(partition_by_col).orderBy(*order_by_col)
113
- ),
117
+ sf.row_number().over(Window.partitionBy(partition_by_col).orderBy(*order_by_col)),
114
118
  )
115
119
  .filter(sf.col("temp_rank") <= k)
116
120
  .drop("temp_rank")
@@ -141,6 +145,7 @@ def get_top_k_recs(
141
145
 
142
146
 
143
147
  if PYSPARK_AVAILABLE:
148
+
144
149
  @sf.udf(returnType=st.DoubleType())
145
150
  def vector_dot(one: DenseVector, two: DenseVector) -> float: # pragma: no cover
146
151
  """
@@ -179,10 +184,8 @@ if PYSPARK_AVAILABLE:
179
184
  """
180
185
  return float(one.dot(two))
181
186
 
182
- @sf.udf(returnType=VectorUDT()) # type: ignore
183
- def vector_mult(
184
- one: Union[DenseVector, NumType], two: DenseVector
185
- ) -> DenseVector: # pragma: no cover
187
+ @sf.udf(returnType=VectorUDT())
188
+ def vector_mult(one: Union[DenseVector, NumType], two: DenseVector) -> DenseVector: # pragma: no cover
186
189
  """
187
190
  elementwise vector multiplication
188
191
 
@@ -271,9 +274,7 @@ def multiply_scala_udf(scalar, vector):
271
274
  return Column(_f.apply(_to_seq(sc, [scalar, vector], _to_java_column)))
272
275
 
273
276
 
274
- def get_log_info(
275
- log: SparkDataFrame, user_col="user_idx", item_col="item_idx"
276
- ) -> str:
277
+ def get_log_info(log: SparkDataFrame, user_col="user_idx", item_col="item_idx") -> str:
277
278
  """
278
279
  Basic log statistics
279
280
 
@@ -310,9 +311,7 @@ def get_log_info(
310
311
  )
311
312
 
312
313
 
313
- def get_stats(
314
- log: SparkDataFrame, group_by: str = "user_id", target_column: str = "relevance"
315
- ) -> SparkDataFrame:
314
+ def get_stats(log: SparkDataFrame, group_by: str = "user_id", target_column: str = "relevance") -> SparkDataFrame:
316
315
  """
317
316
  Calculate log statistics: min, max, mean, median ratings, number of ratings.
318
317
  >>> from replay.utils.session_handler import get_spark_session, State
@@ -351,14 +350,9 @@ def get_stats(
351
350
  "count": sf.count,
352
351
  }
353
352
  agg_functions_list = [
354
- func(target_column).alias(str(name + "_" + target_column))
355
- for name, func in agg_functions.items()
353
+ func(target_column).alias(str(name + "_" + target_column)) for name, func in agg_functions.items()
356
354
  ]
357
- agg_functions_list.append(
358
- sf.expr(f"percentile_approx({target_column}, 0.5)").alias(
359
- "median_" + target_column
360
- )
361
- )
355
+ agg_functions_list.append(sf.expr(f"percentile_approx({target_column}, 0.5)").alias("median_" + target_column))
362
356
 
363
357
  return log.groupBy(group_by).agg(*agg_functions_list)
364
358
 
@@ -369,13 +363,9 @@ def check_numeric(feature_table: SparkDataFrame) -> None:
369
363
  :param feature_table: spark DataFrame
370
364
  """
371
365
  for column in feature_table.columns:
372
- if not isinstance(
373
- feature_table.schema[column].dataType, st.NumericType
374
- ):
375
- raise ValueError(
376
- f"""Column {column} has type {feature_table.schema[
377
- column].dataType}, that is not numeric."""
378
- )
366
+ if not isinstance(feature_table.schema[column].dataType, st.NumericType):
367
+ msg = f"Column {column} has type {feature_table.schema[column].dataType}, that is not numeric."
368
+ raise ValueError(msg)
379
369
 
380
370
 
381
371
  def horizontal_explode(
@@ -420,10 +410,7 @@ def horizontal_explode(
420
410
  num_columns = len(data_frame.select(column_to_explode).head()[0])
421
411
  return data_frame.select(
422
412
  *other_columns,
423
- *[
424
- sf.element_at(column_to_explode, i + 1).alias(f"{prefix}_{i}")
425
- for i in range(num_columns)
426
- ],
413
+ *[sf.element_at(column_to_explode, i + 1).alias(f"{prefix}_{i}") for i in range(num_columns)],
427
414
  )
428
415
 
429
416
 
@@ -442,7 +429,6 @@ def join_or_return(first, second, on, how):
442
429
  return first.join(second, on=on, how=how)
443
430
 
444
431
 
445
- # pylint: disable=too-many-arguments
446
432
  def fallback(
447
433
  base: SparkDataFrame,
448
434
  fill: SparkDataFrame,
@@ -471,15 +457,11 @@ def fallback(
471
457
  diff = max_in_fill - min_in_base
472
458
  fill = fill.withColumnRenamed(rating_column, "relevance_fallback")
473
459
  if diff >= 0:
474
- fill = fill.withColumn(
475
- "relevance_fallback", sf.col("relevance_fallback") - diff - margin
476
- )
477
- recs = base.join(
478
- fill, on=[query_column, item_column], how="full_outer"
460
+ fill = fill.withColumn("relevance_fallback", sf.col("relevance_fallback") - diff - margin)
461
+ recs = base.join(fill, on=[query_column, item_column], how="full_outer")
462
+ recs = recs.withColumn(rating_column, sf.coalesce(rating_column, "relevance_fallback")).select(
463
+ query_column, item_column, rating_column
479
464
  )
480
- recs = recs.withColumn(
481
- rating_column, sf.coalesce(rating_column, "relevance_fallback")
482
- ).select(query_column, item_column, rating_column)
483
465
  recs = get_top_k_recs(recs, k, query_column=query_column, rating_column=rating_column)
484
466
  return recs
485
467
 
@@ -537,9 +519,7 @@ def join_with_col_renaming(
537
519
  right = right.withColumnRenamed(name, f"{name}_{suffix}")
538
520
  on_condition &= sf.col(name) == sf.col(f"{name}_{suffix}")
539
521
 
540
- return (left.join(right, on=on_condition, how=how)).drop(
541
- *[f"{name}_{suffix}" for name in on_col_name]
542
- )
522
+ return (left.join(right, on=on_condition, how=how)).drop(*[f"{name}_{suffix}" for name in on_col_name])
543
523
 
544
524
 
545
525
  def process_timestamp_column(
@@ -562,7 +542,8 @@ def process_timestamp_column(
562
542
  :return: dataframe with updated column ``column_name``
563
543
  """
564
544
  if column_name not in dataframe.columns:
565
- raise ValueError(f"Column {column_name} not found")
545
+ msg = f"Column {column_name} not found"
546
+ raise ValueError(msg)
566
547
 
567
548
  # no conversion needed
568
549
  if isinstance(dataframe.schema[column_name].dataType, st.TimestampType):
@@ -570,9 +551,7 @@ def process_timestamp_column(
570
551
 
571
552
  # unix timestamp
572
553
  if isinstance(dataframe.schema[column_name].dataType, st.NumericType):
573
- return dataframe.withColumn(
574
- column_name, sf.to_timestamp(sf.from_unixtime(sf.col(column_name)))
575
- )
554
+ return dataframe.withColumn(column_name, sf.to_timestamp(sf.from_unixtime(sf.col(column_name))))
576
555
 
577
556
  # datetime in string format
578
557
  dataframe = dataframe.withColumn(
@@ -583,6 +562,7 @@ def process_timestamp_column(
583
562
 
584
563
 
585
564
  if PYSPARK_AVAILABLE:
565
+
586
566
  @sf.udf(returnType=VectorUDT())
587
567
  def list_to_vector_udf(array: st.ArrayType) -> DenseVector: # pragma: no cover
588
568
  """
@@ -603,9 +583,7 @@ if PYSPARK_AVAILABLE:
603
583
  return float(first.squared_distance(second))
604
584
 
605
585
  @sf.udf(returnType=st.FloatType())
606
- def vector_euclidean_distance_similarity(
607
- first: DenseVector, second: DenseVector
608
- ) -> float: # pragma: no cover
586
+ def vector_euclidean_distance_similarity(first: DenseVector, second: DenseVector) -> float: # pragma: no cover
609
587
  """
610
588
  :param first: first vector
611
589
  :param second: second vector
@@ -642,7 +620,7 @@ def drop_temp_view(temp_view_name: str) -> None:
642
620
  spark.catalog.dropTempView(temp_view_name)
643
621
 
644
622
 
645
- def sample_top_k_recs(pairs: SparkDataFrame, k: int, seed: int = None):
623
+ def sample_top_k_recs(pairs: SparkDataFrame, k: int, seed: Optional[int] = None):
646
624
  """
647
625
  Sample k items for each user with probability proportional to the relevance score.
648
626
 
@@ -660,17 +638,13 @@ def sample_top_k_recs(pairs: SparkDataFrame, k: int, seed: int = None):
660
638
  """
661
639
  pairs = pairs.withColumn(
662
640
  "probability",
663
- sf.col("relevance")
664
- / sf.sum("relevance").over(Window.partitionBy("user_idx")),
641
+ sf.col("relevance") / sf.sum("relevance").over(Window.partitionBy("user_idx")),
665
642
  )
666
643
 
667
644
  def grouped_map(pandas_df: pd.DataFrame) -> pd.DataFrame: # pragma: no cover
668
645
  user_idx = pandas_df["user_idx"][0]
669
646
 
670
- if seed is not None:
671
- local_rng = default_rng(seed + user_idx)
672
- else:
673
- local_rng = default_rng()
647
+ local_rng = default_rng(seed + user_idx) if seed is not None else default_rng()
674
648
 
675
649
  items_positions = local_rng.choice(
676
650
  np.arange(pandas_df.shape[0]),
@@ -686,6 +660,7 @@ def sample_top_k_recs(pairs: SparkDataFrame, k: int, seed: int = None):
686
660
  "relevance": pandas_df["relevance"].values[items_positions],
687
661
  }
688
662
  )
663
+
689
664
  rec_schema = StructType(
690
665
  [
691
666
  StructField("user_idx", IntegerType()),
@@ -716,19 +691,12 @@ def filter_cold(
716
691
  if df is None:
717
692
  return 0, df
718
693
 
719
- num_cold = (
720
- df.select(col_name)
721
- .distinct()
722
- .join(warm_df, on=col_name, how="anti")
723
- .count()
724
- )
694
+ num_cold = df.select(col_name).distinct().join(warm_df, on=col_name, how="anti").count()
725
695
 
726
696
  if num_cold == 0:
727
697
  return 0, df
728
698
 
729
- return num_cold, df.join(
730
- warm_df.select(col_name), on=col_name, how="inner"
731
- )
699
+ return num_cold, df.join(warm_df.select(col_name), on=col_name, how="inner")
732
700
 
733
701
 
734
702
  def get_unique_entities(
@@ -745,17 +713,14 @@ def get_unique_entities(
745
713
  if isinstance(df, SparkDataFrame):
746
714
  unique = df.select(column).distinct()
747
715
  elif isinstance(df, collections.abc.Iterable):
748
- unique = spark.createDataFrame(
749
- data=pd.DataFrame(pd.unique(list(df)), columns=[column])
750
- )
716
+ unique = spark.createDataFrame(data=pd.DataFrame(pd.unique(list(df)), columns=[column]))
751
717
  else:
752
- raise ValueError(f"Wrong type {type(df)}")
718
+ msg = f"Wrong type {type(df)}"
719
+ raise ValueError(msg)
753
720
  return unique
754
721
 
755
722
 
756
- def return_recs(
757
- recs: SparkDataFrame, recs_file_path: Optional[str] = None
758
- ) -> Optional[SparkDataFrame]:
723
+ def return_recs(recs: SparkDataFrame, recs_file_path: Optional[str] = None) -> Optional[SparkDataFrame]:
759
724
  """
760
725
  Save dataframe `recs` to `recs_file_path` if presents otherwise cache
761
726
  and materialize the dataframe.
@@ -785,7 +750,7 @@ def save_picklable_to_parquet(obj: Any, path: str) -> None:
785
750
  sc = State().session.sparkContext
786
751
  # We can use `RDD.saveAsPickleFile`, but it has no "overwrite" parameter
787
752
  pickled_instance = pickle.dumps(obj)
788
- Record = collections.namedtuple("Record", ["data"])
753
+ Record = collections.namedtuple("Record", ["data"]) # noqa: PYI024
789
754
  rdd = sc.parallelize([Record(pickled_instance)])
790
755
  instance_df = rdd.map(lambda rec: Record(bytearray(rec.data))).toDF()
791
756
  instance_df.write.mode("overwrite").parquet(path)
@@ -812,9 +777,10 @@ def assert_omp_single_thread():
812
777
  PyTorch uses multithreading for cpu math operations via OpenMP library. Sometimes this
813
778
  leads to failures when OpenMP multithreading is mixed with multiprocessing.
814
779
  """
815
- omp_num_threads = os.environ.get('OMP_NUM_THREADS', None)
816
- if omp_num_threads != '1':
817
- logging.getLogger("replay").warning(
818
- 'Environment variable "OMP_NUM_THREADS" is set to "%s". '
819
- 'Set it to 1 if the working process freezes.', omp_num_threads
780
+ omp_num_threads = os.environ.get("OMP_NUM_THREADS", None)
781
+ if omp_num_threads != "1":
782
+ msg = (
783
+ f'Environment variable "OMP_NUM_THREADS" is set to "{omp_num_threads}". '
784
+ f"Set it to 1 if the working process freezes."
820
785
  )
786
+ logging.getLogger("replay").warning(msg)