replay-rec 0.19.0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +6 -2
- replay/data/dataset.py +9 -9
- replay/data/nn/__init__.py +6 -6
- replay/data/nn/sequence_tokenizer.py +44 -38
- replay/data/nn/sequential_dataset.py +13 -8
- replay/data/nn/torch_sequential_dataset.py +14 -13
- replay/data/nn/utils.py +1 -1
- replay/metrics/base_metric.py +1 -1
- replay/metrics/coverage.py +7 -11
- replay/metrics/experiment.py +3 -3
- replay/metrics/offline_metrics.py +2 -2
- replay/models/__init__.py +19 -0
- replay/models/association_rules.py +1 -4
- replay/models/base_neighbour_rec.py +6 -9
- replay/models/base_rec.py +44 -293
- replay/models/cat_pop_rec.py +2 -1
- replay/models/common.py +69 -0
- replay/models/extensions/ann/ann_mixin.py +30 -25
- replay/models/extensions/ann/index_builders/driver_hnswlib_index_builder.py +1 -1
- replay/models/extensions/ann/utils.py +4 -3
- replay/models/knn.py +18 -17
- replay/models/nn/sequential/bert4rec/dataset.py +1 -1
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +2 -2
- replay/models/nn/sequential/compiled/__init__.py +10 -0
- replay/models/nn/sequential/compiled/base_compiled_model.py +3 -1
- replay/models/nn/sequential/compiled/bert4rec_compiled.py +11 -2
- replay/models/nn/sequential/compiled/sasrec_compiled.py +5 -1
- replay/models/nn/sequential/sasrec/dataset.py +1 -1
- replay/models/nn/sequential/sasrec/model.py +1 -1
- replay/models/optimization/__init__.py +14 -0
- replay/models/optimization/optuna_mixin.py +279 -0
- replay/{optimization → models/optimization}/optuna_objective.py +13 -15
- replay/models/slim.py +2 -4
- replay/models/word2vec.py +7 -12
- replay/preprocessing/discretizer.py +1 -2
- replay/preprocessing/history_based_fp.py +1 -1
- replay/preprocessing/label_encoder.py +1 -1
- replay/splitters/cold_user_random_splitter.py +13 -7
- replay/splitters/last_n_splitter.py +17 -10
- replay/utils/__init__.py +6 -2
- replay/utils/common.py +4 -2
- replay/utils/model_handler.py +11 -31
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +2 -2
- replay/utils/types.py +28 -18
- replay/utils/warnings.py +26 -0
- {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/METADATA +56 -32
- {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/RECORD +51 -47
- {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info}/WHEEL +1 -1
- replay_rec-0.20.0.dist-info/licenses/NOTICE +41 -0
- replay/optimization/__init__.py +0 -5
- {replay_rec-0.19.0.dist-info → replay_rec-0.20.0.dist-info/licenses}/LICENSE +0 -0
|
@@ -21,12 +21,12 @@ from replay.utils import (
|
|
|
21
21
|
PandasDataFrame,
|
|
22
22
|
PolarsDataFrame,
|
|
23
23
|
SparkDataFrame,
|
|
24
|
-
get_spark_session,
|
|
25
24
|
)
|
|
26
25
|
|
|
27
26
|
if PYSPARK_AVAILABLE:
|
|
28
27
|
from pyspark.sql import Window, functions as sf # noqa: I001
|
|
29
28
|
from pyspark.sql.types import LongType, IntegerType, ArrayType
|
|
29
|
+
from replay.utils.session_handler import get_spark_session
|
|
30
30
|
|
|
31
31
|
HandleUnknownStrategies = Literal["error", "use_default_value", "drop"]
|
|
32
32
|
|
|
@@ -38,12 +38,16 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
38
38
|
item_column: Optional[str] = "item_id",
|
|
39
39
|
):
|
|
40
40
|
"""
|
|
41
|
-
:param test_size:
|
|
42
|
-
|
|
43
|
-
:param
|
|
44
|
-
|
|
45
|
-
:param
|
|
46
|
-
|
|
41
|
+
:param test_size: The proportion of users to allocate to the test set.
|
|
42
|
+
Must be a float between 0.0 and 1.0.
|
|
43
|
+
:param drop_cold_items: Drop items from test DataFrame
|
|
44
|
+
which are not in train DataFrame, default: False.
|
|
45
|
+
:param seed: Seed for the random number generator to ensure
|
|
46
|
+
reproducibility of the split, default: None.
|
|
47
|
+
:param query_column: Name of query interaction column.
|
|
48
|
+
default: ``query_id``.
|
|
49
|
+
:param item_column: Name of item interaction column.
|
|
50
|
+
default: ``item_id``.
|
|
47
51
|
"""
|
|
48
52
|
super().__init__(
|
|
49
53
|
drop_cold_items=drop_cold_items,
|
|
@@ -81,7 +85,9 @@ class ColdUserRandomSplitter(Splitter):
|
|
|
81
85
|
seed=self.seed,
|
|
82
86
|
)
|
|
83
87
|
interactions = interactions.join(
|
|
84
|
-
train_users.withColumn("is_test", sf.lit(False)),
|
|
88
|
+
train_users.withColumn("is_test", sf.lit(False)),
|
|
89
|
+
on=self.query_column,
|
|
90
|
+
how="left",
|
|
85
91
|
).na.fill({"is_test": True})
|
|
86
92
|
|
|
87
93
|
train = interactions.filter(~sf.col("is_test")).drop("is_test")
|
|
@@ -4,7 +4,13 @@ import numpy as np
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import polars as pl
|
|
6
6
|
|
|
7
|
-
from replay.utils import
|
|
7
|
+
from replay.utils import (
|
|
8
|
+
PYSPARK_AVAILABLE,
|
|
9
|
+
DataFrameLike,
|
|
10
|
+
PandasDataFrame,
|
|
11
|
+
PolarsDataFrame,
|
|
12
|
+
SparkDataFrame,
|
|
13
|
+
)
|
|
8
14
|
|
|
9
15
|
from .base_splitter import Splitter
|
|
10
16
|
|
|
@@ -118,14 +124,12 @@ class LastNSplitter(Splitter):
|
|
|
118
124
|
session_id_processing_strategy: str = "test",
|
|
119
125
|
):
|
|
120
126
|
"""
|
|
121
|
-
:param N:
|
|
127
|
+
:param N: Number of last interactions or size of the time window in seconds
|
|
122
128
|
:param divide_column: Name of column for dividing
|
|
123
129
|
in dataframe, default: ``query_id``.
|
|
124
|
-
:param time_column_format: Format of
|
|
125
|
-
|
|
126
|
-
If
|
|
127
|
-
If time_column has already transformed into unix_timestamp type,
|
|
128
|
-
then you can omit this parameter.
|
|
130
|
+
:param time_column_format: Format of the timestamp column,
|
|
131
|
+
used for converting string dates to a numerical timestamp when strategy is 'timedelta'.
|
|
132
|
+
If the column is already a datetime object or a numerical timestamp, this parameter is ignored.
|
|
129
133
|
default: ``yyyy-MM-dd HH:mm:ss``
|
|
130
134
|
:param strategy: Defines the type of data splitting.
|
|
131
135
|
Must be ``interactions`` or ``timedelta``.
|
|
@@ -223,7 +227,8 @@ class LastNSplitter(Splitter):
|
|
|
223
227
|
time_column_type = dict(interactions.dtypes)[self.timestamp_column]
|
|
224
228
|
if time_column_type == "date":
|
|
225
229
|
interactions = interactions.withColumn(
|
|
226
|
-
self.timestamp_column,
|
|
230
|
+
self.timestamp_column,
|
|
231
|
+
sf.unix_timestamp(self.timestamp_column, self.time_column_format),
|
|
227
232
|
)
|
|
228
233
|
|
|
229
234
|
return interactions
|
|
@@ -260,7 +265,8 @@ class LastNSplitter(Splitter):
|
|
|
260
265
|
self, interactions: SparkDataFrame, n: int
|
|
261
266
|
) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
262
267
|
interactions = interactions.withColumn(
|
|
263
|
-
"count",
|
|
268
|
+
"count",
|
|
269
|
+
sf.count(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
|
|
264
270
|
)
|
|
265
271
|
# float(n) - because DataFrame.filter is changing order
|
|
266
272
|
# of sorted DataFrame to descending
|
|
@@ -317,7 +323,8 @@ class LastNSplitter(Splitter):
|
|
|
317
323
|
self, interactions: SparkDataFrame, timedelta: int
|
|
318
324
|
) -> Tuple[SparkDataFrame, SparkDataFrame]:
|
|
319
325
|
inter_with_max_time = interactions.withColumn(
|
|
320
|
-
"max_timestamp",
|
|
326
|
+
"max_timestamp",
|
|
327
|
+
sf.max(self.timestamp_column).over(Window.partitionBy(self.divide_column)),
|
|
321
328
|
)
|
|
322
329
|
inter_with_diff = inter_with_max_time.withColumn(
|
|
323
330
|
"diff_timestamp", sf.col("max_timestamp") - sf.col(self.timestamp_column)
|
replay/utils/__init__.py
CHANGED
|
@@ -1,13 +1,17 @@
|
|
|
1
|
-
from .session_handler import State, get_spark_session
|
|
2
1
|
from .types import (
|
|
2
|
+
ANN_AVAILABLE,
|
|
3
3
|
OPENVINO_AVAILABLE,
|
|
4
|
+
OPTUNA_AVAILABLE,
|
|
4
5
|
PYSPARK_AVAILABLE,
|
|
5
6
|
TORCH_AVAILABLE,
|
|
6
7
|
DataFrameLike,
|
|
8
|
+
FeatureUnavailableError,
|
|
9
|
+
FeatureUnavailableWarning,
|
|
7
10
|
IntOrList,
|
|
8
|
-
|
|
11
|
+
MissingImport,
|
|
9
12
|
NumType,
|
|
10
13
|
PandasDataFrame,
|
|
11
14
|
PolarsDataFrame,
|
|
12
15
|
SparkDataFrame,
|
|
13
16
|
)
|
|
17
|
+
from .warnings import deprecation_warning
|
replay/utils/common.py
CHANGED
|
@@ -126,6 +126,7 @@ def convert2pandas(
|
|
|
126
126
|
"""
|
|
127
127
|
if isinstance(data, PandasDataFrame):
|
|
128
128
|
return data
|
|
129
|
+
|
|
129
130
|
if isinstance(data, PolarsDataFrame):
|
|
130
131
|
return data.to_pandas()
|
|
131
132
|
if isinstance(data, SparkDataFrame):
|
|
@@ -144,10 +145,11 @@ def convert2polars(
|
|
|
144
145
|
:param allow_collect_to_master: If set to False (default) raises a warning
|
|
145
146
|
about collecting parallelized data to the master node.
|
|
146
147
|
"""
|
|
147
|
-
if isinstance(data, PandasDataFrame):
|
|
148
|
-
return pl_from_pandas(data)
|
|
149
148
|
if isinstance(data, PolarsDataFrame):
|
|
150
149
|
return data
|
|
150
|
+
|
|
151
|
+
if isinstance(data, PandasDataFrame):
|
|
152
|
+
return pl_from_pandas(data)
|
|
151
153
|
if isinstance(data, SparkDataFrame):
|
|
152
154
|
return pl_from_pandas(spark_to_pandas(data, allow_collect_to_master, from_constructor=False))
|
|
153
155
|
|
replay/utils/model_handler.py
CHANGED
|
@@ -1,16 +1,13 @@
|
|
|
1
|
-
import functools
|
|
2
1
|
import json
|
|
3
2
|
import os
|
|
4
3
|
import pickle
|
|
5
|
-
import warnings
|
|
6
4
|
from os.path import join
|
|
7
5
|
from pathlib import Path
|
|
8
|
-
from typing import
|
|
6
|
+
from typing import Union
|
|
9
7
|
|
|
10
8
|
from replay.data.dataset_utils import DatasetLabelEncoder
|
|
11
|
-
from replay.models import *
|
|
12
9
|
from replay.models.base_rec import BaseRecommender
|
|
13
|
-
from replay.splitters import
|
|
10
|
+
from replay.splitters import Splitter
|
|
14
11
|
|
|
15
12
|
from .session_handler import State
|
|
16
13
|
from .types import PYSPARK_AVAILABLE
|
|
@@ -43,7 +40,7 @@ if PYSPARK_AVAILABLE:
|
|
|
43
40
|
return [str(f.getPath()) for f in statuses]
|
|
44
41
|
|
|
45
42
|
|
|
46
|
-
def save(model: BaseRecommender, path: Union[str, Path], overwrite: bool = False):
|
|
43
|
+
def save(model: "BaseRecommender", path: Union[str, Path], overwrite: bool = False):
|
|
47
44
|
"""
|
|
48
45
|
Save fitted model to disk as a folder
|
|
49
46
|
|
|
@@ -86,19 +83,22 @@ def save(model: BaseRecommender, path: Union[str, Path], overwrite: bool = False
|
|
|
86
83
|
save_picklable_to_parquet(model.study, join(path, "study"))
|
|
87
84
|
|
|
88
85
|
|
|
89
|
-
def load(path: str, model_type=None) -> BaseRecommender:
|
|
86
|
+
def load(path: str, model_type=None) -> "BaseRecommender":
|
|
90
87
|
"""
|
|
91
88
|
Load saved model from disk
|
|
92
89
|
|
|
93
90
|
:param path: path to model folder
|
|
94
91
|
:return: Restored trained model
|
|
95
92
|
"""
|
|
93
|
+
# FIXME: Surely there's a better way to handle this? Not having this method at all perhaps?
|
|
94
|
+
import replay.models as models
|
|
95
|
+
|
|
96
96
|
spark = State().session
|
|
97
97
|
args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True)
|
|
98
98
|
name = args["_model_name"]
|
|
99
99
|
del args["_model_name"]
|
|
100
100
|
|
|
101
|
-
model_class = model_type if model_type is not None else
|
|
101
|
+
model_class = model_type if model_type is not None else getattr(models, name)
|
|
102
102
|
|
|
103
103
|
model = model_class(**args)
|
|
104
104
|
|
|
@@ -175,31 +175,11 @@ def load_splitter(path: str) -> Splitter:
|
|
|
175
175
|
:param path: path to folder
|
|
176
176
|
:return: restored Splitter
|
|
177
177
|
"""
|
|
178
|
+
import replay.splitters as splitters
|
|
179
|
+
|
|
178
180
|
spark = State().session
|
|
179
181
|
args = spark.read.json(join(path, "init_args.json")).first().asDict()
|
|
180
182
|
name = args["_splitter_name"]
|
|
181
183
|
del args["_splitter_name"]
|
|
182
|
-
splitter =
|
|
184
|
+
splitter = getattr(splitters, name)
|
|
183
185
|
return splitter(**args)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def deprecation_warning(message: Optional[str] = None) -> Callable[..., Any]:
|
|
187
|
-
"""
|
|
188
|
-
Decorator that throws deprecation warnings.
|
|
189
|
-
|
|
190
|
-
:param message: message to deprecation warning without func name.
|
|
191
|
-
"""
|
|
192
|
-
base_msg = "will be deprecated in future versions."
|
|
193
|
-
|
|
194
|
-
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
195
|
-
@functools.wraps(func)
|
|
196
|
-
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
197
|
-
msg = f"{func.__qualname__} {message if message else base_msg}"
|
|
198
|
-
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
|
199
|
-
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
|
|
200
|
-
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
|
201
|
-
return func(*args, **kwargs)
|
|
202
|
-
|
|
203
|
-
return wrapper
|
|
204
|
-
|
|
205
|
-
return decorator
|
replay/utils/session_handler.py
CHANGED
|
@@ -10,13 +10,13 @@ from typing import Any, Dict, Optional
|
|
|
10
10
|
|
|
11
11
|
import psutil
|
|
12
12
|
|
|
13
|
-
from .types import PYSPARK_AVAILABLE,
|
|
13
|
+
from .types import PYSPARK_AVAILABLE, MissingImport
|
|
14
14
|
|
|
15
15
|
if PYSPARK_AVAILABLE:
|
|
16
16
|
from pyspark import __version__ as pyspark_version
|
|
17
17
|
from pyspark.sql import SparkSession
|
|
18
18
|
else:
|
|
19
|
-
SparkSession =
|
|
19
|
+
SparkSession = MissingImport
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def get_spark_session(
|
replay/utils/spark_utils.py
CHANGED
|
@@ -10,7 +10,7 @@ import pandas as pd
|
|
|
10
10
|
from numpy.random import default_rng
|
|
11
11
|
|
|
12
12
|
from .session_handler import State
|
|
13
|
-
from .types import PYSPARK_AVAILABLE, DataFrameLike,
|
|
13
|
+
from .types import PYSPARK_AVAILABLE, DataFrameLike, MissingImport, NumType, PolarsDataFrame, SparkDataFrame
|
|
14
14
|
|
|
15
15
|
if PYSPARK_AVAILABLE:
|
|
16
16
|
import pyspark.sql.types as st
|
|
@@ -24,7 +24,7 @@ if PYSPARK_AVAILABLE:
|
|
|
24
24
|
from pyspark.sql.column import _to_java_column, _to_seq
|
|
25
25
|
from pyspark.sql.types import DoubleType, IntegerType, StructField, StructType
|
|
26
26
|
else:
|
|
27
|
-
Column =
|
|
27
|
+
Column = MissingImport
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class PolarsConvertToSparkWarning(Warning):
|
replay/utils/types.py
CHANGED
|
@@ -1,38 +1,48 @@
|
|
|
1
|
+
from importlib.util import find_spec
|
|
1
2
|
from typing import Iterable, Union
|
|
2
3
|
|
|
3
4
|
from pandas import DataFrame as PandasDataFrame
|
|
4
5
|
from polars import DataFrame as PolarsDataFrame
|
|
6
|
+
from typing_extensions import TypeAlias
|
|
5
7
|
|
|
6
8
|
|
|
7
|
-
class
|
|
9
|
+
class MissingImport:
|
|
8
10
|
"""
|
|
9
11
|
Replacement class with missing import
|
|
10
12
|
"""
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
+
class FeatureUnavailableError(Exception):
|
|
16
|
+
"""Exception class for failing a conditional import check."""
|
|
15
17
|
|
|
16
|
-
PYSPARK_AVAILABLE = True
|
|
17
|
-
except ImportError:
|
|
18
|
-
PYSPARK_AVAILABLE = False
|
|
19
|
-
SparkDataFrame = MissingImportType
|
|
20
18
|
|
|
21
|
-
|
|
22
|
-
|
|
19
|
+
class FeatureUnavailableWarning(Warning):
|
|
20
|
+
"""Warning class for failing a conditional import check."""
|
|
23
21
|
|
|
24
|
-
TORCH_AVAILABLE = True
|
|
25
|
-
except ImportError:
|
|
26
|
-
TORCH_AVAILABLE = False
|
|
27
22
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
23
|
+
PYSPARK_AVAILABLE = find_spec("pyspark")
|
|
24
|
+
if not PYSPARK_AVAILABLE:
|
|
25
|
+
SparkDataFrame: TypeAlias = MissingImport
|
|
26
|
+
else:
|
|
27
|
+
from pyspark.sql import DataFrame
|
|
31
28
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
29
|
+
SparkDataFrame: TypeAlias = DataFrame
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
TORCH_AVAILABLE = find_spec("torch") and find_spec("lightning")
|
|
35
33
|
|
|
36
34
|
DataFrameLike = Union[PandasDataFrame, SparkDataFrame, PolarsDataFrame]
|
|
37
35
|
IntOrList = Union[Iterable[int], int]
|
|
38
36
|
NumType = Union[int, float]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Conditional import flags
|
|
40
|
+
ANN_AVAILABLE = all(
|
|
41
|
+
[
|
|
42
|
+
find_spec("nmslib"),
|
|
43
|
+
find_spec("hnswlib"),
|
|
44
|
+
find_spec("pyarrow"),
|
|
45
|
+
]
|
|
46
|
+
)
|
|
47
|
+
OPENVINO_AVAILABLE = TORCH_AVAILABLE and find_spec("onnx") and find_spec("openvino")
|
|
48
|
+
OPTUNA_AVAILABLE = find_spec("optuna")
|
replay/utils/warnings.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def deprecation_warning(message: Optional[str] = None) -> Callable[..., Any]:
|
|
8
|
+
"""
|
|
9
|
+
Decorator that throws deprecation warnings.
|
|
10
|
+
|
|
11
|
+
:param message: message to deprecation warning without func name.
|
|
12
|
+
"""
|
|
13
|
+
base_msg = "will be deprecated in future versions."
|
|
14
|
+
|
|
15
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
16
|
+
@functools.wraps(func)
|
|
17
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
18
|
+
msg = f"{func.__qualname__} {message if message else base_msg}"
|
|
19
|
+
warnings.simplefilter("always", DeprecationWarning) # turn off filter
|
|
20
|
+
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
|
|
21
|
+
warnings.simplefilter("default", DeprecationWarning) # reset filter
|
|
22
|
+
return func(*args, **kwargs)
|
|
23
|
+
|
|
24
|
+
return wrapper
|
|
25
|
+
|
|
26
|
+
return decorator
|
|
@@ -1,45 +1,44 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: replay-rec
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.20.0
|
|
4
4
|
Summary: RecSys Library
|
|
5
|
-
|
|
6
|
-
License:
|
|
5
|
+
License-Expression: Apache-2.0
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
License-File: NOTICE
|
|
7
8
|
Author: AI Lab
|
|
8
|
-
Requires-Python: >=3.
|
|
9
|
+
Requires-Python: >=3.9, <3.13
|
|
10
|
+
Classifier: Operating System :: Unix
|
|
9
11
|
Classifier: Development Status :: 4 - Beta
|
|
10
12
|
Classifier: Environment :: Console
|
|
11
13
|
Classifier: Intended Audience :: Developers
|
|
12
14
|
Classifier: Intended Audience :: Science/Research
|
|
13
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
|
14
15
|
Classifier: Natural Language :: English
|
|
15
|
-
Classifier: Operating System :: Unix
|
|
16
|
-
Classifier: Programming Language :: Python :: 3
|
|
17
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
18
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
19
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
20
16
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
-
Provides-Extra: all
|
|
22
17
|
Provides-Extra: spark
|
|
23
18
|
Provides-Extra: torch
|
|
24
|
-
Provides-Extra: torch-
|
|
25
|
-
Requires-Dist:
|
|
26
|
-
Requires-Dist:
|
|
27
|
-
Requires-Dist: lightning
|
|
28
|
-
Requires-Dist: numpy (>=1.20.0)
|
|
29
|
-
Requires-Dist:
|
|
30
|
-
Requires-Dist:
|
|
31
|
-
Requires-Dist:
|
|
32
|
-
Requires-Dist:
|
|
33
|
-
Requires-Dist:
|
|
34
|
-
Requires-Dist:
|
|
35
|
-
Requires-Dist:
|
|
36
|
-
Requires-Dist:
|
|
37
|
-
Requires-Dist:
|
|
38
|
-
Requires-Dist: pytorch-
|
|
39
|
-
Requires-Dist: scikit-learn (>=1.
|
|
40
|
-
Requires-Dist: scipy (>=1.
|
|
41
|
-
Requires-Dist:
|
|
42
|
-
Requires-Dist: torch (>=1.8
|
|
19
|
+
Provides-Extra: torch-cpu
|
|
20
|
+
Requires-Dist: lightning (<2.6.0) ; extra == "torch" or extra == "torch-cpu"
|
|
21
|
+
Requires-Dist: lightning ; extra == "torch"
|
|
22
|
+
Requires-Dist: lightning ; extra == "torch-cpu"
|
|
23
|
+
Requires-Dist: numpy (>=1.20.0,<2)
|
|
24
|
+
Requires-Dist: pandas (>=1.3.5,<2.4.0)
|
|
25
|
+
Requires-Dist: polars (<2.0)
|
|
26
|
+
Requires-Dist: psutil (<=7.0.0) ; extra == "spark"
|
|
27
|
+
Requires-Dist: psutil ; extra == "spark"
|
|
28
|
+
Requires-Dist: pyarrow (<22.0)
|
|
29
|
+
Requires-Dist: pyspark (>=3.0,<3.5) ; extra == "spark"
|
|
30
|
+
Requires-Dist: pyspark ; extra == "spark"
|
|
31
|
+
Requires-Dist: pytorch-optimizer (>=3.8.0,<3.9.0) ; extra == "torch" or extra == "torch-cpu"
|
|
32
|
+
Requires-Dist: pytorch-optimizer ; extra == "torch"
|
|
33
|
+
Requires-Dist: pytorch-optimizer ; extra == "torch-cpu"
|
|
34
|
+
Requires-Dist: scikit-learn (>=1.6.1,<1.7.0)
|
|
35
|
+
Requires-Dist: scipy (>=1.13.1,<1.14)
|
|
36
|
+
Requires-Dist: setuptools
|
|
37
|
+
Requires-Dist: torch (>=1.8,<3.0.0) ; extra == "torch" or extra == "torch-cpu"
|
|
38
|
+
Requires-Dist: torch ; extra == "torch"
|
|
39
|
+
Requires-Dist: torch ; extra == "torch-cpu"
|
|
40
|
+
Requires-Dist: tqdm (>=4.67,<5)
|
|
41
|
+
Project-URL: Homepage, https://sb-ai-lab.github.io/RePlay/
|
|
43
42
|
Project-URL: Repository, https://github.com/sb-ai-lab/RePlay
|
|
44
43
|
Description-Content-Type: text/markdown
|
|
45
44
|
|
|
@@ -208,7 +207,6 @@ pip install replay-rec==XX.YY.ZZrc0
|
|
|
208
207
|
In addition to the core package, several extras are also provided, including:
|
|
209
208
|
- `[spark]`: Install PySpark functionality
|
|
210
209
|
- `[torch]`: Install PyTorch and Lightning functionality
|
|
211
|
-
- `[all]`: `[spark]` `[torch]`
|
|
212
210
|
|
|
213
211
|
Example:
|
|
214
212
|
```bash
|
|
@@ -219,9 +217,35 @@ pip install replay-rec[spark]
|
|
|
219
217
|
pip install replay-rec[spark]==XX.YY.ZZrc0
|
|
220
218
|
```
|
|
221
219
|
|
|
220
|
+
Additionally, `replay-rec[torch]` may be installed with CPU-only version of `torch` by providing its respective index URL during installation:
|
|
221
|
+
```bash
|
|
222
|
+
# Install package with the CPU version of torch
|
|
223
|
+
pip install replay-rec[torch] --extra-index-url https://download.pytorch.org/whl/cpu
|
|
224
|
+
```
|
|
225
|
+
|
|
226
|
+
|
|
222
227
|
To build RePlay from sources please use the [instruction](CONTRIBUTING.md#installing-from-the-source).
|
|
223
228
|
|
|
224
229
|
|
|
230
|
+
### Optional features
|
|
231
|
+
RePlay includes a set of optional features which require users to install optional dependencies manually. These features include:
|
|
232
|
+
|
|
233
|
+
1) Hyperpearameter search via Optuna:
|
|
234
|
+
```bash
|
|
235
|
+
pip install optuna
|
|
236
|
+
```
|
|
237
|
+
|
|
238
|
+
2) Model compilation via OpenVINO:
|
|
239
|
+
```bash
|
|
240
|
+
pip install openvino onnx
|
|
241
|
+
```
|
|
242
|
+
|
|
243
|
+
3) Vector database and hierarchical search support:
|
|
244
|
+
```bash
|
|
245
|
+
pip install hnswlib fixed-install-nmslib
|
|
246
|
+
```
|
|
247
|
+
|
|
248
|
+
|
|
225
249
|
<a name="examples"></a>
|
|
226
250
|
## 📑 Resources
|
|
227
251
|
|