replay-rec 0.17.1rc0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. replay/__init__.py +2 -1
  2. replay/data/dataset.py +3 -2
  3. replay/data/dataset_utils/dataset_label_encoder.py +1 -0
  4. replay/data/nn/schema.py +5 -5
  5. replay/metrics/__init__.py +1 -0
  6. replay/models/als.py +1 -1
  7. replay/models/base_rec.py +7 -7
  8. replay/models/extensions/ann/index_inferers/nmslib_filter_index_inferer.py +3 -3
  9. replay/models/extensions/ann/index_inferers/nmslib_index_inferer.py +3 -3
  10. replay/models/nn/sequential/bert4rec/model.py +5 -112
  11. replay/models/nn/sequential/sasrec/model.py +8 -5
  12. replay/optimization/optuna_objective.py +1 -0
  13. replay/preprocessing/converter.py +1 -1
  14. replay/preprocessing/filters.py +19 -18
  15. replay/preprocessing/history_based_fp.py +5 -5
  16. replay/preprocessing/label_encoder.py +1 -0
  17. replay/scenarios/__init__.py +1 -0
  18. replay/splitters/last_n_splitter.py +1 -1
  19. replay/splitters/time_splitter.py +1 -1
  20. replay/splitters/two_stage_splitter.py +8 -6
  21. replay/utils/distributions.py +1 -0
  22. replay/utils/session_handler.py +3 -3
  23. replay/utils/spark_utils.py +2 -2
  24. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0.dist-info}/METADATA +12 -18
  25. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0.dist-info}/RECORD +27 -80
  26. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0.dist-info}/WHEEL +1 -1
  27. replay/experimental/__init__.py +0 -0
  28. replay/experimental/metrics/__init__.py +0 -61
  29. replay/experimental/metrics/base_metric.py +0 -601
  30. replay/experimental/metrics/coverage.py +0 -97
  31. replay/experimental/metrics/experiment.py +0 -175
  32. replay/experimental/metrics/hitrate.py +0 -26
  33. replay/experimental/metrics/map.py +0 -30
  34. replay/experimental/metrics/mrr.py +0 -18
  35. replay/experimental/metrics/ncis_precision.py +0 -31
  36. replay/experimental/metrics/ndcg.py +0 -49
  37. replay/experimental/metrics/precision.py +0 -22
  38. replay/experimental/metrics/recall.py +0 -25
  39. replay/experimental/metrics/rocauc.py +0 -49
  40. replay/experimental/metrics/surprisal.py +0 -90
  41. replay/experimental/metrics/unexpectedness.py +0 -76
  42. replay/experimental/models/__init__.py +0 -10
  43. replay/experimental/models/admm_slim.py +0 -205
  44. replay/experimental/models/base_neighbour_rec.py +0 -204
  45. replay/experimental/models/base_rec.py +0 -1271
  46. replay/experimental/models/base_torch_rec.py +0 -234
  47. replay/experimental/models/cql.py +0 -452
  48. replay/experimental/models/ddpg.py +0 -921
  49. replay/experimental/models/dt4rec/__init__.py +0 -0
  50. replay/experimental/models/dt4rec/dt4rec.py +0 -189
  51. replay/experimental/models/dt4rec/gpt1.py +0 -401
  52. replay/experimental/models/dt4rec/trainer.py +0 -127
  53. replay/experimental/models/dt4rec/utils.py +0 -265
  54. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  55. replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
  56. replay/experimental/models/implicit_wrap.py +0 -131
  57. replay/experimental/models/lightfm_wrap.py +0 -302
  58. replay/experimental/models/mult_vae.py +0 -331
  59. replay/experimental/models/neuromf.py +0 -405
  60. replay/experimental/models/scala_als.py +0 -296
  61. replay/experimental/nn/data/__init__.py +0 -1
  62. replay/experimental/nn/data/schema_builder.py +0 -55
  63. replay/experimental/preprocessing/__init__.py +0 -3
  64. replay/experimental/preprocessing/data_preparator.py +0 -838
  65. replay/experimental/preprocessing/padder.py +0 -229
  66. replay/experimental/preprocessing/sequence_generator.py +0 -208
  67. replay/experimental/scenarios/__init__.py +0 -1
  68. replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
  69. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
  70. replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -248
  71. replay/experimental/scenarios/obp_wrapper/utils.py +0 -87
  72. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  73. replay/experimental/scenarios/two_stages/reranker.py +0 -117
  74. replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
  75. replay/experimental/utils/__init__.py +0 -0
  76. replay/experimental/utils/logger.py +0 -24
  77. replay/experimental/utils/model_handler.py +0 -181
  78. replay/experimental/utils/session_handler.py +0 -44
  79. replay_rec-0.17.1rc0.dist-info/NOTICE +0 -41
  80. {replay_rec-0.17.1rc0.dist-info → replay_rec-0.18.0.dist-info}/LICENSE +0 -0
@@ -1,234 +0,0 @@
1
- from abc import abstractmethod
2
- from typing import Any, Dict, Optional
3
-
4
- import numpy as np
5
- import torch
6
- from torch import nn
7
- from torch.optim.lr_scheduler import ReduceLROnPlateau
8
- from torch.optim.optimizer import Optimizer
9
- from torch.utils.data import DataLoader
10
-
11
- from replay.data import get_schema
12
- from replay.experimental.models.base_rec import Recommender
13
- from replay.experimental.utils.session_handler import State
14
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
15
-
16
- if PYSPARK_AVAILABLE:
17
- from pyspark.sql import functions as sf
18
-
19
-
20
- class TorchRecommender(Recommender):
21
- """Base class for neural recommenders"""
22
-
23
- model: Any
24
- device: torch.device
25
-
26
- def __init__(self):
27
- self.logger.info("The model is neural network with non-distributed training")
28
- self.checkpoint_path = State().session.conf.get("spark.local.dir")
29
- self.device = State().device
30
-
31
- def _run_train_step(self, batch, optimizer):
32
- self.model.train()
33
- optimizer.zero_grad()
34
- model_result = self._batch_pass(batch, self.model)
35
- loss = self._loss(**model_result)
36
- loss.backward()
37
- optimizer.step()
38
- return loss.item()
39
-
40
- def _run_validation(self, valid_data_loader: DataLoader, epoch: int) -> float:
41
- self.model.eval()
42
- valid_loss = 0
43
- with torch.no_grad():
44
- for batch in valid_data_loader:
45
- model_result = self._batch_pass(batch, self.model)
46
- valid_loss += self._loss(**model_result)
47
- valid_loss /= len(valid_data_loader)
48
- valid_debug_message = f"""Epoch[{epoch}] validation
49
- average loss: {valid_loss:.5f}"""
50
- self.logger.debug(valid_debug_message)
51
- return valid_loss.item()
52
-
53
- def train(
54
- self,
55
- train_data_loader: DataLoader,
56
- valid_data_loader: DataLoader,
57
- optimizer: Optimizer,
58
- lr_scheduler: ReduceLROnPlateau,
59
- epochs: int,
60
- model_name: str,
61
- ) -> None:
62
- """
63
- Run training loop
64
- :param train_data_loader: data loader for training
65
- :param valid_data_loader: data loader for validation
66
- :param optimizer: optimizer
67
- :param lr_scheduler: scheduler used to decrease learning rate
68
- :param lr_scheduler: scheduler used to decrease learning rate
69
- :param epochs: num training epochs
70
- :param model_name: model name for checkpoint saving
71
- :return:
72
- """
73
- best_valid_loss = np.inf
74
- for epoch in range(epochs):
75
- for batch in train_data_loader:
76
- train_loss = self._run_train_step(batch, optimizer)
77
-
78
- train_debug_message = f"""Epoch[{epoch}] current loss:
79
- {train_loss:.5f}"""
80
- self.logger.debug(train_debug_message)
81
-
82
- valid_loss = self._run_validation(valid_data_loader, epoch)
83
- lr_scheduler.step(valid_loss)
84
-
85
- if valid_loss < best_valid_loss:
86
- best_checkpoint = "/".join(
87
- [
88
- self.checkpoint_path,
89
- f"/best_{model_name}_{epoch+1}_loss={valid_loss}.pt",
90
- ]
91
- )
92
- self._save_model(best_checkpoint)
93
- best_valid_loss = valid_loss
94
- self._load_model(best_checkpoint)
95
-
96
- @abstractmethod
97
- def _batch_pass(self, batch, model) -> Dict[str, Any]:
98
- """
99
- Apply model to a single batch.
100
-
101
- :param batch: data batch
102
- :param model: model object
103
- :return: dictionary used to calculate loss.
104
- """
105
-
106
- @abstractmethod
107
- def _loss(self, **kwargs) -> torch.Tensor:
108
- """
109
- Returns loss value
110
-
111
- :param **kwargs: dictionary used to calculate loss
112
- :return: 1x1 tensor
113
- """
114
-
115
- def _predict(
116
- self,
117
- log: SparkDataFrame,
118
- k: int,
119
- users: SparkDataFrame,
120
- items: SparkDataFrame,
121
- user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
122
- item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
123
- filter_seen_items: bool = True, # noqa: ARG002
124
- ) -> SparkDataFrame:
125
- items_consider_in_pred = items.toPandas()["item_idx"].values
126
- items_count = self._item_dim
127
- model = self.model.cpu()
128
- agg_fn = self._predict_by_user
129
-
130
- def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
131
- return agg_fn(pandas_df, model, items_consider_in_pred, k, items_count)[
132
- ["user_idx", "item_idx", "relevance"]
133
- ]
134
-
135
- self.logger.debug("Predict started")
136
- # do not apply map on cold users for MultVAE predict
137
- join_type = "inner" if str(self) == "MultVAE" else "left"
138
- rec_schema = get_schema(
139
- query_column="user_idx",
140
- item_column="item_idx",
141
- rating_column="relevance",
142
- has_timestamp=False,
143
- )
144
- recs = (
145
- users.join(log, how=join_type, on="user_idx")
146
- .select("user_idx", "item_idx")
147
- .groupby("user_idx")
148
- .applyInPandas(grouped_map, rec_schema)
149
- )
150
- return recs
151
-
152
- def _predict_pairs(
153
- self,
154
- pairs: SparkDataFrame,
155
- log: Optional[SparkDataFrame] = None,
156
- user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
157
- item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
158
- ) -> SparkDataFrame:
159
- items_count = self._item_dim
160
- model = self.model.cpu()
161
- agg_fn = self._predict_by_user_pairs
162
- users = pairs.select("user_idx").distinct()
163
-
164
- def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
165
- return agg_fn(pandas_df, model, items_count)[["user_idx", "item_idx", "relevance"]]
166
-
167
- self.logger.debug("Calculate relevance for user-item pairs")
168
- user_history = (
169
- users.join(log, how="inner", on="user_idx")
170
- .groupBy("user_idx")
171
- .agg(sf.collect_list("item_idx").alias("item_idx_history"))
172
- )
173
- user_pairs = pairs.groupBy("user_idx").agg(sf.collect_list("item_idx").alias("item_idx_to_pred"))
174
- full_df = user_pairs.join(user_history, on="user_idx", how="inner")
175
-
176
- rec_schema = get_schema(
177
- query_column="user_idx",
178
- item_column="item_idx",
179
- rating_column="relevance",
180
- has_timestamp=False,
181
- )
182
- recs = full_df.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
183
-
184
- return recs
185
-
186
- @staticmethod
187
- @abstractmethod
188
- def _predict_by_user(
189
- pandas_df: PandasDataFrame,
190
- model: nn.Module,
191
- items_np: np.ndarray,
192
- k: int,
193
- item_count: int,
194
- ) -> PandasDataFrame:
195
- """
196
- Calculate predictions.
197
-
198
- :param pandas_df: DataFrame with user-item interactions ``[user_idx, item_idx]``
199
- :param model: trained model
200
- :param items_np: items available for recommendations
201
- :param k: length of recommendation list
202
- :param item_count: total number of items
203
- :return: DataFrame ``[user_idx , item_idx , relevance]``
204
- """
205
-
206
- @staticmethod
207
- @abstractmethod
208
- def _predict_by_user_pairs(
209
- pandas_df: PandasDataFrame,
210
- model: nn.Module,
211
- item_count: int,
212
- ) -> PandasDataFrame:
213
- """
214
- Get relevance for provided pairs
215
-
216
- :param pandas_df: DataFrame with rated items and items that need prediction
217
- ``[user_idx, item_idx_history, item_idx_to_pred]``
218
- :param model: trained model
219
- :param item_count: total number of items
220
- :return: DataFrame ``[user_idx , item_idx , relevance]``
221
- """
222
-
223
- def load_model(self, path: str) -> None:
224
- """
225
- Load model from file
226
-
227
- :param path: path to model
228
- :return:
229
- """
230
- self.logger.debug("-- Loading model from file")
231
- self.model.load_state_dict(torch.load(path))
232
-
233
- def _save_model(self, path: str) -> None:
234
- torch.save(self.model.state_dict(), path)
@@ -1,452 +0,0 @@
1
- """
2
- Using CQL implementation from `d3rlpy` package.
3
- """
4
- import io
5
- import logging
6
- import tempfile
7
- import timeit
8
- from typing import Any, Dict, Optional, Union
9
-
10
- import numpy as np
11
- import torch
12
- from d3rlpy.algos import (
13
- CQL as CQL_d3rlpy, # noqa: N811
14
- CQLConfig,
15
- )
16
- from d3rlpy.base import LearnableConfigWithShape
17
- from d3rlpy.constants import IMPL_NOT_INITIALIZED_ERROR
18
- from d3rlpy.dataset import MDPDataset
19
- from d3rlpy.models.encoders import DefaultEncoderFactory, EncoderFactory
20
- from d3rlpy.models.optimizers import AdamFactory, OptimizerFactory
21
- from d3rlpy.models.q_functions import MeanQFunctionFactory, QFunctionFactory
22
- from d3rlpy.preprocessing import (
23
- ActionScaler,
24
- ObservationScaler,
25
- RewardScaler,
26
- )
27
-
28
- from replay.data import get_schema
29
- from replay.experimental.models.base_rec import Recommender
30
- from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
31
- from replay.utils.spark_utils import assert_omp_single_thread
32
-
33
- if PYSPARK_AVAILABLE:
34
- from pyspark.sql import (
35
- Window,
36
- functions as sf,
37
- )
38
-
39
- timer = timeit.default_timer
40
-
41
-
42
- class CQL(Recommender):
43
- """Conservative Q-Learning algorithm.
44
-
45
- CQL is a SAC-based data-driven deep reinforcement learning algorithm, which
46
- achieves state-of-the-art performance in offline RL problems.
47
-
48
- CQL mitigates overestimation error by minimizing action-values under the
49
- current policy and maximizing values under data distribution for
50
- underestimation issue.
51
-
52
- .. math::
53
- L(\\theta_i) = \\alpha\\, \\mathbb{E}_{s_t \\sim D}
54
- \\left[\\log{\\sum_a \\exp{Q_{\\theta_i}(s_t, a)}} -
55
- \\mathbb{E}_{a \\sim D} \\big[Q_{\\theta_i}(s_t, a)\\big] - \\tau\\right]
56
- + L_\\mathrm{SAC}(\\theta_i)
57
-
58
- where :math:`\alpha` is an automatically adjustable value via Lagrangian
59
- dual gradient descent and :math:`\tau` is a threshold value.
60
- If the action-value difference is smaller than :math:`\tau`, the
61
- :math:`\alpha` will become smaller.
62
- Otherwise, the :math:`\alpha` will become larger to aggressively penalize
63
- action-values.
64
-
65
- In continuous control, :math:`\\log{\\sum_a \\exp{Q(s, a)}}` is computed as
66
- follows.
67
-
68
- .. math::
69
- \\log{\\sum_a \\exp{Q(s, a)}} \\approx \\log{\\left(
70
- \\frac{1}{2N} \\sum_{a_i \\sim \\text{Unif}(a)}^N
71
- \\left[\\frac{\\exp{Q(s, a_i)}}{\\text{Unif}(a)}\\right]
72
- + \\frac{1}{2N} \\sum_{a_i \\sim \\pi_\\phi(a|s)}^N
73
- \\left[\\frac{\\exp{Q(s, a_i)}}{\\pi_\\phi(a_i|s)}\\right]\\right)}
74
-
75
- where :math:`N` is the number of sampled actions.
76
-
77
- An implementation of this algorithm is heavily based on the corresponding implementation
78
- in the d3rlpy library (see https://github.com/takuseno/d3rlpy/blob/master/d3rlpy/algos/cql.py)
79
-
80
- The rest of optimization is exactly same as :class:`d3rlpy.algos.SAC`.
81
-
82
- References:
83
- * `Kumar et al., Conservative Q-Learning for Offline Reinforcement
84
- Learning. <https://arxiv.org/abs/2006.04779>`_
85
-
86
- Args:
87
- mdp_dataset_builder (MdpDatasetBuilder): the MDP dataset builder from users' log.
88
- actor_learning_rate (float): learning rate for policy function.
89
- critic_learning_rate (float): learning rate for Q functions.
90
- temp_learning_rate (float): learning rate for temperature parameter of SAC.
91
- alpha_learning_rate (float): learning rate for :math:`\alpha`.
92
- actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
93
- optimizer factory for the actor.
94
- The available options are `[SGD, Adam or RMSprop]`.
95
- critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
96
- optimizer factory for the critic.
97
- The available options are `[SGD, Adam or RMSprop]`.
98
- temp_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
99
- optimizer factory for the temperature.
100
- The available options are `[SGD, Adam or RMSprop]`.
101
- alpha_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
102
- optimizer factory for :math:`\alpha`.
103
- The available options are `[SGD, Adam or RMSprop]`.
104
- actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
105
- encoder factory for the actor.
106
- The available options are `['pixel', 'dense', 'vector', 'default']`.
107
- See d3rlpy.models.encoders.EncoderFactory for details.
108
- critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
109
- encoder factory for the critic.
110
- The available options are `['pixel', 'dense', 'vector', 'default']`.
111
- See d3rlpy.models.encoders.EncoderFactory for details.
112
- q_func_factory (d3rlpy.models.q_functions.QFunctionFactory or str):
113
- Q function factory. The available options are `['mean', 'qr', 'iqn', 'fqf']`.
114
- See d3rlpy.models.q_functions.QFunctionFactory for details.
115
- batch_size (int): mini-batch size.
116
- n_steps (int): Number of training steps.
117
- gamma (float): discount factor.
118
- tau (float): target network synchronization coefficient.
119
- n_critics (int): the number of Q functions for ensemble.
120
- initial_temperature (float): initial temperature value.
121
- initial_alpha (float): initial :math:`\alpha` value.
122
- alpha_threshold (float): threshold value described as :math:`\tau`.
123
- conservative_weight (float): constant weight to scale conservative loss.
124
- n_action_samples (int): the number of sampled actions to compute
125
- :math:`\\log{\\sum_a \\exp{Q(s, a)}}`.
126
- soft_q_backup (bool): flag to use SAC-style backup.
127
- use_gpu (Union[int, str, bool]): device option.
128
- If the value is boolean and True, cuda:0 will be used.
129
- If the value is integer, cuda:<device> will be used.
130
- If the value is string in torch device style, the specified device will be used.
131
- observation_scaler (d3rlpy.preprocessing.Scaler or str): preprocessor.
132
- The available options are `['pixel', 'min_max', 'standard']`.
133
- action_scaler (d3rlpy.preprocessing.ActionScaler or str):
134
- action preprocessor. The available options are `['min_max']`.
135
- reward_scaler (d3rlpy.preprocessing.RewardScaler or str):
136
- reward preprocessor. The available options are
137
- `['clip', 'min_max', 'standard']`.
138
- impl (d3rlpy.algos.torch.cql_impl.CQLImpl): algorithm implementation.
139
- """
140
-
141
- mdp_dataset_builder: "MdpDatasetBuilder"
142
- model: CQL_d3rlpy
143
-
144
- can_predict_cold_users = True
145
-
146
- _observation_shape = (2,)
147
- _action_size = 1
148
-
149
- _search_space = {
150
- "actor_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
151
- "critic_learning_rate": {"type": "loguniform", "args": [3e-5, 3e-4]},
152
- "temp_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
153
- "alpha_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
154
- "gamma": {"type": "loguniform", "args": [0.9, 0.999]},
155
- "n_critics": {"type": "int", "args": [2, 4]},
156
- }
157
-
158
- def __init__(
159
- self,
160
- mdp_dataset_builder: "MdpDatasetBuilder",
161
- # CQL inner params
162
- actor_learning_rate: float = 1e-4,
163
- critic_learning_rate: float = 3e-4,
164
- temp_learning_rate: float = 1e-4,
165
- alpha_learning_rate: float = 1e-4,
166
- actor_optim_factory: OptimizerFactory = AdamFactory(),
167
- critic_optim_factory: OptimizerFactory = AdamFactory(),
168
- temp_optim_factory: OptimizerFactory = AdamFactory(),
169
- alpha_optim_factory: OptimizerFactory = AdamFactory(),
170
- actor_encoder_factory: EncoderFactory = DefaultEncoderFactory(),
171
- critic_encoder_factory: EncoderFactory = DefaultEncoderFactory(),
172
- q_func_factory: QFunctionFactory = MeanQFunctionFactory(),
173
- batch_size: int = 64,
174
- n_steps: int = 1,
175
- gamma: float = 0.99,
176
- tau: float = 0.005,
177
- n_critics: int = 2,
178
- initial_temperature: float = 1.0,
179
- initial_alpha: float = 1.0,
180
- alpha_threshold: float = 10.0,
181
- conservative_weight: float = 5.0,
182
- n_action_samples: int = 10,
183
- soft_q_backup: bool = False,
184
- use_gpu: Union[int, str, bool] = False,
185
- observation_scaler: ObservationScaler = None,
186
- action_scaler: ActionScaler = None,
187
- reward_scaler: RewardScaler = None,
188
- **params
189
- ):
190
- super().__init__()
191
- assert_omp_single_thread()
192
-
193
- if isinstance(actor_optim_factory, dict):
194
- local = {}
195
- local["config"] = {}
196
- local["config"]["params"] = dict(locals().items())
197
- local["config"]["type"] = "cql"
198
- local["observation_shape"] = self._observation_shape
199
- local["action_size"] = self._action_size
200
- deserialized_config = LearnableConfigWithShape.deserialize_from_dict(local)
201
-
202
- self.logger.info("-- Desiarializing CQL parameters")
203
- actor_optim_factory = deserialized_config.config.actor_optim_factory
204
- critic_optim_factory = deserialized_config.config.critic_optim_factory
205
- temp_optim_factory = deserialized_config.config.temp_optim_factory
206
- alpha_optim_factory = deserialized_config.config.alpha_optim_factory
207
- actor_encoder_factory = deserialized_config.config.actor_encoder_factory
208
- critic_encoder_factory = deserialized_config.config.critic_encoder_factory
209
- q_func_factory = deserialized_config.config.q_func_factory
210
- observation_scaler = deserialized_config.config.observation_scaler
211
- action_scaler = deserialized_config.config.action_scaler
212
- reward_scaler = deserialized_config.config.reward_scaler
213
- # non-model params
214
- mdp_dataset_builder = MdpDatasetBuilder(**mdp_dataset_builder)
215
-
216
- self.mdp_dataset_builder = mdp_dataset_builder
217
- self.n_steps = n_steps
218
-
219
- self.model = CQLConfig(
220
- actor_learning_rate=actor_learning_rate,
221
- critic_learning_rate=critic_learning_rate,
222
- temp_learning_rate=temp_learning_rate,
223
- alpha_learning_rate=alpha_learning_rate,
224
- actor_optim_factory=actor_optim_factory,
225
- critic_optim_factory=critic_optim_factory,
226
- temp_optim_factory=temp_optim_factory,
227
- alpha_optim_factory=alpha_optim_factory,
228
- actor_encoder_factory=actor_encoder_factory,
229
- critic_encoder_factory=critic_encoder_factory,
230
- q_func_factory=q_func_factory,
231
- batch_size=batch_size,
232
- gamma=gamma,
233
- tau=tau,
234
- n_critics=n_critics,
235
- initial_temperature=initial_temperature,
236
- initial_alpha=initial_alpha,
237
- alpha_threshold=alpha_threshold,
238
- conservative_weight=conservative_weight,
239
- n_action_samples=n_action_samples,
240
- soft_q_backup=soft_q_backup,
241
- observation_scaler=observation_scaler,
242
- action_scaler=action_scaler,
243
- reward_scaler=reward_scaler,
244
- **params
245
- ).create(device=use_gpu)
246
-
247
- # explicitly create the model's algorithm implementation at init stage
248
- # despite the lazy on-fit init convention in d3rlpy a) to avoid serialization
249
- # complications and b) to make model ready for prediction even before fitting
250
- self.model.create_impl(observation_shape=self._observation_shape, action_size=self._action_size)
251
-
252
- def _fit(
253
- self,
254
- log: SparkDataFrame,
255
- user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
256
- item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
257
- ) -> None:
258
- mdp_dataset: MDPDataset = self.mdp_dataset_builder.build(log)
259
- self.model.fit(mdp_dataset, self.n_steps)
260
-
261
- @staticmethod
262
- def _predict_pairs_inner(
263
- model: bytes,
264
- user_idx: int,
265
- items: np.ndarray,
266
- ) -> PandasDataFrame:
267
- user_item_pairs = PandasDataFrame({"user_idx": np.repeat(user_idx, len(items)), "item_idx": items})
268
-
269
- # deserialize model policy and predict items relevance for the user
270
- policy = CQL._deserialize_policy(model)
271
- items_batch = user_item_pairs.to_numpy()
272
- user_item_pairs["relevance"] = CQL._predict_relevance_with_policy(policy, items_batch)
273
-
274
- return user_item_pairs
275
-
276
- def _predict(
277
- self,
278
- log: SparkDataFrame, # noqa: ARG002
279
- k: int, # noqa: ARG002
280
- users: SparkDataFrame,
281
- items: SparkDataFrame,
282
- user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
283
- item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
284
- filter_seen_items: bool = True, # noqa: ARG002
285
- ) -> SparkDataFrame:
286
- available_items = items.toPandas()["item_idx"].values
287
- policy_bytes = self._serialize_policy()
288
-
289
- def grouped_map(log_slice: PandasDataFrame) -> PandasDataFrame:
290
- return CQL._predict_pairs_inner(
291
- model=policy_bytes,
292
- user_idx=log_slice["user_idx"][0],
293
- items=available_items,
294
- )[["user_idx", "item_idx", "relevance"]]
295
-
296
- # predict relevance for all available items and return them as is;
297
- # `filter_seen_items` and top `k` params are ignored
298
- self.logger.debug("Predict started")
299
- rec_schema = get_schema(
300
- query_column="user_idx",
301
- item_column="item_idx",
302
- rating_column="relevance",
303
- has_timestamp=False,
304
- )
305
- return users.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
306
-
307
- def _predict_pairs(
308
- self,
309
- pairs: SparkDataFrame,
310
- log: Optional[SparkDataFrame] = None,
311
- user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
312
- item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
313
- ) -> SparkDataFrame:
314
- policy_bytes = self._serialize_policy()
315
-
316
- def grouped_map(user_log: PandasDataFrame) -> PandasDataFrame:
317
- return CQL._predict_pairs_inner(
318
- model=policy_bytes,
319
- user_idx=user_log["user_idx"][0],
320
- items=np.array(user_log["item_idx_to_pred"][0]),
321
- )[["user_idx", "item_idx", "relevance"]]
322
-
323
- self.logger.debug("Calculate relevance for user-item pairs")
324
- rec_schema = get_schema(
325
- query_column="user_idx",
326
- item_column="item_idx",
327
- rating_column="relevance",
328
- has_timestamp=False,
329
- )
330
- return (
331
- pairs.groupBy("user_idx")
332
- .agg(sf.collect_list("item_idx").alias("item_idx_to_pred"))
333
- .join(log.select("user_idx").distinct(), on="user_idx", how="inner")
334
- .groupby("user_idx")
335
- .applyInPandas(grouped_map, rec_schema)
336
- )
337
-
338
- @property
339
- def _init_args(self) -> Dict[str, Any]:
340
- return {
341
- # non-model hyperparams
342
- "mdp_dataset_builder": self.mdp_dataset_builder.init_args(),
343
- "n_steps": self.n_steps,
344
- # model internal hyperparams
345
- **self._get_model_hyperparams(),
346
- "use_gpu": self.model._impl.device,
347
- }
348
-
349
- def _save_model(self, path: str) -> None:
350
- self.logger.info("-- Saving model to %s", path)
351
- self.model.save_model(path)
352
-
353
- def _load_model(self, path: str) -> None:
354
- self.logger.info("-- Loading model from %s", path)
355
- self.model.load_model(path)
356
-
357
- def _get_model_hyperparams(self) -> Dict[str, Any]:
358
- """Get model hyperparams as dictionary.
359
- NB: The code is taken from a `d3rlpy.base.save_config(logger)` method as
360
- there's no method to just return such params without saving them.
361
- """
362
- assert self.model._impl is not None, IMPL_NOT_INITIALIZED_ERROR
363
- config = LearnableConfigWithShape(
364
- observation_shape=self.model.impl.observation_shape,
365
- action_size=self.model.impl.action_size,
366
- config=self.model.config,
367
- )
368
- config = config.serialize_to_dict()
369
- config.update(config["config"]["params"])
370
- for key_to_delete in ["observation_shape", "action_size", "config"]:
371
- config.pop(key_to_delete)
372
-
373
- return config
374
-
375
- def _serialize_policy(self) -> bytes:
376
- # store using temporary file and immediately read serialized version
377
- with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
378
- # noinspection PyProtectedMember
379
- self.model.save_policy(tmp.name)
380
- with open(tmp.name, "rb") as policy_file:
381
- return policy_file.read()
382
-
383
- @staticmethod
384
- def _deserialize_policy(policy: bytes) -> torch.jit.ScriptModule:
385
- with io.BytesIO(policy) as buffer:
386
- return torch.jit.load(buffer, map_location=torch.device("cpu"))
387
-
388
- @staticmethod
389
- def _predict_relevance_with_policy(policy: torch.jit.ScriptModule, items: np.ndarray) -> np.ndarray:
390
- items = torch.from_numpy(items).float().cpu()
391
- with torch.no_grad():
392
- return policy.forward(items).numpy()
393
-
394
-
395
- class MdpDatasetBuilder:
396
- r"""
397
- Markov Decision Process Dataset builder.
398
- This class transforms datasets with user logs, which is natural for recommender systems,
399
- to datasets consisting of users' decision-making session logs, which is natural for RL methods.
400
-
401
- Args:
402
- top_k (int): the number of top user items to learn predicting.
403
- action_randomization_scale (float): the scale of action randomization gaussian noise.
404
- """
405
- logger: logging.Logger
406
- top_k: int
407
- action_randomization_scale: float
408
-
409
- def __init__(self, top_k: int, action_randomization_scale: float = 1e-3):
410
- self.logger = logging.getLogger("replay")
411
- self.top_k = top_k
412
- # cannot set zero scale as then d3rlpy will treat transitions as discrete
413
- assert action_randomization_scale > 0
414
- self.action_randomization_scale = action_randomization_scale
415
-
416
- def build(self, log: SparkDataFrame) -> MDPDataset:
417
- """Builds and returns MDP dataset from users' log."""
418
-
419
- start_time = timer()
420
- # reward top-K watched movies with 1, the others - with 0
421
- reward_condition = (
422
- sf.row_number().over(Window.partitionBy("user_idx").orderBy([sf.desc("relevance"), sf.desc("timestamp")]))
423
- <= self.top_k
424
- )
425
-
426
- # every user has his own episode (the latest item is defined as terminal)
427
- terminal_condition = sf.row_number().over(Window.partitionBy("user_idx").orderBy(sf.desc("timestamp"))) == 1
428
-
429
- user_logs = (
430
- log.withColumn("reward", sf.when(reward_condition, sf.lit(1)).otherwise(sf.lit(0)))
431
- .withColumn("terminal", sf.when(terminal_condition, sf.lit(1)).otherwise(sf.lit(0)))
432
- .withColumn("action", sf.col("relevance").cast("float") + sf.randn() * self.action_randomization_scale)
433
- .orderBy(["user_idx", "timestamp"], ascending=True)
434
- .select(["user_idx", "item_idx", "action", "reward", "terminal"])
435
- .toPandas()
436
- )
437
- train_dataset = MDPDataset(
438
- observations=np.array(user_logs[["user_idx", "item_idx"]]),
439
- actions=user_logs["action"].to_numpy()[:, None],
440
- rewards=user_logs["reward"].to_numpy(),
441
- terminals=user_logs["terminal"].to_numpy(),
442
- )
443
-
444
- prepare_time = timer() - start_time
445
- self.logger.info("-- Building MDP dataset took %.2f seconds", prepare_time)
446
- return train_dataset
447
-
448
- def init_args(self):
449
- return {
450
- "top_k": self.top_k,
451
- "action_randomization_scale": self.action_randomization_scale,
452
- }