replay-rec 0.20.3__py3-none-any.whl → 0.20.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. replay/__init__.py +1 -1
  2. replay/experimental/__init__.py +0 -0
  3. replay/experimental/metrics/__init__.py +62 -0
  4. replay/experimental/metrics/base_metric.py +603 -0
  5. replay/experimental/metrics/coverage.py +97 -0
  6. replay/experimental/metrics/experiment.py +175 -0
  7. replay/experimental/metrics/hitrate.py +26 -0
  8. replay/experimental/metrics/map.py +30 -0
  9. replay/experimental/metrics/mrr.py +18 -0
  10. replay/experimental/metrics/ncis_precision.py +31 -0
  11. replay/experimental/metrics/ndcg.py +49 -0
  12. replay/experimental/metrics/precision.py +22 -0
  13. replay/experimental/metrics/recall.py +25 -0
  14. replay/experimental/metrics/rocauc.py +49 -0
  15. replay/experimental/metrics/surprisal.py +90 -0
  16. replay/experimental/metrics/unexpectedness.py +76 -0
  17. replay/experimental/models/__init__.py +50 -0
  18. replay/experimental/models/admm_slim.py +257 -0
  19. replay/experimental/models/base_neighbour_rec.py +200 -0
  20. replay/experimental/models/base_rec.py +1386 -0
  21. replay/experimental/models/base_torch_rec.py +234 -0
  22. replay/experimental/models/cql.py +454 -0
  23. replay/experimental/models/ddpg.py +932 -0
  24. replay/experimental/models/dt4rec/__init__.py +0 -0
  25. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  26. replay/experimental/models/dt4rec/gpt1.py +401 -0
  27. replay/experimental/models/dt4rec/trainer.py +127 -0
  28. replay/experimental/models/dt4rec/utils.py +264 -0
  29. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  30. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  31. replay/experimental/models/hierarchical_recommender.py +331 -0
  32. replay/experimental/models/implicit_wrap.py +131 -0
  33. replay/experimental/models/lightfm_wrap.py +303 -0
  34. replay/experimental/models/mult_vae.py +332 -0
  35. replay/experimental/models/neural_ts.py +986 -0
  36. replay/experimental/models/neuromf.py +406 -0
  37. replay/experimental/models/scala_als.py +293 -0
  38. replay/experimental/models/u_lin_ucb.py +115 -0
  39. replay/experimental/nn/data/__init__.py +1 -0
  40. replay/experimental/nn/data/schema_builder.py +102 -0
  41. replay/experimental/preprocessing/__init__.py +3 -0
  42. replay/experimental/preprocessing/data_preparator.py +839 -0
  43. replay/experimental/preprocessing/padder.py +229 -0
  44. replay/experimental/preprocessing/sequence_generator.py +208 -0
  45. replay/experimental/scenarios/__init__.py +1 -0
  46. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  47. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  48. replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
  49. replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
  50. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  51. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  52. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  53. replay/experimental/utils/__init__.py +0 -0
  54. replay/experimental/utils/logger.py +24 -0
  55. replay/experimental/utils/model_handler.py +186 -0
  56. replay/experimental/utils/session_handler.py +44 -0
  57. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/METADATA +11 -17
  58. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/RECORD +61 -6
  59. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/WHEEL +0 -0
  60. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/LICENSE +0 -0
  61. {replay_rec-0.20.3.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,234 @@
1
+ from abc import abstractmethod
2
+ from typing import Any, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
8
+ from torch.optim.optimizer import Optimizer
9
+ from torch.utils.data import DataLoader
10
+
11
+ from replay.data import get_schema
12
+ from replay.experimental.models.base_rec import Recommender
13
+ from replay.experimental.utils.session_handler import State
14
+ from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
15
+
16
+ if PYSPARK_AVAILABLE:
17
+ from pyspark.sql import functions as sf
18
+
19
+
20
+ class TorchRecommender(Recommender):
21
+ """Base class for neural recommenders"""
22
+
23
+ model: Any
24
+ device: torch.device
25
+
26
+ def __init__(self):
27
+ self.logger.info("The model is neural network with non-distributed training")
28
+ self.checkpoint_path = State().session.conf.get("spark.local.dir")
29
+ self.device = State().device
30
+
31
+ def _run_train_step(self, batch, optimizer):
32
+ self.model.train()
33
+ optimizer.zero_grad()
34
+ model_result = self._batch_pass(batch, self.model)
35
+ loss = self._loss(**model_result)
36
+ loss.backward()
37
+ optimizer.step()
38
+ return loss.item()
39
+
40
+ def _run_validation(self, valid_data_loader: DataLoader, epoch: int) -> float:
41
+ self.model.eval()
42
+ valid_loss = 0
43
+ with torch.no_grad():
44
+ for batch in valid_data_loader:
45
+ model_result = self._batch_pass(batch, self.model)
46
+ valid_loss += self._loss(**model_result)
47
+ valid_loss /= len(valid_data_loader)
48
+ valid_debug_message = f"""Epoch[{epoch}] validation
49
+ average loss: {valid_loss:.5f}"""
50
+ self.logger.debug(valid_debug_message)
51
+ return valid_loss.item()
52
+
53
+ def train(
54
+ self,
55
+ train_data_loader: DataLoader,
56
+ valid_data_loader: DataLoader,
57
+ optimizer: Optimizer,
58
+ lr_scheduler: ReduceLROnPlateau,
59
+ epochs: int,
60
+ model_name: str,
61
+ ) -> None:
62
+ """
63
+ Run training loop
64
+ :param train_data_loader: data loader for training
65
+ :param valid_data_loader: data loader for validation
66
+ :param optimizer: optimizer
67
+ :param lr_scheduler: scheduler used to decrease learning rate
68
+ :param lr_scheduler: scheduler used to decrease learning rate
69
+ :param epochs: num training epochs
70
+ :param model_name: model name for checkpoint saving
71
+ :return:
72
+ """
73
+ best_valid_loss = np.inf
74
+ for epoch in range(epochs):
75
+ for batch in train_data_loader:
76
+ train_loss = self._run_train_step(batch, optimizer)
77
+
78
+ train_debug_message = f"""Epoch[{epoch}] current loss:
79
+ {train_loss:.5f}"""
80
+ self.logger.debug(train_debug_message)
81
+
82
+ valid_loss = self._run_validation(valid_data_loader, epoch)
83
+ lr_scheduler.step(valid_loss)
84
+
85
+ if valid_loss < best_valid_loss:
86
+ best_checkpoint = "/".join(
87
+ [
88
+ self.checkpoint_path,
89
+ f"/best_{model_name}_{epoch+1}_loss={valid_loss}.pt",
90
+ ]
91
+ )
92
+ self._save_model(best_checkpoint)
93
+ best_valid_loss = valid_loss
94
+ self._load_model(best_checkpoint)
95
+
96
+ @abstractmethod
97
+ def _batch_pass(self, batch, model) -> dict[str, Any]:
98
+ """
99
+ Apply model to a single batch.
100
+
101
+ :param batch: data batch
102
+ :param model: model object
103
+ :return: dictionary used to calculate loss.
104
+ """
105
+
106
+ @abstractmethod
107
+ def _loss(self, **kwargs) -> torch.Tensor:
108
+ """
109
+ Returns loss value
110
+
111
+ :param **kwargs: dictionary used to calculate loss
112
+ :return: 1x1 tensor
113
+ """
114
+
115
+ def _predict(
116
+ self,
117
+ log: SparkDataFrame,
118
+ k: int,
119
+ users: SparkDataFrame,
120
+ items: SparkDataFrame,
121
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
122
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
123
+ filter_seen_items: bool = True, # noqa: ARG002
124
+ ) -> SparkDataFrame:
125
+ items_consider_in_pred = items.toPandas()["item_idx"].values
126
+ items_count = self._item_dim
127
+ model = self.model.cpu()
128
+ agg_fn = self._predict_by_user
129
+
130
+ def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
131
+ return agg_fn(pandas_df, model, items_consider_in_pred, k, items_count)[
132
+ ["user_idx", "item_idx", "relevance"]
133
+ ]
134
+
135
+ self.logger.debug("Predict started")
136
+ # do not apply map on cold users for MultVAE predict
137
+ join_type = "inner" if str(self) == "MultVAE" else "left"
138
+ rec_schema = get_schema(
139
+ query_column="user_idx",
140
+ item_column="item_idx",
141
+ rating_column="relevance",
142
+ has_timestamp=False,
143
+ )
144
+ recs = (
145
+ users.join(log, how=join_type, on="user_idx")
146
+ .select("user_idx", "item_idx")
147
+ .groupby("user_idx")
148
+ .applyInPandas(grouped_map, rec_schema)
149
+ )
150
+ return recs
151
+
152
+ def _predict_pairs(
153
+ self,
154
+ pairs: SparkDataFrame,
155
+ log: Optional[SparkDataFrame] = None,
156
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
157
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
158
+ ) -> SparkDataFrame:
159
+ items_count = self._item_dim
160
+ model = self.model.cpu()
161
+ agg_fn = self._predict_by_user_pairs
162
+ users = pairs.select("user_idx").distinct()
163
+
164
+ def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
165
+ return agg_fn(pandas_df, model, items_count)[["user_idx", "item_idx", "relevance"]]
166
+
167
+ self.logger.debug("Calculate relevance for user-item pairs")
168
+ user_history = (
169
+ users.join(log, how="inner", on="user_idx")
170
+ .groupBy("user_idx")
171
+ .agg(sf.collect_list("item_idx").alias("item_idx_history"))
172
+ )
173
+ user_pairs = pairs.groupBy("user_idx").agg(sf.collect_list("item_idx").alias("item_idx_to_pred"))
174
+ full_df = user_pairs.join(user_history, on="user_idx", how="inner")
175
+
176
+ rec_schema = get_schema(
177
+ query_column="user_idx",
178
+ item_column="item_idx",
179
+ rating_column="relevance",
180
+ has_timestamp=False,
181
+ )
182
+ recs = full_df.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
183
+
184
+ return recs
185
+
186
+ @staticmethod
187
+ @abstractmethod
188
+ def _predict_by_user(
189
+ pandas_df: PandasDataFrame,
190
+ model: nn.Module,
191
+ items_np: np.ndarray,
192
+ k: int,
193
+ item_count: int,
194
+ ) -> PandasDataFrame:
195
+ """
196
+ Calculate predictions.
197
+
198
+ :param pandas_df: DataFrame with user-item interactions ``[user_idx, item_idx]``
199
+ :param model: trained model
200
+ :param items_np: items available for recommendations
201
+ :param k: length of recommendation list
202
+ :param item_count: total number of items
203
+ :return: DataFrame ``[user_idx , item_idx , relevance]``
204
+ """
205
+
206
+ @staticmethod
207
+ @abstractmethod
208
+ def _predict_by_user_pairs(
209
+ pandas_df: PandasDataFrame,
210
+ model: nn.Module,
211
+ item_count: int,
212
+ ) -> PandasDataFrame:
213
+ """
214
+ Get relevance for provided pairs
215
+
216
+ :param pandas_df: DataFrame with rated items and items that need prediction
217
+ ``[user_idx, item_idx_history, item_idx_to_pred]``
218
+ :param model: trained model
219
+ :param item_count: total number of items
220
+ :return: DataFrame ``[user_idx , item_idx , relevance]``
221
+ """
222
+
223
+ def load_model(self, path: str) -> None:
224
+ """
225
+ Load model from file
226
+
227
+ :param path: path to model
228
+ :return:
229
+ """
230
+ self.logger.debug("-- Loading model from file")
231
+ self.model.load_state_dict(torch.load(path))
232
+
233
+ def _save_model(self, path: str) -> None:
234
+ torch.save(self.model.state_dict(), path)
@@ -0,0 +1,454 @@
1
+ """
2
+ Using CQL implementation from `d3rlpy` package.
3
+ """
4
+
5
+ import io
6
+ import logging
7
+ import tempfile
8
+ import timeit
9
+ from typing import Any, Optional, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from d3rlpy.algos import (
14
+ CQL as CQL_d3rlpy, # noqa: N811
15
+ CQLConfig,
16
+ )
17
+ from d3rlpy.base import LearnableConfigWithShape
18
+ from d3rlpy.constants import IMPL_NOT_INITIALIZED_ERROR
19
+ from d3rlpy.dataset import MDPDataset
20
+ from d3rlpy.models.encoders import DefaultEncoderFactory, EncoderFactory
21
+ from d3rlpy.models.q_functions import MeanQFunctionFactory, QFunctionFactory
22
+ from d3rlpy.optimizers import AdamFactory, OptimizerFactory
23
+ from d3rlpy.preprocessing import (
24
+ ActionScaler,
25
+ ObservationScaler,
26
+ RewardScaler,
27
+ )
28
+
29
+ from replay.data import get_schema
30
+ from replay.experimental.models.base_rec import Recommender
31
+ from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
32
+ from replay.utils.spark_utils import assert_omp_single_thread
33
+
34
+ if PYSPARK_AVAILABLE:
35
+ from pyspark.sql import (
36
+ Window,
37
+ functions as sf,
38
+ )
39
+
40
+ timer = timeit.default_timer
41
+
42
+
43
+ class CQL(Recommender):
44
+ """Conservative Q-Learning algorithm.
45
+
46
+ CQL is a SAC-based data-driven deep reinforcement learning algorithm, which
47
+ achieves state-of-the-art performance in offline RL problems.
48
+
49
+ CQL mitigates overestimation error by minimizing action-values under the
50
+ current policy and maximizing values under data distribution for
51
+ underestimation issue.
52
+
53
+ .. math::
54
+ L(\\theta_i) = \\alpha\\, \\mathbb{E}_{s_t \\sim D}
55
+ \\left[\\log{\\sum_a \\exp{Q_{\\theta_i}(s_t, a)}} -
56
+ \\mathbb{E}_{a \\sim D} \\big[Q_{\\theta_i}(s_t, a)\\big] - \\tau\\right]
57
+ + L_\\mathrm{SAC}(\\theta_i)
58
+
59
+ where :math:`\alpha` is an automatically adjustable value via Lagrangian
60
+ dual gradient descent and :math:`\tau` is a threshold value.
61
+ If the action-value difference is smaller than :math:`\tau`, the
62
+ :math:`\alpha` will become smaller.
63
+ Otherwise, the :math:`\alpha` will become larger to aggressively penalize
64
+ action-values.
65
+
66
+ In continuous control, :math:`\\log{\\sum_a \\exp{Q(s, a)}}` is computed as
67
+ follows.
68
+
69
+ .. math::
70
+ \\log{\\sum_a \\exp{Q(s, a)}} \\approx \\log{\\left(
71
+ \\frac{1}{2N} \\sum_{a_i \\sim \\text{Unif}(a)}^N
72
+ \\left[\\frac{\\exp{Q(s, a_i)}}{\\text{Unif}(a)}\\right]
73
+ + \\frac{1}{2N} \\sum_{a_i \\sim \\pi_\\phi(a|s)}^N
74
+ \\left[\\frac{\\exp{Q(s, a_i)}}{\\pi_\\phi(a_i|s)}\\right]\\right)}
75
+
76
+ where :math:`N` is the number of sampled actions.
77
+
78
+ An implementation of this algorithm is heavily based on the corresponding implementation
79
+ in the d3rlpy library (see https://github.com/takuseno/d3rlpy/blob/master/d3rlpy/algos/cql.py)
80
+
81
+ The rest of optimization is exactly same as :class:`d3rlpy.algos.SAC`.
82
+
83
+ References:
84
+ * `Kumar et al., Conservative Q-Learning for Offline Reinforcement
85
+ Learning. <https://arxiv.org/abs/2006.04779>`_
86
+
87
+ Args:
88
+ mdp_dataset_builder (MdpDatasetBuilder): the MDP dataset builder from users' log.
89
+ actor_learning_rate (float): learning rate for policy function.
90
+ critic_learning_rate (float): learning rate for Q functions.
91
+ temp_learning_rate (float): learning rate for temperature parameter of SAC.
92
+ alpha_learning_rate (float): learning rate for :math:`\alpha`.
93
+ actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
94
+ optimizer factory for the actor.
95
+ The available options are `[SGD, Adam or RMSprop]`.
96
+ critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
97
+ optimizer factory for the critic.
98
+ The available options are `[SGD, Adam or RMSprop]`.
99
+ temp_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
100
+ optimizer factory for the temperature.
101
+ The available options are `[SGD, Adam or RMSprop]`.
102
+ alpha_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
103
+ optimizer factory for :math:`\alpha`.
104
+ The available options are `[SGD, Adam or RMSprop]`.
105
+ actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
106
+ encoder factory for the actor.
107
+ The available options are `['pixel', 'dense', 'vector', 'default']`.
108
+ See d3rlpy.models.encoders.EncoderFactory for details.
109
+ critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
110
+ encoder factory for the critic.
111
+ The available options are `['pixel', 'dense', 'vector', 'default']`.
112
+ See d3rlpy.models.encoders.EncoderFactory for details.
113
+ q_func_factory (d3rlpy.models.q_functions.QFunctionFactory or str):
114
+ Q function factory. The available options are `['mean', 'qr', 'iqn', 'fqf']`.
115
+ See d3rlpy.models.q_functions.QFunctionFactory for details.
116
+ batch_size (int): mini-batch size.
117
+ n_steps (int): Number of training steps.
118
+ gamma (float): discount factor.
119
+ tau (float): target network synchronization coefficient.
120
+ n_critics (int): the number of Q functions for ensemble.
121
+ initial_temperature (float): initial temperature value.
122
+ initial_alpha (float): initial :math:`\alpha` value.
123
+ alpha_threshold (float): threshold value described as :math:`\tau`.
124
+ conservative_weight (float): constant weight to scale conservative loss.
125
+ n_action_samples (int): the number of sampled actions to compute
126
+ :math:`\\log{\\sum_a \\exp{Q(s, a)}}`.
127
+ soft_q_backup (bool): flag to use SAC-style backup.
128
+ use_gpu (Union[int, str, bool]): device option.
129
+ If the value is boolean and True, cuda:0 will be used.
130
+ If the value is integer, cuda:<device> will be used.
131
+ If the value is string in torch device style, the specified device will be used.
132
+ observation_scaler (d3rlpy.preprocessing.Scaler or str): preprocessor.
133
+ The available options are `['pixel', 'min_max', 'standard']`.
134
+ action_scaler (d3rlpy.preprocessing.ActionScaler or str):
135
+ action preprocessor. The available options are `['min_max']`.
136
+ reward_scaler (d3rlpy.preprocessing.RewardScaler or str):
137
+ reward preprocessor. The available options are
138
+ `['clip', 'min_max', 'standard']`.
139
+ impl (d3rlpy.algos.torch.cql_impl.CQLImpl): algorithm implementation.
140
+ """
141
+
142
+ mdp_dataset_builder: "MdpDatasetBuilder"
143
+ model: CQL_d3rlpy
144
+
145
+ can_predict_cold_users = True
146
+
147
+ _observation_shape = (2,)
148
+ _action_size = 1
149
+
150
+ _search_space = {
151
+ "actor_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
152
+ "critic_learning_rate": {"type": "loguniform", "args": [3e-5, 3e-4]},
153
+ "temp_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
154
+ "alpha_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
155
+ "gamma": {"type": "loguniform", "args": [0.9, 0.999]},
156
+ "n_critics": {"type": "int", "args": [2, 4]},
157
+ }
158
+
159
+ def __init__(
160
+ self,
161
+ mdp_dataset_builder: "MdpDatasetBuilder",
162
+ # CQL inner params
163
+ actor_learning_rate: float = 1e-4,
164
+ critic_learning_rate: float = 3e-4,
165
+ temp_learning_rate: float = 1e-4,
166
+ alpha_learning_rate: float = 1e-4,
167
+ actor_optim_factory: OptimizerFactory = AdamFactory(),
168
+ critic_optim_factory: OptimizerFactory = AdamFactory(),
169
+ temp_optim_factory: OptimizerFactory = AdamFactory(),
170
+ alpha_optim_factory: OptimizerFactory = AdamFactory(),
171
+ actor_encoder_factory: EncoderFactory = DefaultEncoderFactory(),
172
+ critic_encoder_factory: EncoderFactory = DefaultEncoderFactory(),
173
+ q_func_factory: QFunctionFactory = MeanQFunctionFactory(),
174
+ batch_size: int = 64,
175
+ n_steps: int = 1,
176
+ gamma: float = 0.99,
177
+ tau: float = 0.005,
178
+ n_critics: int = 2,
179
+ initial_temperature: float = 1.0,
180
+ initial_alpha: float = 1.0,
181
+ alpha_threshold: float = 10.0,
182
+ conservative_weight: float = 5.0,
183
+ n_action_samples: int = 10,
184
+ soft_q_backup: bool = False,
185
+ use_gpu: Union[int, str, bool] = False,
186
+ observation_scaler: ObservationScaler = None,
187
+ action_scaler: ActionScaler = None,
188
+ reward_scaler: RewardScaler = None,
189
+ **params,
190
+ ):
191
+ super().__init__()
192
+ assert_omp_single_thread()
193
+
194
+ if isinstance(actor_optim_factory, dict):
195
+ local = {}
196
+ local["config"] = {}
197
+ local["config"]["params"] = dict(locals().items())
198
+ local["config"]["type"] = "cql"
199
+ local["observation_shape"] = self._observation_shape
200
+ local["action_size"] = self._action_size
201
+ deserialized_config = LearnableConfigWithShape.deserialize_from_dict(local)
202
+
203
+ self.logger.info("-- Desiarializing CQL parameters")
204
+ actor_optim_factory = deserialized_config.config.actor_optim_factory
205
+ critic_optim_factory = deserialized_config.config.critic_optim_factory
206
+ temp_optim_factory = deserialized_config.config.temp_optim_factory
207
+ alpha_optim_factory = deserialized_config.config.alpha_optim_factory
208
+ actor_encoder_factory = deserialized_config.config.actor_encoder_factory
209
+ critic_encoder_factory = deserialized_config.config.critic_encoder_factory
210
+ q_func_factory = deserialized_config.config.q_func_factory
211
+ observation_scaler = deserialized_config.config.observation_scaler
212
+ action_scaler = deserialized_config.config.action_scaler
213
+ reward_scaler = deserialized_config.config.reward_scaler
214
+ # non-model params
215
+ mdp_dataset_builder = MdpDatasetBuilder(**mdp_dataset_builder)
216
+
217
+ self.mdp_dataset_builder = mdp_dataset_builder
218
+ self.n_steps = n_steps
219
+
220
+ self.model = CQLConfig(
221
+ actor_learning_rate=actor_learning_rate,
222
+ critic_learning_rate=critic_learning_rate,
223
+ temp_learning_rate=temp_learning_rate,
224
+ alpha_learning_rate=alpha_learning_rate,
225
+ actor_optim_factory=actor_optim_factory,
226
+ critic_optim_factory=critic_optim_factory,
227
+ temp_optim_factory=temp_optim_factory,
228
+ alpha_optim_factory=alpha_optim_factory,
229
+ actor_encoder_factory=actor_encoder_factory,
230
+ critic_encoder_factory=critic_encoder_factory,
231
+ q_func_factory=q_func_factory,
232
+ batch_size=batch_size,
233
+ gamma=gamma,
234
+ tau=tau,
235
+ n_critics=n_critics,
236
+ initial_temperature=initial_temperature,
237
+ initial_alpha=initial_alpha,
238
+ alpha_threshold=alpha_threshold,
239
+ conservative_weight=conservative_weight,
240
+ n_action_samples=n_action_samples,
241
+ soft_q_backup=soft_q_backup,
242
+ observation_scaler=observation_scaler,
243
+ action_scaler=action_scaler,
244
+ reward_scaler=reward_scaler,
245
+ **params,
246
+ ).create(device=use_gpu)
247
+
248
+ # explicitly create the model's algorithm implementation at init stage
249
+ # despite the lazy on-fit init convention in d3rlpy a) to avoid serialization
250
+ # complications and b) to make model ready for prediction even before fitting
251
+ self.model.create_impl(observation_shape=self._observation_shape, action_size=self._action_size)
252
+
253
+ def _fit(
254
+ self,
255
+ log: SparkDataFrame,
256
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
257
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
258
+ ) -> None:
259
+ mdp_dataset: MDPDataset = self.mdp_dataset_builder.build(log)
260
+ self.model.fit(mdp_dataset, self.n_steps)
261
+
262
+ @staticmethod
263
+ def _predict_pairs_inner(
264
+ model: bytes,
265
+ user_idx: int,
266
+ items: np.ndarray,
267
+ ) -> PandasDataFrame:
268
+ user_item_pairs = PandasDataFrame({"user_idx": np.repeat(user_idx, len(items)), "item_idx": items})
269
+
270
+ # deserialize model policy and predict items relevance for the user
271
+ policy = CQL._deserialize_policy(model)
272
+ items_batch = user_item_pairs.to_numpy()
273
+ user_item_pairs["relevance"] = CQL._predict_relevance_with_policy(policy, items_batch)
274
+
275
+ return user_item_pairs
276
+
277
+ def _predict(
278
+ self,
279
+ log: SparkDataFrame, # noqa: ARG002
280
+ k: int, # noqa: ARG002
281
+ users: SparkDataFrame,
282
+ items: SparkDataFrame,
283
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
284
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
285
+ filter_seen_items: bool = True, # noqa: ARG002
286
+ ) -> SparkDataFrame:
287
+ available_items = items.toPandas()["item_idx"].values
288
+ policy_bytes = self._serialize_policy()
289
+
290
+ def grouped_map(log_slice: PandasDataFrame) -> PandasDataFrame:
291
+ return CQL._predict_pairs_inner(
292
+ model=policy_bytes,
293
+ user_idx=log_slice["user_idx"][0],
294
+ items=available_items,
295
+ )[["user_idx", "item_idx", "relevance"]]
296
+
297
+ # predict relevance for all available items and return them as is;
298
+ # `filter_seen_items` and top `k` params are ignored
299
+ self.logger.debug("Predict started")
300
+ rec_schema = get_schema(
301
+ query_column="user_idx",
302
+ item_column="item_idx",
303
+ rating_column="relevance",
304
+ has_timestamp=False,
305
+ )
306
+ return users.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
307
+
308
+ def _predict_pairs(
309
+ self,
310
+ pairs: SparkDataFrame,
311
+ log: Optional[SparkDataFrame] = None,
312
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
313
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
314
+ ) -> SparkDataFrame:
315
+ policy_bytes = self._serialize_policy()
316
+
317
+ def grouped_map(user_log: PandasDataFrame) -> PandasDataFrame:
318
+ return CQL._predict_pairs_inner(
319
+ model=policy_bytes,
320
+ user_idx=user_log["user_idx"][0],
321
+ items=np.array(user_log["item_idx_to_pred"][0]),
322
+ )[["user_idx", "item_idx", "relevance"]]
323
+
324
+ self.logger.debug("Calculate relevance for user-item pairs")
325
+ rec_schema = get_schema(
326
+ query_column="user_idx",
327
+ item_column="item_idx",
328
+ rating_column="relevance",
329
+ has_timestamp=False,
330
+ )
331
+ return (
332
+ pairs.groupBy("user_idx")
333
+ .agg(sf.collect_list("item_idx").alias("item_idx_to_pred"))
334
+ .join(log.select("user_idx").distinct(), on="user_idx", how="inner")
335
+ .groupby("user_idx")
336
+ .applyInPandas(grouped_map, rec_schema)
337
+ )
338
+
339
+ @property
340
+ def _init_args(self) -> dict[str, Any]:
341
+ return {
342
+ # non-model hyperparams
343
+ "mdp_dataset_builder": self.mdp_dataset_builder.init_args(),
344
+ "n_steps": self.n_steps,
345
+ # model internal hyperparams
346
+ **self._get_model_hyperparams(),
347
+ "use_gpu": self.model._impl.device,
348
+ }
349
+
350
+ def _save_model(self, path: str) -> None:
351
+ self.logger.info("-- Saving model to %s", path)
352
+ self.model.save_model(path)
353
+
354
+ def _load_model(self, path: str) -> None:
355
+ self.logger.info("-- Loading model from %s", path)
356
+ self.model.load_model(path)
357
+
358
+ def _get_model_hyperparams(self) -> dict[str, Any]:
359
+ """Get model hyperparams as dictionary.
360
+ NB: The code is taken from a `d3rlpy.base.save_config(logger)` method as
361
+ there's no method to just return such params without saving them.
362
+ """
363
+ assert self.model._impl is not None, IMPL_NOT_INITIALIZED_ERROR
364
+ config = LearnableConfigWithShape(
365
+ observation_shape=self.model.impl.observation_shape,
366
+ action_size=self.model.impl.action_size,
367
+ config=self.model.config,
368
+ )
369
+ config = config.serialize_to_dict()
370
+ config.update(config["config"]["params"])
371
+ for key_to_delete in ["observation_shape", "action_size", "config"]:
372
+ config.pop(key_to_delete)
373
+
374
+ return config
375
+
376
+ def _serialize_policy(self) -> bytes:
377
+ # store using temporary file and immediately read serialized version
378
+ with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
379
+ # noinspection PyProtectedMember
380
+ self.model.save_policy(tmp.name)
381
+ with open(tmp.name, "rb") as policy_file:
382
+ return policy_file.read()
383
+
384
+ @staticmethod
385
+ def _deserialize_policy(policy: bytes) -> torch.jit.ScriptModule:
386
+ with io.BytesIO(policy) as buffer:
387
+ return torch.jit.load(buffer, map_location=torch.device("cpu"))
388
+
389
+ @staticmethod
390
+ def _predict_relevance_with_policy(policy: torch.jit.ScriptModule, items: np.ndarray) -> np.ndarray:
391
+ items = torch.from_numpy(items).float().cpu()
392
+ with torch.no_grad():
393
+ return policy.forward(items).numpy()
394
+
395
+
396
+ class MdpDatasetBuilder:
397
+ r"""
398
+ Markov Decision Process Dataset builder.
399
+ This class transforms datasets with user logs, which is natural for recommender systems,
400
+ to datasets consisting of users' decision-making session logs, which is natural for RL methods.
401
+
402
+ Args:
403
+ top_k (int): the number of top user items to learn predicting.
404
+ action_randomization_scale (float): the scale of action randomization gaussian noise.
405
+ """
406
+
407
+ logger: logging.Logger
408
+ top_k: int
409
+ action_randomization_scale: float
410
+
411
+ def __init__(self, top_k: int, action_randomization_scale: float = 1e-3):
412
+ self.logger = logging.getLogger("replay")
413
+ self.top_k = top_k
414
+ # cannot set zero scale as then d3rlpy will treat transitions as discrete
415
+ assert action_randomization_scale > 0
416
+ self.action_randomization_scale = action_randomization_scale
417
+
418
+ def build(self, log: SparkDataFrame) -> MDPDataset:
419
+ """Builds and returns MDP dataset from users' log."""
420
+
421
+ start_time = timer()
422
+ # reward top-K watched movies with 1, the others - with 0
423
+ reward_condition = (
424
+ sf.row_number().over(Window.partitionBy("user_idx").orderBy([sf.desc("relevance"), sf.desc("timestamp")]))
425
+ <= self.top_k
426
+ )
427
+
428
+ # every user has his own episode (the latest item is defined as terminal)
429
+ terminal_condition = sf.row_number().over(Window.partitionBy("user_idx").orderBy(sf.desc("timestamp"))) == 1
430
+
431
+ user_logs = (
432
+ log.withColumn("reward", sf.when(reward_condition, sf.lit(1)).otherwise(sf.lit(0)))
433
+ .withColumn("terminal", sf.when(terminal_condition, sf.lit(1)).otherwise(sf.lit(0)))
434
+ .withColumn("action", sf.col("relevance").cast("float") + sf.randn() * self.action_randomization_scale)
435
+ .orderBy(["user_idx", "timestamp"], ascending=True)
436
+ .select(["user_idx", "item_idx", "action", "reward", "terminal"])
437
+ .toPandas()
438
+ )
439
+ train_dataset = MDPDataset(
440
+ observations=np.array(user_logs[["user_idx", "item_idx"]]),
441
+ actions=user_logs["action"].to_numpy()[:, None],
442
+ rewards=user_logs["reward"].to_numpy(),
443
+ terminals=user_logs["terminal"].to_numpy(),
444
+ )
445
+
446
+ prepare_time = timer() - start_time
447
+ self.logger.info("-- Building MDP dataset took %.2f seconds", prepare_time)
448
+ return train_dataset
449
+
450
+ def init_args(self):
451
+ return {
452
+ "top_k": self.top_k,
453
+ "action_randomization_scale": self.action_randomization_scale,
454
+ }