replay-rec 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/__init__.py +1 -1
  3. replay/data/dataset.py +45 -42
  4. replay/data/dataset_utils/dataset_label_encoder.py +6 -7
  5. replay/data/nn/__init__.py +1 -1
  6. replay/data/nn/schema.py +20 -33
  7. replay/data/nn/sequence_tokenizer.py +217 -87
  8. replay/data/nn/sequential_dataset.py +6 -22
  9. replay/data/nn/torch_sequential_dataset.py +20 -11
  10. replay/data/nn/utils.py +7 -9
  11. replay/data/schema.py +17 -17
  12. replay/data/spark_schema.py +0 -1
  13. replay/metrics/base_metric.py +38 -79
  14. replay/metrics/categorical_diversity.py +24 -58
  15. replay/metrics/coverage.py +25 -49
  16. replay/metrics/descriptors.py +4 -13
  17. replay/metrics/experiment.py +3 -8
  18. replay/metrics/hitrate.py +3 -6
  19. replay/metrics/map.py +3 -6
  20. replay/metrics/mrr.py +1 -4
  21. replay/metrics/ndcg.py +4 -7
  22. replay/metrics/novelty.py +10 -29
  23. replay/metrics/offline_metrics.py +26 -61
  24. replay/metrics/precision.py +3 -6
  25. replay/metrics/recall.py +3 -6
  26. replay/metrics/rocauc.py +7 -10
  27. replay/metrics/surprisal.py +13 -30
  28. replay/metrics/torch_metrics_builder.py +0 -4
  29. replay/metrics/unexpectedness.py +15 -20
  30. replay/models/__init__.py +1 -2
  31. replay/models/als.py +7 -15
  32. replay/models/association_rules.py +12 -28
  33. replay/models/base_neighbour_rec.py +21 -36
  34. replay/models/base_rec.py +92 -215
  35. replay/models/cat_pop_rec.py +9 -22
  36. replay/models/cluster.py +17 -28
  37. replay/models/extensions/ann/ann_mixin.py +7 -12
  38. replay/models/extensions/ann/entities/base_hnsw_param.py +1 -1
  39. replay/models/extensions/ann/entities/hnswlib_param.py +0 -6
  40. replay/models/extensions/ann/entities/nmslib_hnsw_param.py +0 -6
  41. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +4 -10
  42. replay/models/extensions/ann/index_builders/driver_nmslib_index_builder.py +7 -11
  43. replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +5 -12
  44. replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +11 -18
  45. replay/models/extensions/ann/index_builders/nmslib_index_builder_mixin.py +1 -4
  46. replay/models/extensions/ann/index_inferers/base_inferer.py +3 -10
  47. replay/models/extensions/ann/index_inferers/hnswlib_filter_index_inferer.py +7 -17
  48. replay/models/extensions/ann/index_inferers/hnswlib_index_inferer.py +6 -14
  49. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +14 -28
  50. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +15 -25
  51. replay/models/extensions/ann/index_inferers/utils.py +2 -9
  52. replay/models/extensions/ann/index_stores/hdfs_index_store.py +4 -9
  53. replay/models/extensions/ann/index_stores/shared_disk_index_store.py +2 -6
  54. replay/models/extensions/ann/index_stores/spark_files_index_store.py +8 -14
  55. replay/models/extensions/ann/index_stores/utils.py +5 -2
  56. replay/models/extensions/ann/utils.py +3 -5
  57. replay/models/kl_ucb.py +16 -22
  58. replay/models/knn.py +37 -59
  59. replay/models/nn/optimizer_utils/__init__.py +1 -6
  60. replay/models/nn/optimizer_utils/optimizer_factory.py +3 -6
  61. replay/models/nn/sequential/bert4rec/__init__.py +1 -1
  62. replay/models/nn/sequential/bert4rec/dataset.py +6 -7
  63. replay/models/nn/sequential/bert4rec/lightning.py +53 -56
  64. replay/models/nn/sequential/bert4rec/model.py +12 -25
  65. replay/models/nn/sequential/callbacks/__init__.py +1 -1
  66. replay/models/nn/sequential/callbacks/prediction_callbacks.py +23 -25
  67. replay/models/nn/sequential/callbacks/validation_callback.py +27 -30
  68. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  69. replay/models/nn/sequential/sasrec/dataset.py +8 -7
  70. replay/models/nn/sequential/sasrec/lightning.py +53 -48
  71. replay/models/nn/sequential/sasrec/model.py +4 -17
  72. replay/models/pop_rec.py +9 -10
  73. replay/models/query_pop_rec.py +7 -15
  74. replay/models/random_rec.py +10 -18
  75. replay/models/slim.py +8 -13
  76. replay/models/thompson_sampling.py +13 -14
  77. replay/models/ucb.py +11 -22
  78. replay/models/wilson.py +5 -14
  79. replay/models/word2vec.py +24 -69
  80. replay/optimization/optuna_objective.py +13 -27
  81. replay/preprocessing/__init__.py +1 -2
  82. replay/preprocessing/converter.py +2 -7
  83. replay/preprocessing/filters.py +67 -142
  84. replay/preprocessing/history_based_fp.py +44 -116
  85. replay/preprocessing/label_encoder.py +106 -68
  86. replay/preprocessing/sessionizer.py +1 -11
  87. replay/scenarios/fallback.py +3 -8
  88. replay/splitters/base_splitter.py +43 -15
  89. replay/splitters/cold_user_random_splitter.py +18 -31
  90. replay/splitters/k_folds.py +14 -24
  91. replay/splitters/last_n_splitter.py +33 -43
  92. replay/splitters/new_users_splitter.py +31 -55
  93. replay/splitters/random_splitter.py +16 -23
  94. replay/splitters/ratio_splitter.py +30 -54
  95. replay/splitters/time_splitter.py +13 -18
  96. replay/splitters/two_stage_splitter.py +44 -79
  97. replay/utils/__init__.py +1 -1
  98. replay/utils/common.py +65 -0
  99. replay/utils/dataframe_bucketizer.py +25 -31
  100. replay/utils/distributions.py +3 -15
  101. replay/utils/model_handler.py +36 -33
  102. replay/utils/session_handler.py +11 -15
  103. replay/utils/spark_utils.py +51 -85
  104. replay/utils/time.py +8 -22
  105. replay/utils/types.py +1 -3
  106. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/METADATA +2 -2
  107. replay_rec-0.17.0.dist-info/RECORD +127 -0
  108. replay_rec-0.16.0.dist-info/RECORD +0 -126
  109. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/LICENSE +0 -0
  110. {replay_rec-0.16.0.dist-info → replay_rec-0.17.0.dist-info}/WHEEL +0 -0
@@ -11,22 +11,17 @@ if PYSPARK_AVAILABLE:
11
11
  from pyspark.sql.pandas.functions import pandas_udf
12
12
 
13
13
 
14
- # pylint: disable=too-few-public-methods
15
14
  class NmslibFilterIndexInferer(IndexInferer):
16
15
  """Nmslib index inferer with filter seen items. Infers nmslib hnsw index."""
17
16
 
18
- def infer(
19
- self, vectors: SparkDataFrame, features_col: str, k: int
20
- ) -> SparkDataFrame:
17
+ def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame: # noqa: ARG002
21
18
  _index_store = self.index_store
22
19
  index_params = self.index_params
23
20
 
24
- index_store_broadcast = State().session.sparkContext.broadcast(
25
- _index_store
26
- )
21
+ index_store_broadcast = State().session.sparkContext.broadcast(_index_store)
27
22
 
28
23
  @pandas_udf(self.udf_return_type)
29
- def infer_index_udf( # pylint: disable=too-many-locals
24
+ def infer_index_udf(
30
25
  user_idx: pd.Series,
31
26
  vector_items: pd.Series,
32
27
  vector_ratings: pd.Series,
@@ -36,12 +31,8 @@ class NmslibFilterIndexInferer(IndexInferer):
36
31
  index_store = index_store_broadcast.value
37
32
  index = index_store.load_index(
38
33
  init_index=lambda: create_nmslib_index_instance(index_params),
39
- load_index=lambda index, path: index.loadIndex(
40
- path, load_data=True
41
- ),
42
- configure_index=lambda index: index.setQueryTimeParams(
43
- {"efSearch": index_params.ef_s}
44
- )
34
+ load_index=lambda index, path: index.loadIndex(path, load_data=True),
35
+ configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
45
36
  if index_params.ef_s
46
37
  else None,
47
38
  )
@@ -49,9 +40,7 @@ class NmslibFilterIndexInferer(IndexInferer):
49
40
  # max number of items to retrieve per batch
50
41
  max_items_to_retrieve = num_items.max()
51
42
 
52
- user_vectors = get_csr_matrix(
53
- user_idx, vector_items, vector_ratings
54
- )
43
+ user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)
55
44
 
56
45
  neighbours = index.knnQueryBatch(
57
46
  user_vectors[user_idx.values, :],
@@ -61,9 +50,7 @@ class NmslibFilterIndexInferer(IndexInferer):
61
50
 
62
51
  neighbours_filtered = []
63
52
  for i, (item_idxs, distances) in enumerate(neighbours):
64
- non_seen_item_indexes = ~np.isin(
65
- item_idxs, seen_item_ids[i], assume_unique=True
66
- )
53
+ non_seen_item_indexes = ~np.isin(item_idxs, seen_item_ids[i], assume_unique=True)
67
54
  neighbours_filtered.append(
68
55
  (
69
56
  (item_idxs[non_seen_item_indexes])[:k],
@@ -71,14 +58,14 @@ class NmslibFilterIndexInferer(IndexInferer):
71
58
  )
72
59
  )
73
60
 
74
- pd_res = PandasDataFrame(
75
- neighbours_filtered, columns=["item_idx", "distance"]
76
- )
61
+ pd_res = PandasDataFrame(neighbours_filtered, columns=["item_idx", "distance"])
77
62
 
78
- # pd_res looks like
79
- # item_idx distances
80
- # [1, 2, 3, ...] [-0.5, -0.3, -0.1, ...]
81
- # [1, 3, 4, ...] [-0.1, -0.8, -0.2, ...]
63
+ """
64
+ pd_res looks like
65
+ item_idx distances
66
+ [1, 2, 3, ...] [-0.5, -0.3, -0.1, ...]
67
+ [1, 3, 4, ...] [-0.1, -0.8, -0.2, ...]
68
+ """
82
69
 
83
70
  return pd_res
84
71
 
@@ -89,7 +76,6 @@ class NmslibFilterIndexInferer(IndexInferer):
89
76
  "num_items",
90
77
  "seen_item_idxs",
91
78
  ]
92
- # cols = cols + ["num_items", "seen_item_idxs"]
93
79
 
94
80
  res = vectors.select(
95
81
  "user_idx",
@@ -1,28 +1,24 @@
1
1
  import pandas as pd
2
2
 
3
- from .base_inferer import IndexInferer
4
- from .utils import get_csr_matrix
5
3
  from replay.models.extensions.ann.utils import create_nmslib_index_instance
6
4
  from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
7
5
  from replay.utils.session_handler import State
8
6
 
7
+ from .base_inferer import IndexInferer
8
+ from .utils import get_csr_matrix
9
+
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.sql.pandas.functions import pandas_udf
11
12
 
12
13
 
13
- # pylint: disable=too-few-public-methods
14
14
  class NmslibIndexInferer(IndexInferer):
15
15
  """Nmslib index inferer without filter seen items. Infers nmslib hnsw index."""
16
16
 
17
- def infer(
18
- self, vectors: SparkDataFrame, features_col: str, k: int
19
- ) -> SparkDataFrame:
17
+ def infer(self, vectors: SparkDataFrame, features_col: str, k: int) -> SparkDataFrame: # noqa: ARG002
20
18
  _index_store = self.index_store
21
19
  index_params = self.index_params
22
20
 
23
- index_store_broadcast = State().session.sparkContext.broadcast(
24
- _index_store
25
- )
21
+ index_store_broadcast = State().session.sparkContext.broadcast(_index_store)
26
22
 
27
23
  @pandas_udf(self.udf_return_type)
28
24
  def infer_index_udf(
@@ -33,29 +29,23 @@ class NmslibIndexInferer(IndexInferer):
33
29
  index_store = index_store_broadcast.value
34
30
  index = index_store.load_index(
35
31
  init_index=lambda: create_nmslib_index_instance(index_params),
36
- load_index=lambda index, path: index.loadIndex(
37
- path, load_data=True
38
- ),
39
- configure_index=lambda index: index.setQueryTimeParams(
40
- {"efSearch": index_params.ef_s}
41
- )
32
+ load_index=lambda index, path: index.loadIndex(path, load_data=True),
33
+ configure_index=lambda index: index.setQueryTimeParams({"efSearch": index_params.ef_s})
42
34
  if index_params.ef_s
43
35
  else None,
44
36
  )
45
37
 
46
- user_vectors = get_csr_matrix(
47
- user_idx, vector_items, vector_ratings
48
- )
49
- neighbours = index.knnQueryBatch(
50
- user_vectors[user_idx.values, :], k=k, num_threads=1
51
- )
38
+ user_vectors = get_csr_matrix(user_idx, vector_items, vector_ratings)
39
+ neighbours = index.knnQueryBatch(user_vectors[user_idx.values, :], k=k, num_threads=1)
52
40
 
53
41
  pd_res = PandasDataFrame(neighbours, columns=["item_idx", "distance"])
54
42
 
55
- # pd_res looks like
56
- # item_idx distances
57
- # [1, 2, 3, ...] [-0.5, -0.3, -0.1, ...]
58
- # [1, 3, 4, ...] [-0.1, -0.8, -0.2, ...]
43
+ """
44
+ pd_res looks like
45
+ item_idx distances
46
+ [1, 2, 3, ...] [-0.5, -0.3, -0.1, ...]
47
+ [1, 3, 4, ...] [-0.1, -0.8, -0.2, ...]
48
+ """
59
49
 
60
50
  return pd_res
61
51
 
@@ -12,19 +12,12 @@ def get_csr_matrix(
12
12
  (
13
13
  vector_ratings.explode().values.astype(float),
14
14
  (
15
- user_idx.repeat(
16
- vector_items.apply(
17
- lambda x: len(x) # pylint: disable=unnecessary-lambda
18
- )
19
- ).values,
15
+ user_idx.repeat(vector_items.apply(lambda x: len(x))).values,
20
16
  vector_items.explode().values.astype(int),
21
17
  ),
22
18
  ),
23
19
  shape=(
24
20
  user_idx.max() + 1,
25
- vector_items.apply(
26
- lambda x: max(x) # pylint: disable=unnecessary-lambda
27
- ).max()
28
- + 1,
21
+ vector_items.apply(lambda x: max(x)).max() + 1,
29
22
  ),
30
23
  )
@@ -19,12 +19,9 @@ class HdfsIndexStore(IndexStore):
19
19
  index_dir_path = os.path.join(warehouse_dir, index_dir)
20
20
  self._index_dir_info = get_filesystem(index_dir_path)
21
21
  if self._index_dir_info.filesystem != FileSystem.HDFS:
22
- raise ValueError(
23
- f"Can't recognize path {index_dir_path} as HDFS path!"
24
- )
25
- self._hadoop_fs = fs.HadoopFileSystem.from_uri(
26
- self._index_dir_info.hdfs_uri
27
- )
22
+ msg = f"Can't recognize path {index_dir_path} as HDFS path!"
23
+ raise ValueError(msg)
24
+ self._hadoop_fs = fs.HadoopFileSystem.from_uri(self._index_dir_info.hdfs_uri)
28
25
  super().__init__()
29
26
 
30
27
  if self.cleanup:
@@ -32,9 +29,7 @@ class HdfsIndexStore(IndexStore):
32
29
  "Index directory %s is marked for deletion via weakref.finalize()",
33
30
  self._index_dir_info.path,
34
31
  )
35
- weakref.finalize(
36
- self, self._hadoop_fs.delete_dir, self._index_dir_info.path
37
- )
32
+ weakref.finalize(self, self._hadoop_fs.delete_dir, self._index_dir_info.path)
38
33
 
39
34
  def load_index(
40
35
  self,
@@ -17,9 +17,7 @@ class SharedDiskIndexStore(IndexStore):
17
17
  It can also be used with a local disk when the driver and executors
18
18
  are running on the same machine."""
19
19
 
20
- def __init__(
21
- self, warehouse_dir: str, index_dir: str, cleanup: bool = True
22
- ):
20
+ def __init__(self, warehouse_dir: str, index_dir: str, cleanup: bool = True):
23
21
  self.index_dir_path = os.path.join(warehouse_dir, index_dir)
24
22
  super().__init__(cleanup)
25
23
  if self.cleanup:
@@ -52,9 +50,7 @@ class SharedDiskIndexStore(IndexStore):
52
50
  save_index(temp_file_path)
53
51
 
54
52
  def dump_index(self, target_path: str):
55
- destination_filesystem, target_path = fs.FileSystem.from_uri(
56
- target_path
57
- )
53
+ destination_filesystem, target_path = fs.FileSystem.from_uri(target_path)
58
54
  target_path = os.path.join(target_path, "index_files")
59
55
  destination_filesystem.create_dir(target_path)
60
56
  fs.copy_files(
@@ -7,11 +7,12 @@ from typing import Any, Callable
7
7
 
8
8
  from pyarrow import fs
9
9
 
10
- from .base_index_store import IndexStore
11
- from .utils import FileSystem, get_filesystem
12
10
  from replay.utils import PYSPARK_AVAILABLE
13
11
  from replay.utils.session_handler import State
14
12
 
13
+ from .base_index_store import IndexStore
14
+ from .utils import FileSystem, get_filesystem
15
+
15
16
  if PYSPARK_AVAILABLE:
16
17
  from pyspark import SparkFiles
17
18
 
@@ -20,6 +21,7 @@ logger = logging.getLogger("replay")
20
21
 
21
22
 
22
23
  if PYSPARK_AVAILABLE:
24
+
23
25
  class SparkFilesIndexStore(IndexStore):
24
26
  """Class that responsible for index store in spark files.
25
27
  Works through SparkContext.addFile()."""
@@ -62,14 +64,10 @@ if PYSPARK_AVAILABLE:
62
64
  for filename in os.listdir(self.index_dir_path):
63
65
  index_file_path = os.path.join(self.index_dir_path, filename)
64
66
  spark.sparkContext.addFile("file://" + index_file_path)
65
- logger.info(
66
- "Index file %s transferred to executors", index_file_path
67
- )
67
+ logger.info("Index file %s transferred to executors", index_file_path)
68
68
 
69
69
  def dump_index(self, target_path: str):
70
- destination_filesystem, target_path = fs.FileSystem.from_uri(
71
- target_path
72
- )
70
+ destination_filesystem, target_path = fs.FileSystem.from_uri(target_path)
73
71
  target_path = os.path.join(target_path, "index_files")
74
72
  destination_filesystem.create_dir(target_path)
75
73
  fs.copy_files(
@@ -83,9 +81,7 @@ if PYSPARK_AVAILABLE:
83
81
  """Loads index from `path` directory to spark files."""
84
82
  path_info = get_filesystem(path)
85
83
  source_filesystem, path = fs.FileSystem.from_uri(
86
- path_info.hdfs_uri + path_info.path
87
- if path_info.filesystem == FileSystem.HDFS
88
- else path_info.path
84
+ path_info.hdfs_uri + path_info.path if path_info.filesystem == FileSystem.HDFS else path_info.path
89
85
  )
90
86
  path = os.path.join(path, "index_files")
91
87
  self.index_dir_path: str = tempfile.mkdtemp()
@@ -100,6 +96,4 @@ if PYSPARK_AVAILABLE:
100
96
  for filename in os.listdir(self.index_dir_path):
101
97
  index_file_path = os.path.join(self.index_dir_path, filename)
102
98
  spark.sparkContext.addFile("file://" + index_file_path)
103
- logger.info(
104
- "Index file %s transferred to executors", index_file_path
105
- )
99
+ logger.info("Index file %s transferred to executors", index_file_path)
@@ -9,6 +9,7 @@ if PYSPARK_AVAILABLE:
9
9
 
10
10
  class FileSystem(Enum):
11
11
  """File system types"""
12
+
12
13
  HDFS = 1
13
14
  LOCAL = 2
14
15
 
@@ -24,6 +25,7 @@ def get_default_fs() -> str:
24
25
  @dataclass(frozen=True)
25
26
  class FileInfo:
26
27
  """File meta-information: filesystem, path and hdfs_uri (optional)"""
28
+
27
29
  path: str
28
30
  filesystem: FileSystem
29
31
  hdfs_uri: str = None
@@ -83,15 +85,16 @@ or set 'fs.defaultFS' in hadoop configuration.
83
85
  if default_fs.startswith("hdfs://"):
84
86
  return FileInfo(path[prefix_len:], FileSystem.HDFS, default_fs)
85
87
  else:
86
- raise ValueError(
88
+ msg = (
87
89
  f"Can't get default hdfs uri for path = '{path}'. "
88
90
  "Specify an explicit path, such as 'hdfs://host:port/dir/file', "
89
91
  "or set 'fs.defaultFS' in hadoop configuration."
90
92
  )
93
+ raise ValueError(msg)
91
94
  else:
92
95
  hostname = path[prefix_len:].split("/", 1)[0]
93
96
  hdfs_uri = "hdfs://" + hostname
94
- return FileInfo(path[len(hdfs_uri):], FileSystem.HDFS, hdfs_uri)
97
+ return FileInfo(path[len(hdfs_uri) :], FileSystem.HDFS, hdfs_uri)
95
98
  elif path.startswith("file://"):
96
99
  return FileInfo(path[prefix_len:], FileSystem.LOCAL)
97
100
  else:
@@ -15,9 +15,7 @@ def create_hnswlib_index_instance(params: HnswlibParam, init: bool = False):
15
15
  If `False` then the index will be used to load index data from a file.
16
16
  :return: `hnswlib` index instance
17
17
  """
18
- index = hnswlib.Index( # pylint: disable=c-extension-no-member
19
- space=params.space, dim=params.dim
20
- )
18
+ index = hnswlib.Index(space=params.space, dim=params.dim)
21
19
 
22
20
  if init:
23
21
  # Initializing index - the maximum number of elements should be known beforehand
@@ -37,10 +35,10 @@ def create_nmslib_index_instance(params: NmslibHnswParam):
37
35
  :param params: `NmslibHnswParam`
38
36
  :return: `nmslib` index
39
37
  """
40
- index = nmslib.init( # pylint: disable=c-extension-no-member
38
+ index = nmslib.init(
41
39
  method=params.method,
42
40
  space=params.space,
43
- data_type=nmslib.DataType.SPARSE_VECTOR, # pylint: disable=c-extension-no-member
41
+ data_type=nmslib.DataType.SPARSE_VECTOR,
44
42
  )
45
43
 
46
44
  return index
replay/models/kl_ucb.py CHANGED
@@ -1,13 +1,15 @@
1
1
  import math
2
-
3
2
  from typing import Optional
4
- from .ucb import UCB
5
- from replay.utils import PYSPARK_AVAILABLE
3
+
6
4
  from scipy.optimize import root_scalar
7
5
 
6
+ from replay.utils import PYSPARK_AVAILABLE
7
+
8
+ from .ucb import UCB
9
+
8
10
  if PYSPARK_AVAILABLE:
9
- from pyspark.sql.types import DoubleType
10
11
  from pyspark.sql.functions import udf
12
+ from pyspark.sql.types import DoubleType
11
13
 
12
14
 
13
15
  class KLUCB(UCB):
@@ -17,7 +19,7 @@ class KLUCB(UCB):
17
19
  computes item relevance as an upper confidence bound of true fraction of
18
20
  positive interactions.
19
21
 
20
- In a nutshell, KL-UCB сonsiders the data as the history of interactions
22
+ In a nutshell, KL-UCB considers the data as the history of interactions
21
23
  with items. The interaction may be either positive or negative. For each
22
24
  item the model computes empirical frequency of positive interactions
23
25
  and estimates the true frequency with an upper confidence bound. The higher
@@ -137,14 +139,11 @@ class KLUCB(UCB):
137
139
  super().__init__(exploration_coef, sample, seed)
138
140
 
139
141
  def _calc_item_popularity(self):
140
-
141
- right_hand_side = math.log(self.full_count) \
142
- + self.coef * math.log(math.log(self.full_count))
142
+ right_hand_side = math.log(self.full_count) + self.coef * math.log(math.log(self.full_count))
143
143
  eps = 1e-12
144
144
 
145
145
  def bernoulli_kl(proba_p, proba_q): # pragma: no cover
146
- return proba_p * math.log(proba_p / proba_q) +\
147
- (1 - proba_p) * math.log((1 - proba_p) / (1 - proba_q))
146
+ return proba_p * math.log(proba_p / proba_q) + (1 - proba_p) * math.log((1 - proba_p) / (1 - proba_q))
148
147
 
149
148
  @udf(returnType=DoubleType())
150
149
  def get_ucb(pos, total): # pragma: no cover
@@ -152,27 +151,22 @@ class KLUCB(UCB):
152
151
 
153
152
  if proba == 0:
154
153
  ucb = root_scalar(
155
- f=lambda qq: math.log(1 / (1 - qq)) - right_hand_side,
156
- bracket=[0, 1 - eps],
157
- method='brentq').root
154
+ f=lambda qq: math.log(1 / (1 - qq)) - right_hand_side, bracket=[0, 1 - eps], method="brentq"
155
+ ).root
158
156
  return ucb
159
157
 
160
158
  if proba == 1:
161
159
  ucb = root_scalar(
162
- f=lambda qq: math.log(1 / qq) - right_hand_side,
163
- bracket=[0 + eps, 1],
164
- method='brentq').root
160
+ f=lambda qq: math.log(1 / qq) - right_hand_side, bracket=[0 + eps, 1], method="brentq"
161
+ ).root
165
162
  return ucb
166
163
 
167
164
  ucb = root_scalar(
168
- f=lambda q: total * bernoulli_kl(proba, q) - right_hand_side,
169
- bracket=[proba, 1 - eps],
170
- method='brentq').root
165
+ f=lambda q: total * bernoulli_kl(proba, q) - right_hand_side, bracket=[proba, 1 - eps], method="brentq"
166
+ ).root
171
167
  return ucb
172
168
 
173
- items_counts = self.items_counts_aggr.withColumn(
174
- self.rating_column, get_ucb("pos", "total")
175
- )
169
+ items_counts = self.items_counts_aggr.withColumn(self.rating_column, get_ucb("pos", "total"))
176
170
 
177
171
  self.item_popularity = items_counts.drop("pos", "total")
178
172
 
replay/models/knn.py CHANGED
@@ -1,17 +1,17 @@
1
1
  from typing import Any, Dict, Optional
2
2
 
3
3
  from replay.data import Dataset
4
- from .base_neighbour_rec import NeighbourRec
5
- from .extensions.ann.index_builders.base_index_builder import IndexBuilder
6
4
  from replay.optimization.optuna_objective import ItemKNNObjective
7
5
  from replay.utils import PYSPARK_AVAILABLE, SparkDataFrame
8
6
 
7
+ from .base_neighbour_rec import NeighbourRec
8
+ from .extensions.ann.index_builders.base_index_builder import IndexBuilder
9
+
9
10
  if PYSPARK_AVAILABLE:
10
11
  from pyspark.sql import functions as sf
11
12
  from pyspark.sql.window import Window
12
13
 
13
14
 
14
- # pylint: disable=too-many-ancestors, too-many-instance-attributes
15
15
  class ItemKNN(NeighbourRec):
16
16
  """Item-based ItemKNN with modified cosine similarity measure."""
17
17
 
@@ -29,15 +29,15 @@ class ItemKNN(NeighbourRec):
29
29
  _search_space = {
30
30
  "num_neighbours": {"type": "int", "args": [1, 100]},
31
31
  "shrink": {"type": "int", "args": [0, 100]},
32
- "weighting": {"type": "categorical", "args": [None, "tf_idf", "bm25"]}
32
+ "weighting": {"type": "categorical", "args": [None, "tf_idf", "bm25"]},
33
33
  }
34
34
 
35
- def __init__( # pylint: disable=too-many-arguments
35
+ def __init__(
36
36
  self,
37
37
  num_neighbours: int = 10,
38
38
  use_rating: bool = False,
39
39
  shrink: float = 0.0,
40
- weighting: str = None,
40
+ weighting: Optional[str] = None,
41
41
  index_builder: Optional[IndexBuilder] = None,
42
42
  ):
43
43
  """
@@ -54,7 +54,8 @@ class ItemKNN(NeighbourRec):
54
54
 
55
55
  valid_weightings = self._search_space["weighting"]["args"]
56
56
  if weighting not in valid_weightings:
57
- raise ValueError(f"weighting must be one of {valid_weightings}")
57
+ msg = f"weighting must be one of {valid_weightings}"
58
+ raise ValueError(msg)
58
59
  self.weighting = weighting
59
60
  if isinstance(index_builder, (IndexBuilder, type(None))):
60
61
  self.index_builder = index_builder
@@ -75,8 +76,7 @@ class ItemKNN(NeighbourRec):
75
76
  def _shrink(dot_products: SparkDataFrame, shrink: float) -> SparkDataFrame:
76
77
  return dot_products.withColumn(
77
78
  "similarity",
78
- sf.col("dot_product")
79
- / (sf.col("norm1") * sf.col("norm2") + shrink),
79
+ sf.col("dot_product") / (sf.col("norm1") * sf.col("norm2") + shrink),
80
80
  ).select("item_idx_one", "item_idx_two", "similarity")
81
81
 
82
82
  def _get_similarity(self, interactions: SparkDataFrame) -> SparkDataFrame:
@@ -116,25 +116,19 @@ class ItemKNN(NeighbourRec):
116
116
  :param interactions: SparkDataFrame with interactions, `[user_id, item_id, rating]`
117
117
  :return: interactions `[user_id, item_id, rating]`
118
118
  """
119
- item_stats = interactions.groupBy(self.item_column).agg(
120
- sf.count(self.query_column).alias("n_queries_per_item")
121
- )
119
+ item_stats = interactions.groupBy(self.item_column).agg(sf.count(self.query_column).alias("n_queries_per_item"))
122
120
  avgdl = item_stats.select(sf.mean("n_queries_per_item")).take(1)[0][0]
123
121
  interactions = interactions.join(item_stats, how="inner", on=self.item_column)
124
122
 
125
- interactions = (
126
- interactions.withColumn(
127
- self.rating_column,
128
- sf.col(self.rating_column) * (self.bm25_k1 + 1) / (
129
- sf.col(self.rating_column) + self.bm25_k1 * (
130
- 1 - self.bm25_b + self.bm25_b * (
131
- sf.col("n_queries_per_item") / avgdl
132
- )
133
- )
134
- )
135
- )
136
- .drop("n_queries_per_item")
137
- )
123
+ interactions = interactions.withColumn(
124
+ self.rating_column,
125
+ sf.col(self.rating_column)
126
+ * (self.bm25_k1 + 1)
127
+ / (
128
+ sf.col(self.rating_column)
129
+ + self.bm25_k1 * (1 - self.bm25_b + self.bm25_b * (sf.col("n_queries_per_item") / avgdl))
130
+ ),
131
+ ).drop("n_queries_per_item")
138
132
 
139
133
  return interactions
140
134
 
@@ -150,23 +144,15 @@ class ItemKNN(NeighbourRec):
150
144
  n_items = interactions.select(self.item_column).distinct().count()
151
145
 
152
146
  if self.weighting == "tf_idf":
153
- idf = (
154
- df.withColumn("idf", sf.log1p(sf.lit(n_items) / sf.col("DF")))
155
- .drop("DF")
156
- )
147
+ idf = df.withColumn("idf", sf.log1p(sf.lit(n_items) / sf.col("DF"))).drop("DF")
157
148
  elif self.weighting == "bm25":
158
- idf = (
159
- df.withColumn(
160
- "idf",
161
- sf.log1p(
162
- (sf.lit(n_items) - sf.col("DF") + 0.5)
163
- / (sf.col("DF") + 0.5)
164
- ),
165
- )
166
- .drop("DF")
167
- )
149
+ idf = df.withColumn(
150
+ "idf",
151
+ sf.log1p((sf.lit(n_items) - sf.col("DF") + 0.5) / (sf.col("DF") + 0.5)),
152
+ ).drop("DF")
168
153
  else:
169
- raise ValueError("weighting must be one of ['tf_idf', 'bm25']")
154
+ msg = "weighting must be one of ['tf_idf', 'bm25']"
155
+ raise ValueError(msg)
170
156
 
171
157
  return idf
172
158
 
@@ -180,12 +166,12 @@ class ItemKNN(NeighbourRec):
180
166
  if self.weighting:
181
167
  interactions = self._reweight_interactions(interactions)
182
168
 
183
- left = interactions.withColumnRenamed(
184
- self.item_column, "item_idx_one"
185
- ).withColumnRenamed(self.rating_column, "rel_one")
186
- right = interactions.withColumnRenamed(
187
- self.item_column, "item_idx_two"
188
- ).withColumnRenamed(self.rating_column, "rel_two")
169
+ left = interactions.withColumnRenamed(self.item_column, "item_idx_one").withColumnRenamed(
170
+ self.rating_column, "rel_one"
171
+ )
172
+ right = interactions.withColumnRenamed(self.item_column, "item_idx_two").withColumnRenamed(
173
+ self.rating_column, "rel_two"
174
+ )
189
175
 
190
176
  dot_products = (
191
177
  left.join(right, how="inner", on=self.query_column)
@@ -201,19 +187,11 @@ class ItemKNN(NeighbourRec):
201
187
  .agg(sf.sum(self.rating_column).alias("square_norm"))
202
188
  .select(sf.col(self.item_column), sf.sqrt("square_norm").alias("norm"))
203
189
  )
204
- norm1 = item_norms.withColumnRenamed(
205
- self.item_column, "item_id1"
206
- ).withColumnRenamed("norm", "norm1")
207
- norm2 = item_norms.withColumnRenamed(
208
- self.item_column, "item_id2"
209
- ).withColumnRenamed("norm", "norm2")
210
-
211
- dot_products = dot_products.join(
212
- norm1, how="inner", on=sf.col("item_id1") == sf.col("item_idx_one")
213
- )
214
- dot_products = dot_products.join(
215
- norm2, how="inner", on=sf.col("item_id2") == sf.col("item_idx_two")
216
- )
190
+ norm1 = item_norms.withColumnRenamed(self.item_column, "item_id1").withColumnRenamed("norm", "norm1")
191
+ norm2 = item_norms.withColumnRenamed(self.item_column, "item_id2").withColumnRenamed("norm", "norm2")
192
+
193
+ dot_products = dot_products.join(norm1, how="inner", on=sf.col("item_id1") == sf.col("item_idx_one"))
194
+ dot_products = dot_products.join(norm2, how="inner", on=sf.col("item_id2") == sf.col("item_idx_two"))
217
195
 
218
196
  return dot_products
219
197
 
@@ -1,9 +1,4 @@
1
1
  from replay.utils import TORCH_AVAILABLE
2
2
 
3
3
  if TORCH_AVAILABLE:
4
- from .optimizer_factory import (
5
- FatLRSchedulerFactory,
6
- FatOptimizerFactory,
7
- LRSchedulerFactory,
8
- OptimizerFactory
9
- )
4
+ from .optimizer_factory import FatLRSchedulerFactory, FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory
@@ -4,7 +4,6 @@ from typing import Iterator, Tuple
4
4
  import torch
5
5
 
6
6
 
7
- # pylint: disable=too-few-public-methods
8
7
  class OptimizerFactory(abc.ABC):
9
8
  """
10
9
  Interface for optimizer factory
@@ -21,7 +20,6 @@ class OptimizerFactory(abc.ABC):
21
20
  """
22
21
 
23
22
 
24
- # pylint: disable=too-few-public-methods
25
23
  class LRSchedulerFactory(abc.ABC):
26
24
  """
27
25
  Interface for learning rate scheduler factory
@@ -38,13 +36,11 @@ class LRSchedulerFactory(abc.ABC):
38
36
  """
39
37
 
40
38
 
41
- # pylint: disable=too-few-public-methods
42
39
  class FatOptimizerFactory(OptimizerFactory):
43
40
  """
44
41
  Factory that creates optimizer depending on passed parameters
45
42
  """
46
43
 
47
- # pylint: disable=too-many-arguments
48
44
  def __init__(
49
45
  self,
50
46
  optimizer: str = "adam",
@@ -74,10 +70,11 @@ class FatOptimizerFactory(OptimizerFactory):
74
70
  return torch.optim.SGD(
75
71
  parameters, lr=self.learning_rate, weight_decay=self.weight_decay, momentum=self.sgd_momentum
76
72
  )
77
- raise ValueError("Unexpected optimizer")
73
+
74
+ msg = "Unexpected optimizer"
75
+ raise ValueError(msg)
78
76
 
79
77
 
80
- # pylint: disable=too-few-public-methods
81
78
  class FatLRSchedulerFactory(LRSchedulerFactory):
82
79
  """
83
80
  Factory that creates learning rate schedule depending on passed parameters