replay-rec 0.20.0rc0__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/nn/sequence_tokenizer.py +10 -3
  3. replay/data/nn/sequential_dataset.py +18 -14
  4. replay/data/nn/torch_sequential_dataset.py +12 -12
  5. replay/models/lin_ucb.py +55 -9
  6. replay/models/nn/sequential/bert4rec/dataset.py +3 -16
  7. replay/models/nn/sequential/postprocessors/postprocessors.py +1 -1
  8. replay/models/nn/sequential/sasrec/dataset.py +3 -16
  9. replay/utils/__init__.py +0 -1
  10. {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1.dist-info}/METADATA +17 -11
  11. {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1.dist-info}/RECORD +14 -70
  12. replay/experimental/__init__.py +0 -0
  13. replay/experimental/metrics/__init__.py +0 -62
  14. replay/experimental/metrics/base_metric.py +0 -603
  15. replay/experimental/metrics/coverage.py +0 -97
  16. replay/experimental/metrics/experiment.py +0 -175
  17. replay/experimental/metrics/hitrate.py +0 -26
  18. replay/experimental/metrics/map.py +0 -30
  19. replay/experimental/metrics/mrr.py +0 -18
  20. replay/experimental/metrics/ncis_precision.py +0 -31
  21. replay/experimental/metrics/ndcg.py +0 -49
  22. replay/experimental/metrics/precision.py +0 -22
  23. replay/experimental/metrics/recall.py +0 -25
  24. replay/experimental/metrics/rocauc.py +0 -49
  25. replay/experimental/metrics/surprisal.py +0 -90
  26. replay/experimental/metrics/unexpectedness.py +0 -76
  27. replay/experimental/models/__init__.py +0 -50
  28. replay/experimental/models/admm_slim.py +0 -257
  29. replay/experimental/models/base_neighbour_rec.py +0 -200
  30. replay/experimental/models/base_rec.py +0 -1386
  31. replay/experimental/models/base_torch_rec.py +0 -234
  32. replay/experimental/models/cql.py +0 -454
  33. replay/experimental/models/ddpg.py +0 -932
  34. replay/experimental/models/dt4rec/__init__.py +0 -0
  35. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  36. replay/experimental/models/dt4rec/gpt1.py +0 -401
  37. replay/experimental/models/dt4rec/trainer.py +0 -127
  38. replay/experimental/models/dt4rec/utils.py +0 -264
  39. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  40. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  41. replay/experimental/models/hierarchical_recommender.py +0 -331
  42. replay/experimental/models/implicit_wrap.py +0 -131
  43. replay/experimental/models/lightfm_wrap.py +0 -303
  44. replay/experimental/models/mult_vae.py +0 -332
  45. replay/experimental/models/neural_ts.py +0 -986
  46. replay/experimental/models/neuromf.py +0 -406
  47. replay/experimental/models/scala_als.py +0 -293
  48. replay/experimental/models/u_lin_ucb.py +0 -115
  49. replay/experimental/nn/data/__init__.py +0 -1
  50. replay/experimental/nn/data/schema_builder.py +0 -102
  51. replay/experimental/preprocessing/__init__.py +0 -3
  52. replay/experimental/preprocessing/data_preparator.py +0 -839
  53. replay/experimental/preprocessing/padder.py +0 -229
  54. replay/experimental/preprocessing/sequence_generator.py +0 -208
  55. replay/experimental/scenarios/__init__.py +0 -1
  56. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  57. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  58. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
  59. replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
  60. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  61. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  62. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  63. replay/experimental/utils/__init__.py +0 -0
  64. replay/experimental/utils/logger.py +0 -24
  65. replay/experimental/utils/model_handler.py +0 -186
  66. replay/experimental/utils/session_handler.py +0 -44
  67. replay/utils/warnings.py +0 -26
  68. {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1.dist-info}/WHEEL +0 -0
  69. {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1.dist-info}/licenses/LICENSE +0 -0
  70. {replay_rec-0.20.0rc0.dist-info → replay_rec-0.20.1.dist-info}/licenses/NOTICE +0 -0
File without changes
@@ -1,24 +0,0 @@
1
- import logging
2
-
3
-
4
- def get_logger(
5
- name,
6
- level=logging.INFO,
7
- format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
8
- date_format="%Y-%m-%d %H:%M:%S",
9
- file=False,
10
- ):
11
- """
12
- Get python logger instance
13
- """
14
- logger = logging.getLogger(name)
15
- logger.setLevel(level)
16
-
17
- if not logger.hasHandlers():
18
- handler = logging.StreamHandler() if not file else logging.FileHandler(name)
19
- handler.setLevel(level)
20
- formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
21
- handler.setFormatter(formatter)
22
- logger.addHandler(handler)
23
-
24
- return logger
@@ -1,186 +0,0 @@
1
- import json
2
- from inspect import getfullargspec
3
- from os.path import join
4
- from pathlib import Path
5
- from typing import Union
6
-
7
- from replay.experimental.models.base_rec import BaseRecommender
8
- from replay.experimental.preprocessing import Indexer
9
- from replay.utils import PYSPARK_AVAILABLE
10
- from replay.utils.session_handler import State
11
- from replay.utils.spark_utils import load_pickled_from_parquet, save_picklable_to_parquet
12
-
13
- if PYSPARK_AVAILABLE:
14
- import pyspark.sql.types as st
15
- from pyspark.ml.feature import IndexToString, StringIndexerModel
16
- from pyspark.sql import SparkSession
17
-
18
- from replay.utils.model_handler import get_fs
19
-
20
- def get_list_of_paths(spark: SparkSession, dir_path: str):
21
- """
22
- Returns list of paths to files in the `dir_path`
23
-
24
- :param spark: spark session
25
- :param dir_path: path to dir in hdfs or local disk
26
- :return: list of paths to files
27
- """
28
- fs = get_fs(spark)
29
- statuses = fs.listStatus(spark._jvm.org.apache.hadoop.fs.Path(dir_path))
30
- return [str(f.getPath()) for f in statuses]
31
-
32
-
33
- def save(model: BaseRecommender, path: Union[str, Path], overwrite: bool = False):
34
- """
35
- Save fitted model to disk as a folder
36
-
37
- :param model: Trained recommender
38
- :param path: destination where model files will be stored
39
- :return:
40
- """
41
- if isinstance(path, Path):
42
- path = str(path)
43
-
44
- spark = State().session
45
-
46
- fs = get_fs(spark)
47
- if not overwrite:
48
- is_exists = fs.exists(spark._jvm.org.apache.hadoop.fs.Path(path))
49
- if is_exists:
50
- msg = f"Path '{path}' already exists. Mode is 'overwrite = False'."
51
- raise FileExistsError(msg)
52
-
53
- fs.mkdirs(spark._jvm.org.apache.hadoop.fs.Path(path))
54
- model._save_model(join(path, "model"))
55
-
56
- init_args = model._init_args
57
- init_args["_model_name"] = str(model)
58
- sc = spark.sparkContext
59
- df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
60
- df.coalesce(1).write.mode("overwrite").option("ignoreNullFields", "false").json(join(path, "init_args.json"))
61
-
62
- dataframes = model._dataframes
63
- df_path = join(path, "dataframes")
64
- for name, df in dataframes.items():
65
- if df is not None:
66
- df.write.mode("overwrite").parquet(join(df_path, name))
67
-
68
- if hasattr(model, "fit_users"):
69
- model.fit_users.write.mode("overwrite").parquet(join(df_path, "fit_users"))
70
- if hasattr(model, "fit_items"):
71
- model.fit_items.write.mode("overwrite").parquet(join(df_path, "fit_items"))
72
- if hasattr(model, "study"):
73
- save_picklable_to_parquet(model.study, join(path, "study"))
74
-
75
-
76
- def load(path: str, model_type=None) -> BaseRecommender:
77
- """
78
- Load saved model from disk
79
-
80
- :param path: path to model folder
81
- :return: Restored trained model
82
- """
83
- spark = State().session
84
- args = spark.read.json(join(path, "init_args.json")).first().asDict(recursive=True)
85
- name = args["_model_name"]
86
- del args["_model_name"]
87
-
88
- model_class = model_type if model_type is not None else globals()[name]
89
- if name == "CQL":
90
- for a in args:
91
- if isinstance(args[a], dict) and "type" in args[a] and args[a]["type"] == "none":
92
- args[a]["params"] = {}
93
- init_args = getfullargspec(model_class.__init__).args
94
- init_args.remove("self")
95
- extra_args = set(args) - set(init_args)
96
- if len(extra_args) > 0:
97
- extra_args = {key: args[key] for key in args}
98
- init_args = {key: args[key] for key in init_args}
99
- else:
100
- init_args = args
101
- extra_args = {}
102
-
103
- model = model_class(**init_args)
104
- for arg in extra_args:
105
- model.arg = extra_args[arg]
106
-
107
- dataframes_paths = get_list_of_paths(spark, join(path, "dataframes"))
108
- for dataframe_path in dataframes_paths:
109
- df = spark.read.parquet(dataframe_path)
110
- attr_name = dataframe_path.split("/")[-1]
111
- setattr(model, attr_name, df)
112
-
113
- model._load_model(join(path, "model"))
114
- fs = get_fs(spark)
115
- model.study = (
116
- load_pickled_from_parquet(join(path, "study"))
117
- if fs.exists(spark._jvm.org.apache.hadoop.fs.Path(join(path, "study")))
118
- else None
119
- )
120
-
121
- return model
122
-
123
-
124
- def save_indexer(indexer: Indexer, path: Union[str, Path], overwrite: bool = False):
125
- """
126
- Save fitted indexer to disk as a folder
127
-
128
- :param indexer: Trained indexer
129
- :param path: destination where indexer files will be stored
130
- """
131
- if isinstance(path, Path):
132
- path = str(path)
133
-
134
- spark = State().session
135
-
136
- if not overwrite:
137
- fs = get_fs(spark)
138
- is_exists = fs.exists(spark._jvm.org.apache.hadoop.fs.Path(path))
139
- if is_exists:
140
- msg = f"Path '{path}' already exists. Mode is 'overwrite = False'."
141
- raise FileExistsError(msg)
142
-
143
- init_args = indexer._init_args
144
- init_args["user_type"] = str(indexer.user_type)
145
- init_args["item_type"] = str(indexer.item_type)
146
- sc = spark.sparkContext
147
- df = spark.read.json(sc.parallelize([json.dumps(init_args)]))
148
- df.coalesce(1).write.mode("overwrite").json(join(path, "init_args.json"))
149
-
150
- indexer.user_indexer.write().overwrite().save(join(path, "user_indexer"))
151
- indexer.item_indexer.write().overwrite().save(join(path, "item_indexer"))
152
- indexer.inv_user_indexer.write().overwrite().save(join(path, "inv_user_indexer"))
153
- indexer.inv_item_indexer.write().overwrite().save(join(path, "inv_item_indexer"))
154
-
155
-
156
- def load_indexer(path: str) -> Indexer:
157
- """
158
- Load saved indexer from disk
159
-
160
- :param path: path to folder
161
- :return: restored Indexer
162
- """
163
- spark = State().session
164
- args = spark.read.json(join(path, "init_args.json")).first().asDict()
165
-
166
- user_type = args["user_type"]
167
- del args["user_type"]
168
- item_type = args["item_type"]
169
- del args["item_type"]
170
-
171
- indexer = Indexer(**args)
172
-
173
- if user_type.endswith("()"):
174
- user_type = user_type[:-2]
175
- item_type = item_type[:-2]
176
- user_type = getattr(st, user_type)
177
- item_type = getattr(st, item_type)
178
- indexer.user_type = user_type()
179
- indexer.item_type = item_type()
180
-
181
- indexer.user_indexer = StringIndexerModel.load(join(path, "user_indexer"))
182
- indexer.item_indexer = StringIndexerModel.load(join(path, "item_indexer"))
183
- indexer.inv_user_indexer = IndexToString.load(join(path, "inv_user_indexer"))
184
- indexer.inv_item_indexer = IndexToString.load(join(path, "inv_item_indexer"))
185
-
186
- return indexer
@@ -1,44 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
-
5
- from replay.utils.session_handler import Borg, get_spark_session, logger_with_settings
6
- from replay.utils.types import PYSPARK_AVAILABLE, MissingImport
7
-
8
- if PYSPARK_AVAILABLE:
9
- from pyspark.sql import SparkSession
10
- else:
11
- SparkSession = MissingImport
12
-
13
-
14
- class State(Borg):
15
- """
16
- All modules look for Spark session via this class. You can put your own session here.
17
-
18
- Other parameters are stored here too: ``default device`` for ``pytorch`` (CPU/CUDA)
19
- """
20
-
21
- def __init__(
22
- self,
23
- session: Optional[SparkSession] = None,
24
- device: Optional[torch.device] = None,
25
- ):
26
- Borg.__init__(self)
27
- if not hasattr(self, "logger_set"):
28
- self.logger = logger_with_settings()
29
- self.logger_set = True
30
-
31
- if session is None:
32
- if not hasattr(self, "session"):
33
- self.session = get_spark_session()
34
- else:
35
- self.session = session
36
-
37
- if device is None:
38
- if not hasattr(self, "device"):
39
- if torch.cuda.is_available():
40
- self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
41
- else:
42
- self.device = torch.device("cpu")
43
- else:
44
- self.device = device
replay/utils/warnings.py DELETED
@@ -1,26 +0,0 @@
1
- import functools
2
- import warnings
3
- from collections.abc import Callable
4
- from typing import Any, Optional
5
-
6
-
7
- def deprecation_warning(message: Optional[str] = None) -> Callable[..., Any]:
8
- """
9
- Decorator that throws deprecation warnings.
10
-
11
- :param message: message to deprecation warning without func name.
12
- """
13
- base_msg = "will be deprecated in future versions."
14
-
15
- def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
16
- @functools.wraps(func)
17
- def wrapper(*args: Any, **kwargs: Any) -> Any:
18
- msg = f"{func.__qualname__} {message if message else base_msg}"
19
- warnings.simplefilter("always", DeprecationWarning) # turn off filter
20
- warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
21
- warnings.simplefilter("default", DeprecationWarning) # reset filter
22
- return func(*args, **kwargs)
23
-
24
- return wrapper
25
-
26
- return decorator