replay-rec 0.19.0__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. replay/__init__.py +6 -2
  2. replay/data/dataset.py +9 -9
  3. replay/data/nn/__init__.py +6 -6
  4. replay/data/nn/sequence_tokenizer.py +44 -38
  5. replay/data/nn/sequential_dataset.py +13 -8
  6. replay/data/nn/torch_sequential_dataset.py +14 -13
  7. replay/data/nn/utils.py +1 -1
  8. replay/metrics/base_metric.py +1 -1
  9. replay/metrics/coverage.py +7 -11
  10. replay/metrics/experiment.py +3 -3
  11. replay/metrics/offline_metrics.py +2 -2
  12. replay/models/__init__.py +19 -0
  13. replay/models/association_rules.py +1 -4
  14. replay/models/base_neighbour_rec.py +6 -9
  15. replay/models/base_rec.py +44 -293
  16. replay/models/cat_pop_rec.py +2 -1
  17. replay/models/common.py +69 -0
  18. replay/models/extensions/ann/ann_mixin.py +30 -25
  19. replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
  20. replay/models/extensions/ann/utils.py +4 -3
  21. replay/models/knn.py +18 -17
  22. replay/models/nn/sequential/bert4rec/dataset.py +1 -1
  23. replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
  24. replay/models/nn/sequential/compiled/__init__.py +10 -0
  25. replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
  26. replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
  27. replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
  28. replay/models/nn/sequential/sasrec/dataset.py +1 -1
  29. replay/models/nn/sequential/sasrec/model.py +1 -1
  30. replay/models/optimization/__init__.py +14 -0
  31. replay/models/optimization/optuna_mixin.py +279 -0
  32. replay/{optimization → models/optimization}/optuna_objective.py +13 -15
  33. replay/models/slim.py +2 -4
  34. replay/models/word2vec.py +7 -12
  35. replay/preprocessing/discretizer.py +1 -2
  36. replay/preprocessing/history_based_fp.py +1 -1
  37. replay/preprocessing/label_encoder.py +1 -1
  38. replay/splitters/cold_user_random_splitter.py +13 -7
  39. replay/splitters/last_n_splitter.py +17 -10
  40. replay/utils/__init__.py +6 -2
  41. replay/utils/common.py +4 -2
  42. replay/utils/model_handler.py +11 -31
  43. replay/utils/session_handler.py +2 -2
  44. replay/utils/spark_utils.py +2 -2
  45. replay/utils/types.py +28 -18
  46. replay/utils/warnings.py +26 -0
  47. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -32
  48. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/RECORD +51 -47
  49. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
  50. replay_rec-0.20.0.dist-info/licenses/NOTICE +41 -0
  51. replay/optimization/__init__.py +0 -5
  52. {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
@@ -21,12 +21,12 @@ from replay.utils import (
21
21
  PandasDataFrame,
22
22
  PolarsDataFrame,
23
23
  SparkDataFrame,
24
- get_spark_session,
25
24
  )
26
25
 
27
26
  if PYSPARK_AVAILABLE:
28
27
  from pyspark.sql import Window, functions as sf # noqa: I001
29
28
  from pyspark.sql.types import LongType, IntegerType, ArrayType
29
+ from replay.utils.session_handler import get_spark_session
30
30
 
31
31
  HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
32
32
 
@@ -38,12 +38,16 @@ class ColdUserRandomSplitter(Splitter):
38
38
  item_column: Optional[str] = "item_id",
39
39
  ):
40
40
  """
41
- :param test_size: fraction of users to be in test
42
- :param drop_cold_items: flag to drop cold items from test
43
- :param drop_cold_users: flag to drop cold users from test
44
- :param seed: random seed
45
- :param query_column: query id column name
46
- :param item_column: item id column name
41
+ :param test_size: The proportion of users to allocate to the test set.
42
+ Must be a float between 0.0 and 1.0.
43
+ :param drop_cold_items: Drop items from test DataFrame
44
+ which are not in train DataFrame, default: False.
45
+ :param seed: Seed for the random number generator to ensure
46
+ reproducibility of the split, default: None.
47
+ :param query_column: Name of query interaction column.
48
+ default: ``query_id``.
49
+ :param item_column: Name of item interaction column.
50
+ default: ``item_id``.
47
51
  """
48
52
  super().__init__(
49
53
  drop_cold_items=drop_cold_items,
@@ -81,7 +85,9 @@ class ColdUserRandomSplitter(Splitter):
81
85
  seed=self.seed,
82
86
  )
83
87
  interactions = interactions.join(
84
- train_users.withColumn("is_test", sf.lit(False)), on=self.query_column, how="left"
88
+ train_users.withColumn("is_test", sf.lit(False)),
89
+ on=self.query_column,
90
+ how="left",
85
91
  ).na.fill({"is_test": True})
86
92
 
87
93
  train = interactions.filter(~sf.col("is_test")).drop("is_test")
@@ -4,7 +4,13 @@ import numpy as np
4
4
  import pandas as pd
5
5
  import polars as pl
6
6
 
7
- from replay.utils import PYSPARK_AVAILABLE, DataFrameLike, PandasDataFrame, PolarsDataFrame, SparkDataFrame
7
+ from replay.utils import (
8
+ PYSPARK_AVAILABLE,
9
+ DataFrameLike,
10
+ PandasDataFrame,
11
+ PolarsDataFrame,
12
+ SparkDataFrame,
13
+ )
8
14
 
9
15
  from .base_splitter import Splitter
10
16
 
@@ -118,14 +124,12 @@ class LastNSplitter(Splitter):
118
124
  session_id_processing_strategy: str = "test",
119
125
  ):
120
126
  """
121
- :param N: Array of interactions/timedelta to split.
127
+ :param N: Number of last interactions or size of the time window in seconds
122
128
  :param divide_column: Name of column for dividing
123
129
  in dataframe, default: ``query_id``.
124
- :param time_column_format: Format of time_column,
125
- needs for convert time_column into unix_timestamp type.
126
- If strategy is set to 'interactions', then you can omit this parameter.
127
- If time_column has already transformed into unix_timestamp type,
128
- then you can omit this parameter.
130
+ :param time_column_format: Format of the timestamp column,
131
+ used for converting string dates to a numerical timestamp when strategy is 'timedelta'.
132
+ If the column is already a datetime object or a numerical timestamp, this parameter is ignored.
129
133
  default: ``yyyy-MM-dd HH:mm:ss``
130
134
  :param strategy: Defines the type of data splitting.
131
135
  Must be ``interactions`` or ``timedelta``.
@@ -223,7 +227,8 @@ class LastNSplitter(Splitter):
223
227
  time_column_type = dict(interactions.dtypes)[self.timestamp_column]
224
228
  if time_column_type == "date":
225
229
  interactions = interactions.withColumn(
226
- self.timestamp_column, sf.unix_timestamp(self.timestamp_column, self.time_column_format)
230
+ self.timestamp_column,
231
+ sf.unix_timestamp(self.timestamp_column, self.time_column_format),
227
232
  )
228
233
 
229
234
  return interactions
@@ -260,7 +265,8 @@ class LastNSplitter(Splitter):
260
265
  self, interactions: SparkDataFrame, n: int
261
266
  ) -> Tuple[SparkDataFrame, SparkDataFrame]:
262
267
  interactions = interactions.withColumn(
263
- "count", sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column))
268
+ "count",
269
+ sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
264
270
  )
265
271
  # float(n) - because DataFrame.filter is changing order
266
272
  # of sorted DataFrame to descending
@@ -317,7 +323,8 @@ class LastNSplitter(Splitter):
317
323
  self, interactions: SparkDataFrame, timedelta: int
318
324
  ) -> Tuple[SparkDataFrame, SparkDataFrame]:
319
325
  inter_with_max_time = interactions.withColumn(
320
- "max_timestamp", sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column))
326
+ "max_timestamp",
327
+ sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
321
328
  )
322
329
  inter_with_diff = inter_with_max_time.withColumn(
323
330
  "diff_timestamp", sf.col("max_timestamp") - sf.col(self.timestamp_column)
replay/utils/__init__.py CHANGED
@@ -1,13 +1,17 @@
1
- from .session_handler import State, get_spark_session
2
1
  from .types import (
2
+ ANN_AVAILABLE,
3
3
  OPENVINO_AVAILABLE,
4
+ OPTUNA_AVAILABLE,
4
5
  PYSPARK_AVAILABLE,
5
6
  TORCH_AVAILABLE,
6
7
  DataFrameLike,
8
+ FeatureUnavailableError,
9
+ FeatureUnavailableWarning,
7
10
  IntOrList,
8
- MissingImportType,
11
+ MissingImport,
9
12
  NumType,
10
13
  PandasDataFrame,
11
14
  PolarsDataFrame,
12
15
  SparkDataFrame,
13
16
  )
17
+ from .warnings import deprecation_warning
replay/utils/common.py CHANGED
@@ -126,6 +126,7 @@ def convert2pandas(
126
126
  """
127
127
  if isinstance(data, PandasDataFrame):
128
128
  return data
129
+
129
130
  if isinstance(data, PolarsDataFrame):
130
131
  return data.to_pandas()
131
132
  if isinstance(data, SparkDataFrame):
@@ -144,10 +145,11 @@ def convert2polars(
144
145
  :param allow_collect_to_master: If set to False (default) raises a warning
145
146
  about collecting parallelized data to the master node.
146
147
  """
147
- if isinstance(data, PandasDataFrame):
148
- return pl_from_pandas(data)
149
148
  if isinstance(data, PolarsDataFrame):
150
149
  return data
150
+
151
+ if isinstance(data, PandasDataFrame):
152
+ return pl_from_pandas(data)
151
153
  if isinstance(data, SparkDataFrame):
152
154
  return pl_from_pandas(spark_to_pandas(data, allow_collect_to_master, from_constructor=False))
153
155
 
@@ -1,16 +1,13 @@
1
- import functools
2
1
  import json
3
2
  import os
4
3
  import pickle
5
- import warnings
6
4
  from os.path import join
7
5
  from pathlib import Path
8
- from typing import Any, Callable, Optional, Union
6
+ from typing import Union
9
7
 
10
8
  from replay.data.dataset_utils import DatasetLabelEncoder
11
- from replay.models import *
12
9
  from replay.models.base_rec import BaseRecommender
13
- from replay.splitters import *
10
+ from replay.splitters import Splitter
14
11
 
15
12
  from .session_handler import State
16
13
  from .types import PYSPARK_AVAILABLE
@@ -43,7 +40,7 @@ if PYSPARK_AVAILABLE:
43
40
  return [str(f.getPath()) for f in statuses]
44
41
 
45
42
 
46
- def save(model: BaseRecommender, path: Union[str, Path], overwrite: bool = False):
43
+ def save(model: "BaseRecommender", path: Union[str, Path], overwrite: bool = False):
47
44
  """
48
45
  Save fitted model to disk as a folder
49
46
 
@@ -86,19 +83,22 @@ def save(model: BaseRecommender, path: Union[str, Path], overwrite: bool = False
86
83
  save_picklable_to_parquet(model.study, join(path, "study"))
87
84
 
88
85
 
89
- def load(path: str, model_type=None) -> BaseRecommender:
86
+ def load(path: str, model_type=None) -> "BaseRecommender":
90
87
  """
91
88
  Load saved model from disk
92
89
 
93
90
  :param path: path to model folder
94
91
  :return: Restored trained model
95
92
  """
93
+ # FIXME: Surely there's a better way to handle this? Not having this method at all perhaps?
94
+ import replay.models as models
95
+
96
96
  spark = State().session
97
97
  args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True)
98
98
  name = args["_model_name"]
99
99
  del args["_model_name"]
100
100
 
101
- model_class = model_type if model_type is not None else globals()[name]
101
+ model_class = model_type if model_type is not None else getattr(models, name)
102
102
 
103
103
  model = model_class(**args)
104
104
 
@@ -175,31 +175,11 @@ def load_splitter(path: str) -> Splitter:
175
175
  :param path: path to folder
176
176
  :return: restored Splitter
177
177
  """
178
+ import replay.splitters as splitters
179
+
178
180
  spark = State().session
179
181
  args = spark.read.json(join(path, "init_args.json")).first().asDict()
180
182
  name = args["_splitter_name"]
181
183
  del args["_splitter_name"]
182
- splitter = globals()[name]
184
+ splitter = getattr(splitters, name)
183
185
  return splitter(**args)
184
-
185
-
186
- def deprecation_warning(message: Optional[str] = None) -> Callable[..., Any]:
187
- """
188
- Decorator that throws deprecation warnings.
189
-
190
- :param message: message to deprecation warning without func name.
191
- """
192
- base_msg = "will be deprecated in future versions."
193
-
194
- def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
195
- @functools.wraps(func)
196
- def wrapper(*args: Any, **kwargs: Any) -> Any:
197
- msg = f"{func.__qualname__} {message if message else base_msg}"
198
- warnings.simplefilter("always", DeprecationWarning) # turn off filter
199
- warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
200
- warnings.simplefilter("default", DeprecationWarning) # reset filter
201
- return func(*args, **kwargs)
202
-
203
- return wrapper
204
-
205
- return decorator
@@ -10,13 +10,13 @@ from typing import Any, Dict, Optional
10
10
 
11
11
  import psutil
12
12
 
13
- from .types import PYSPARK_AVAILABLE, MissingImportType
13
+ from .types import PYSPARK_AVAILABLE, MissingImport
14
14
 
15
15
  if PYSPARK_AVAILABLE:
16
16
  from pyspark import __version__ as pyspark_version
17
17
  from pyspark.sql import SparkSession
18
18
  else:
19
- SparkSession = MissingImportType
19
+ SparkSession = MissingImport
20
20
 
21
21
 
22
22
  def get_spark_session(
@@ -10,7 +10,7 @@ import pandas as pd
10
10
  from numpy.random import default_rng
11
11
 
12
12
  from .session_handler import State
13
- from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImportType, NumType, PolarsDataFrame, SparkDataFrame
13
+ from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImport, NumType, PolarsDataFrame, SparkDataFrame
14
14
 
15
15
  if PYSPARK_AVAILABLE:
16
16
  import pyspark.sql.types as st
@@ -24,7 +24,7 @@ if PYSPARK_AVAILABLE:
24
24
  from pyspark.sql.column import _to_java_column, _to_seq
25
25
  from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType
26
26
  else:
27
- Column = MissingImportType
27
+ Column = MissingImport
28
28
 
29
29
 
30
30
  class PolarsConvertToSparkWarning(Warning):
replay/utils/types.py CHANGED
@@ -1,38 +1,48 @@
1
+ from importlib.util import find_spec
1
2
  from typing import Iterable, Union
2
3
 
3
4
  from pandas import DataFrame as PandasDataFrame
4
5
  from polars import DataFrame as PolarsDataFrame
6
+ from typing_extensions import TypeAlias
5
7
 
6
8
 
7
- class MissingImportType:
9
+ class MissingImport:
8
10
  """
9
11
  Replacement class with missing import
10
12
  """
11
13
 
12
14
 
13
- try:
14
- from pyspark.sql import DataFrame as SparkDataFrame
15
+ class FeatureUnavailableError(Exception):
16
+ """Exception class for failing a conditional import check."""
15
17
 
16
- PYSPARK_AVAILABLE = True
17
- except ImportError:
18
- PYSPARK_AVAILABLE = False
19
- SparkDataFrame = MissingImportType
20
18
 
21
- try:
22
- import torch # noqa: F401
19
+ class FeatureUnavailableWarning(Warning):
20
+ """Warning class for failing a conditional import check."""
23
21
 
24
- TORCH_AVAILABLE = True
25
- except ImportError:
26
- TORCH_AVAILABLE = False
27
22
 
28
- try:
29
- import onnx # noqa: F401
30
- import openvino # noqa: F401
23
+ PYSPARK_AVAILABLE = find_spec("pyspark")
24
+ if not PYSPARK_AVAILABLE:
25
+ SparkDataFrame: TypeAlias = MissingImport
26
+ else:
27
+ from pyspark.sql import DataFrame
31
28
 
32
- OPENVINO_AVAILABLE = TORCH_AVAILABLE
33
- except ImportError:
34
- OPENVINO_AVAILABLE = False
29
+ SparkDataFrame: TypeAlias = DataFrame
30
+
31
+
32
+ TORCH_AVAILABLE = find_spec("torch") and find_spec("lightning")
35
33
 
36
34
  DataFrameLike = Union[PandasDataFrame, SparkDataFrame, PolarsDataFrame]
37
35
  IntOrList = Union[Iterable[int], int]
38
36
  NumType = Union[int, float]
37
+
38
+
39
+ # Conditional import flags
40
+ ANN_AVAILABLE = all(
41
+ [
42
+ find_spec("nmslib"),
43
+ find_spec("hnswlib"),
44
+ find_spec("pyarrow"),
45
+ ]
46
+ )
47
+ OPENVINO_AVAILABLE = TORCH_AVAILABLE and find_spec("onnx") and find_spec("openvino")
48
+ OPTUNA_AVAILABLE = find_spec("optuna")
@@ -0,0 +1,26 @@
1
+ import functools
2
+ import warnings
3
+ from collections.abc import Callable
4
+ from typing import Any, Optional
5
+
6
+
7
+ def deprecation_warning(message: Optional[str] = None) -> Callable[..., Any]:
8
+ """
9
+ Decorator that throws deprecation warnings.
10
+
11
+ :param message: message to deprecation warning without func name.
12
+ """
13
+ base_msg = "will be deprecated in future versions."
14
+
15
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
16
+ @functools.wraps(func)
17
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
18
+ msg = f"{func.__qualname__} {message if message else base_msg}"
19
+ warnings.simplefilter("always", DeprecationWarning) # turn off filter
20
+ warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
21
+ warnings.simplefilter("default", DeprecationWarning) # reset filter
22
+ return func(*args, **kwargs)
23
+
24
+ return wrapper
25
+
26
+ return decorator
@@ -1,45 +1,44 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: replay-rec
3
- Version: 0.19.0
3
+ Version: 0.20.0
4
4
  Summary: RecSys Library
5
- Home-page: https://sb-ai-lab.github.io/RePlay/
6
- License: Apache-2.0
5
+ License-Expression: Apache-2.0
6
+ License-File: LICENSE
7
+ License-File: NOTICE
7
8
  Author: AI Lab
8
- Requires-Python: >=3.8.1,<3.12
9
+ Requires-Python: >=3.9, <3.13
10
+ Classifier: Operating System :: Unix
9
11
  Classifier: Development Status :: 4 - Beta
10
12
  Classifier: Environment :: Console
11
13
  Classifier: Intended Audience :: Developers
12
14
  Classifier: Intended Audience :: Science/Research
13
- Classifier: License :: OSI Approved :: Apache Software License
14
15
  Classifier: Natural Language :: English
15
- Classifier: Operating System :: Unix
16
- Classifier: Programming Language :: Python :: 3
17
- Classifier: Programming Language :: Python :: 3.9
18
- Classifier: Programming Language :: Python :: 3.10
19
- Classifier: Programming Language :: Python :: 3.11
20
16
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
21
- Provides-Extra: all
22
17
  Provides-Extra: spark
23
18
  Provides-Extra: torch
24
- Provides-Extra: torch-openvino
25
- Requires-Dist: fixed-install-nmslib (==2.1.2)
26
- Requires-Dist: hnswlib (>=0.7.0,<0.8.0)
27
- Requires-Dist: lightning (>=2.0.2,<=2.4.0) ; extra == "torch" or extra == "torch-openvino" or extra == "all"
28
- Requires-Dist: numpy (>=1.20.0)
29
- Requires-Dist: onnx (>=1.16.2,<1.17.0) ; extra == "torch-openvino" or extra == "all"
30
- Requires-Dist: openvino (>=2024.3.0,<2024.4.0) ; extra == "torch-openvino" or extra == "all"
31
- Requires-Dist: optuna (>=3.2.0,<3.3.0)
32
- Requires-Dist: pandas (>=1.3.5,<=2.2.2)
33
- Requires-Dist: polars (>=1.0.0,<1.1.0)
34
- Requires-Dist: psutil (>=6.0.0,<6.1.0)
35
- Requires-Dist: pyarrow (>=12.0.1)
36
- Requires-Dist: pyspark (>=3.0,<3.6) ; (python_full_version >= "3.8.1" and python_version < "3.11") and (extra == "spark" or extra == "all")
37
- Requires-Dist: pyspark (>=3.4,<3.6) ; (python_version >= "3.11" and python_version < "3.12") and (extra == "spark" or extra == "all")
38
- Requires-Dist: pytorch-ranger (>=0.1.1,<0.2.0) ; extra == "torch" or extra == "torch-openvino" or extra == "all"
39
- Requires-Dist: scikit-learn (>=1.0.2,<2.0.0)
40
- Requires-Dist: scipy (>=1.8.1,<2.0.0)
41
- Requires-Dist: torch (>=1.8,<3.0.0) ; (python_version >= "3.9") and (extra == "torch" or extra == "torch-openvino" or extra == "all")
42
- Requires-Dist: torch (>=1.8,<=2.4.1) ; (python_version >= "3.8" and python_version < "3.9") and (extra == "torch" or extra == "torch-openvino" or extra == "all")
19
+ Provides-Extra: torch-cpu
20
+ Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
21
+ Requires-Dist: lightning ; extra == "torch"
22
+ Requires-Dist: lightning ; extra == "torch-cpu"
23
+ Requires-Dist: numpy (>=1.20.0,<2)
24
+ Requires-Dist: pandas (>=1.3.5,<2.4.0)
25
+ Requires-Dist: polars (<2.0)
26
+ Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
27
+ Requires-Dist: psutil ; extra == "spark"
28
+ Requires-Dist: pyarrow (<22.0)
29
+ Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
30
+ Requires-Dist: pyspark ; extra == "spark"
31
+ Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
32
+ Requires-Dist: pytorch-optimizer ; extra == "torch"
33
+ Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
34
+ Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
35
+ Requires-Dist: scipy (>=1.13.1,<1.14)
36
+ Requires-Dist: setuptools
37
+ Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
38
+ Requires-Dist: torch ; extra == "torch"
39
+ Requires-Dist: torch ; extra == "torch-cpu"
40
+ Requires-Dist: tqdm (>=4.67,<5)
41
+ Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
43
42
  Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
44
43
  Description-Content-Type: text/markdown
45
44
 
@@ -208,7 +207,6 @@ pip install replay-rec==XX.YY.ZZrc0
208
207
  In addition to the core package, several extras are also provided, including:
209
208
  - `[spark]`: Install PySpark functionality
210
209
  - `[torch]`: Install PyTorch and Lightning functionality
211
- - `[all]`: `[spark]` `[torch]`
212
210
 
213
211
  Example:
214
212
  ```bash
@@ -219,9 +217,35 @@ pip install replay-rec[spark]
219
217
  pip install replay-rec[spark]==XX.YY.ZZrc0
220
218
  ```
221
219
 
220
+ Additionally, `replay-rec[torch]` may be installed with CPU-only version of `torch` by providing its respective index URL during installation:
221
+ ```bash
222
+ # Install package with the CPU version of torch
223
+ pip install replay-rec[torch] --extra-index-url https://download.pytorch.org/whl/cpu
224
+ ```
225
+
226
+
222
227
  To build RePlay from sources please use the [instruction](CONTRIBUTING.md#installing-from-the-source).
223
228
 
224
229
 
230
+ ### Optional features
231
+ RePlay includes a set of optional features which require users to install optional dependencies manually. These features include:
232
+
233
+ 1) Hyperpearameter search via Optuna:
234
+ ```bash
235
+ pip install optuna
236
+ ```
237
+
238
+ 2) Model compilation via OpenVINO:
239
+ ```bash
240
+ pip install openvino onnx
241
+ ```
242
+
243
+ 3) Vector database and hierarchical search support:
244
+ ```bash
245
+ pip install hnswlib fixed-install-nmslib
246
+ ```
247
+
248
+
225
249
  <a name="examples"></a>
226
250
  ## 📑 Resources
227
251