replay-rec 0.20.1rc0__py3-none-any.whl → 0.20.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/nn/sequential_dataset.py +8 -2
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.3.dist-info}/METADATA +18 -12
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.3.dist-info}/RECORD +7 -62
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -603
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -50
- replay/experimental/models/admm_slim.py +0 -257
- replay/experimental/models/base_neighbour_rec.py +0 -200
- replay/experimental/models/base_rec.py +0 -1386
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -932
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -264
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -303
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -293
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.3.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.3.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.3.dist-info}/licenses/NOTICE +0 -0
|
@@ -1,234 +0,0 @@
|
|
|
1
|
-
from abc import abstractmethod
|
|
2
|
-
from typing import Any, Optional
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import torch
|
|
6
|
-
from torch import nn
|
|
7
|
-
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
8
|
-
from torch.optim.optimizer import Optimizer
|
|
9
|
-
from torch.utils.data import DataLoader
|
|
10
|
-
|
|
11
|
-
from replay.data import get_schema
|
|
12
|
-
from replay.experimental.models.base_rec import Recommender
|
|
13
|
-
from replay.experimental.utils.session_handler import State
|
|
14
|
-
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
15
|
-
|
|
16
|
-
if PYSPARK_AVAILABLE:
|
|
17
|
-
from pyspark.sql import functions as sf
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class TorchRecommender(Recommender):
|
|
21
|
-
"""Base class for neural recommenders"""
|
|
22
|
-
|
|
23
|
-
model: Any
|
|
24
|
-
device: torch.device
|
|
25
|
-
|
|
26
|
-
def __init__(self):
|
|
27
|
-
self.logger.info("The model is neural network with non-distributed training")
|
|
28
|
-
self.checkpoint_path = State().session.conf.get("spark.local.dir")
|
|
29
|
-
self.device = State().device
|
|
30
|
-
|
|
31
|
-
def _run_train_step(self, batch, optimizer):
|
|
32
|
-
self.model.train()
|
|
33
|
-
optimizer.zero_grad()
|
|
34
|
-
model_result = self._batch_pass(batch, self.model)
|
|
35
|
-
loss = self._loss(**model_result)
|
|
36
|
-
loss.backward()
|
|
37
|
-
optimizer.step()
|
|
38
|
-
return loss.item()
|
|
39
|
-
|
|
40
|
-
def _run_validation(self, valid_data_loader: DataLoader, epoch: int) -> float:
|
|
41
|
-
self.model.eval()
|
|
42
|
-
valid_loss = 0
|
|
43
|
-
with torch.no_grad():
|
|
44
|
-
for batch in valid_data_loader:
|
|
45
|
-
model_result = self._batch_pass(batch, self.model)
|
|
46
|
-
valid_loss += self._loss(**model_result)
|
|
47
|
-
valid_loss /= len(valid_data_loader)
|
|
48
|
-
valid_debug_message = f"""Epoch[{epoch}] validation
|
|
49
|
-
average loss: {valid_loss:.5f}"""
|
|
50
|
-
self.logger.debug(valid_debug_message)
|
|
51
|
-
return valid_loss.item()
|
|
52
|
-
|
|
53
|
-
def train(
|
|
54
|
-
self,
|
|
55
|
-
train_data_loader: DataLoader,
|
|
56
|
-
valid_data_loader: DataLoader,
|
|
57
|
-
optimizer: Optimizer,
|
|
58
|
-
lr_scheduler: ReduceLROnPlateau,
|
|
59
|
-
epochs: int,
|
|
60
|
-
model_name: str,
|
|
61
|
-
) -> None:
|
|
62
|
-
"""
|
|
63
|
-
Run training loop
|
|
64
|
-
:param train_data_loader: data loader for training
|
|
65
|
-
:param valid_data_loader: data loader for validation
|
|
66
|
-
:param optimizer: optimizer
|
|
67
|
-
:param lr_scheduler: scheduler used to decrease learning rate
|
|
68
|
-
:param lr_scheduler: scheduler used to decrease learning rate
|
|
69
|
-
:param epochs: num training epochs
|
|
70
|
-
:param model_name: model name for checkpoint saving
|
|
71
|
-
:return:
|
|
72
|
-
"""
|
|
73
|
-
best_valid_loss = np.inf
|
|
74
|
-
for epoch in range(epochs):
|
|
75
|
-
for batch in train_data_loader:
|
|
76
|
-
train_loss = self._run_train_step(batch, optimizer)
|
|
77
|
-
|
|
78
|
-
train_debug_message = f"""Epoch[{epoch}] current loss:
|
|
79
|
-
{train_loss:.5f}"""
|
|
80
|
-
self.logger.debug(train_debug_message)
|
|
81
|
-
|
|
82
|
-
valid_loss = self._run_validation(valid_data_loader, epoch)
|
|
83
|
-
lr_scheduler.step(valid_loss)
|
|
84
|
-
|
|
85
|
-
if valid_loss < best_valid_loss:
|
|
86
|
-
best_checkpoint = "/".join(
|
|
87
|
-
[
|
|
88
|
-
self.checkpoint_path,
|
|
89
|
-
f"/best_{model_name}_{epoch+1}_loss={valid_loss}.pt",
|
|
90
|
-
]
|
|
91
|
-
)
|
|
92
|
-
self._save_model(best_checkpoint)
|
|
93
|
-
best_valid_loss = valid_loss
|
|
94
|
-
self._load_model(best_checkpoint)
|
|
95
|
-
|
|
96
|
-
@abstractmethod
|
|
97
|
-
def _batch_pass(self, batch, model) -> dict[str, Any]:
|
|
98
|
-
"""
|
|
99
|
-
Apply model to a single batch.
|
|
100
|
-
|
|
101
|
-
:param batch: data batch
|
|
102
|
-
:param model: model object
|
|
103
|
-
:return: dictionary used to calculate loss.
|
|
104
|
-
"""
|
|
105
|
-
|
|
106
|
-
@abstractmethod
|
|
107
|
-
def _loss(self, **kwargs) -> torch.Tensor:
|
|
108
|
-
"""
|
|
109
|
-
Returns loss value
|
|
110
|
-
|
|
111
|
-
:param **kwargs: dictionary used to calculate loss
|
|
112
|
-
:return: 1x1 tensor
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
def _predict(
|
|
116
|
-
self,
|
|
117
|
-
log: SparkDataFrame,
|
|
118
|
-
k: int,
|
|
119
|
-
users: SparkDataFrame,
|
|
120
|
-
items: SparkDataFrame,
|
|
121
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
122
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
123
|
-
filter_seen_items: bool = True, # noqa: ARG002
|
|
124
|
-
) -> SparkDataFrame:
|
|
125
|
-
items_consider_in_pred = items.toPandas()["item_idx"].values
|
|
126
|
-
items_count = self._item_dim
|
|
127
|
-
model = self.model.cpu()
|
|
128
|
-
agg_fn = self._predict_by_user
|
|
129
|
-
|
|
130
|
-
def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
131
|
-
return agg_fn(pandas_df, model, items_consider_in_pred, k, items_count)[
|
|
132
|
-
["user_idx", "item_idx", "relevance"]
|
|
133
|
-
]
|
|
134
|
-
|
|
135
|
-
self.logger.debug("Predict started")
|
|
136
|
-
# do not apply map on cold users for MultVAE predict
|
|
137
|
-
join_type = "inner" if str(self) == "MultVAE" else "left"
|
|
138
|
-
rec_schema = get_schema(
|
|
139
|
-
query_column="user_idx",
|
|
140
|
-
item_column="item_idx",
|
|
141
|
-
rating_column="relevance",
|
|
142
|
-
has_timestamp=False,
|
|
143
|
-
)
|
|
144
|
-
recs = (
|
|
145
|
-
users.join(log, how=join_type, on="user_idx")
|
|
146
|
-
.select("user_idx", "item_idx")
|
|
147
|
-
.groupby("user_idx")
|
|
148
|
-
.applyInPandas(grouped_map, rec_schema)
|
|
149
|
-
)
|
|
150
|
-
return recs
|
|
151
|
-
|
|
152
|
-
def _predict_pairs(
|
|
153
|
-
self,
|
|
154
|
-
pairs: SparkDataFrame,
|
|
155
|
-
log: Optional[SparkDataFrame] = None,
|
|
156
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
157
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
158
|
-
) -> SparkDataFrame:
|
|
159
|
-
items_count = self._item_dim
|
|
160
|
-
model = self.model.cpu()
|
|
161
|
-
agg_fn = self._predict_by_user_pairs
|
|
162
|
-
users = pairs.select("user_idx").distinct()
|
|
163
|
-
|
|
164
|
-
def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
165
|
-
return agg_fn(pandas_df, model, items_count)[["user_idx", "item_idx", "relevance"]]
|
|
166
|
-
|
|
167
|
-
self.logger.debug("Calculate relevance for user-item pairs")
|
|
168
|
-
user_history = (
|
|
169
|
-
users.join(log, how="inner", on="user_idx")
|
|
170
|
-
.groupBy("user_idx")
|
|
171
|
-
.agg(sf.collect_list("item_idx").alias("item_idx_history"))
|
|
172
|
-
)
|
|
173
|
-
user_pairs = pairs.groupBy("user_idx").agg(sf.collect_list("item_idx").alias("item_idx_to_pred"))
|
|
174
|
-
full_df = user_pairs.join(user_history, on="user_idx", how="inner")
|
|
175
|
-
|
|
176
|
-
rec_schema = get_schema(
|
|
177
|
-
query_column="user_idx",
|
|
178
|
-
item_column="item_idx",
|
|
179
|
-
rating_column="relevance",
|
|
180
|
-
has_timestamp=False,
|
|
181
|
-
)
|
|
182
|
-
recs = full_df.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
|
|
183
|
-
|
|
184
|
-
return recs
|
|
185
|
-
|
|
186
|
-
@staticmethod
|
|
187
|
-
@abstractmethod
|
|
188
|
-
def _predict_by_user(
|
|
189
|
-
pandas_df: PandasDataFrame,
|
|
190
|
-
model: nn.Module,
|
|
191
|
-
items_np: np.ndarray,
|
|
192
|
-
k: int,
|
|
193
|
-
item_count: int,
|
|
194
|
-
) -> PandasDataFrame:
|
|
195
|
-
"""
|
|
196
|
-
Calculate predictions.
|
|
197
|
-
|
|
198
|
-
:param pandas_df: DataFrame with user-item interactions ``[user_idx, item_idx]``
|
|
199
|
-
:param model: trained model
|
|
200
|
-
:param items_np: items available for recommendations
|
|
201
|
-
:param k: length of recommendation list
|
|
202
|
-
:param item_count: total number of items
|
|
203
|
-
:return: DataFrame ``[user_idx , item_idx , relevance]``
|
|
204
|
-
"""
|
|
205
|
-
|
|
206
|
-
@staticmethod
|
|
207
|
-
@abstractmethod
|
|
208
|
-
def _predict_by_user_pairs(
|
|
209
|
-
pandas_df: PandasDataFrame,
|
|
210
|
-
model: nn.Module,
|
|
211
|
-
item_count: int,
|
|
212
|
-
) -> PandasDataFrame:
|
|
213
|
-
"""
|
|
214
|
-
Get relevance for provided pairs
|
|
215
|
-
|
|
216
|
-
:param pandas_df: DataFrame with rated items and items that need prediction
|
|
217
|
-
``[user_idx, item_idx_history, item_idx_to_pred]``
|
|
218
|
-
:param model: trained model
|
|
219
|
-
:param item_count: total number of items
|
|
220
|
-
:return: DataFrame ``[user_idx , item_idx , relevance]``
|
|
221
|
-
"""
|
|
222
|
-
|
|
223
|
-
def load_model(self, path: str) -> None:
|
|
224
|
-
"""
|
|
225
|
-
Load model from file
|
|
226
|
-
|
|
227
|
-
:param path: path to model
|
|
228
|
-
:return:
|
|
229
|
-
"""
|
|
230
|
-
self.logger.debug("-- Loading model from file")
|
|
231
|
-
self.model.load_state_dict(torch.load(path))
|
|
232
|
-
|
|
233
|
-
def _save_model(self, path: str) -> None:
|
|
234
|
-
torch.save(self.model.state_dict(), path)
|
|
@@ -1,454 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Using CQL implementation from `d3rlpy` package.
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import io
|
|
6
|
-
import logging
|
|
7
|
-
import tempfile
|
|
8
|
-
import timeit
|
|
9
|
-
from typing import Any, Optional, Union
|
|
10
|
-
|
|
11
|
-
import numpy as np
|
|
12
|
-
import torch
|
|
13
|
-
from d3rlpy.algos import (
|
|
14
|
-
CQL as CQL_d3rlpy, # noqa: N811
|
|
15
|
-
CQLConfig,
|
|
16
|
-
)
|
|
17
|
-
from d3rlpy.base import LearnableConfigWithShape
|
|
18
|
-
from d3rlpy.constants import IMPL_NOT_INITIALIZED_ERROR
|
|
19
|
-
from d3rlpy.dataset import MDPDataset
|
|
20
|
-
from d3rlpy.models.encoders import DefaultEncoderFactory, EncoderFactory
|
|
21
|
-
from d3rlpy.models.q_functions import MeanQFunctionFactory, QFunctionFactory
|
|
22
|
-
from d3rlpy.optimizers import AdamFactory, OptimizerFactory
|
|
23
|
-
from d3rlpy.preprocessing import (
|
|
24
|
-
ActionScaler,
|
|
25
|
-
ObservationScaler,
|
|
26
|
-
RewardScaler,
|
|
27
|
-
)
|
|
28
|
-
|
|
29
|
-
from replay.data import get_schema
|
|
30
|
-
from replay.experimental.models.base_rec import Recommender
|
|
31
|
-
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
32
|
-
from replay.utils.spark_utils import assert_omp_single_thread
|
|
33
|
-
|
|
34
|
-
if PYSPARK_AVAILABLE:
|
|
35
|
-
from pyspark.sql import (
|
|
36
|
-
Window,
|
|
37
|
-
functions as sf,
|
|
38
|
-
)
|
|
39
|
-
|
|
40
|
-
timer = timeit.default_timer
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
class CQL(Recommender):
|
|
44
|
-
"""Conservative Q-Learning algorithm.
|
|
45
|
-
|
|
46
|
-
CQL is a SAC-based data-driven deep reinforcement learning algorithm, which
|
|
47
|
-
achieves state-of-the-art performance in offline RL problems.
|
|
48
|
-
|
|
49
|
-
CQL mitigates overestimation error by minimizing action-values under the
|
|
50
|
-
current policy and maximizing values under data distribution for
|
|
51
|
-
underestimation issue.
|
|
52
|
-
|
|
53
|
-
.. math::
|
|
54
|
-
L(\\theta_i) = \\alpha\\, \\mathbb{E}_{s_t \\sim D}
|
|
55
|
-
\\left[\\log{\\sum_a \\exp{Q_{\\theta_i}(s_t, a)}} -
|
|
56
|
-
\\mathbb{E}_{a \\sim D} \\big[Q_{\\theta_i}(s_t, a)\\big] - \\tau\\right]
|
|
57
|
-
+ L_\\mathrm{SAC}(\\theta_i)
|
|
58
|
-
|
|
59
|
-
where :math:`\alpha` is an automatically adjustable value via Lagrangian
|
|
60
|
-
dual gradient descent and :math:`\tau` is a threshold value.
|
|
61
|
-
If the action-value difference is smaller than :math:`\tau`, the
|
|
62
|
-
:math:`\alpha` will become smaller.
|
|
63
|
-
Otherwise, the :math:`\alpha` will become larger to aggressively penalize
|
|
64
|
-
action-values.
|
|
65
|
-
|
|
66
|
-
In continuous control, :math:`\\log{\\sum_a \\exp{Q(s, a)}}` is computed as
|
|
67
|
-
follows.
|
|
68
|
-
|
|
69
|
-
.. math::
|
|
70
|
-
\\log{\\sum_a \\exp{Q(s, a)}} \\approx \\log{\\left(
|
|
71
|
-
\\frac{1}{2N} \\sum_{a_i \\sim \\text{Unif}(a)}^N
|
|
72
|
-
\\left[\\frac{\\exp{Q(s, a_i)}}{\\text{Unif}(a)}\\right]
|
|
73
|
-
+ \\frac{1}{2N} \\sum_{a_i \\sim \\pi_\\phi(a|s)}^N
|
|
74
|
-
\\left[\\frac{\\exp{Q(s, a_i)}}{\\pi_\\phi(a_i|s)}\\right]\\right)}
|
|
75
|
-
|
|
76
|
-
where :math:`N` is the number of sampled actions.
|
|
77
|
-
|
|
78
|
-
An implementation of this algorithm is heavily based on the corresponding implementation
|
|
79
|
-
in the d3rlpy library (see https://github.com/takuseno/d3rlpy/blob/master/d3rlpy/algos/cql.py)
|
|
80
|
-
|
|
81
|
-
The rest of optimization is exactly same as :class:`d3rlpy.algos.SAC`.
|
|
82
|
-
|
|
83
|
-
References:
|
|
84
|
-
* `Kumar et al., Conservative Q-Learning for Offline Reinforcement
|
|
85
|
-
Learning. <https://arxiv.org/abs/2006.04779>`_
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
mdp_dataset_builder (MdpDatasetBuilder): the MDP dataset builder from users' log.
|
|
89
|
-
actor_learning_rate (float): learning rate for policy function.
|
|
90
|
-
critic_learning_rate (float): learning rate for Q functions.
|
|
91
|
-
temp_learning_rate (float): learning rate for temperature parameter of SAC.
|
|
92
|
-
alpha_learning_rate (float): learning rate for :math:`\alpha`.
|
|
93
|
-
actor_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
|
|
94
|
-
optimizer factory for the actor.
|
|
95
|
-
The available options are `[SGD, Adam or RMSprop]`.
|
|
96
|
-
critic_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
|
|
97
|
-
optimizer factory for the critic.
|
|
98
|
-
The available options are `[SGD, Adam or RMSprop]`.
|
|
99
|
-
temp_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
|
|
100
|
-
optimizer factory for the temperature.
|
|
101
|
-
The available options are `[SGD, Adam or RMSprop]`.
|
|
102
|
-
alpha_optim_factory (d3rlpy.models.optimizers.OptimizerFactory):
|
|
103
|
-
optimizer factory for :math:`\alpha`.
|
|
104
|
-
The available options are `[SGD, Adam or RMSprop]`.
|
|
105
|
-
actor_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
|
|
106
|
-
encoder factory for the actor.
|
|
107
|
-
The available options are `['pixel', 'dense', 'vector', 'default']`.
|
|
108
|
-
See d3rlpy.models.encoders.EncoderFactory for details.
|
|
109
|
-
critic_encoder_factory (d3rlpy.models.encoders.EncoderFactory or str):
|
|
110
|
-
encoder factory for the critic.
|
|
111
|
-
The available options are `['pixel', 'dense', 'vector', 'default']`.
|
|
112
|
-
See d3rlpy.models.encoders.EncoderFactory for details.
|
|
113
|
-
q_func_factory (d3rlpy.models.q_functions.QFunctionFactory or str):
|
|
114
|
-
Q function factory. The available options are `['mean', 'qr', 'iqn', 'fqf']`.
|
|
115
|
-
See d3rlpy.models.q_functions.QFunctionFactory for details.
|
|
116
|
-
batch_size (int): mini-batch size.
|
|
117
|
-
n_steps (int): Number of training steps.
|
|
118
|
-
gamma (float): discount factor.
|
|
119
|
-
tau (float): target network synchronization coefficient.
|
|
120
|
-
n_critics (int): the number of Q functions for ensemble.
|
|
121
|
-
initial_temperature (float): initial temperature value.
|
|
122
|
-
initial_alpha (float): initial :math:`\alpha` value.
|
|
123
|
-
alpha_threshold (float): threshold value described as :math:`\tau`.
|
|
124
|
-
conservative_weight (float): constant weight to scale conservative loss.
|
|
125
|
-
n_action_samples (int): the number of sampled actions to compute
|
|
126
|
-
:math:`\\log{\\sum_a \\exp{Q(s, a)}}`.
|
|
127
|
-
soft_q_backup (bool): flag to use SAC-style backup.
|
|
128
|
-
use_gpu (Union[int, str, bool]): device option.
|
|
129
|
-
If the value is boolean and True, cuda:0 will be used.
|
|
130
|
-
If the value is integer, cuda:<device> will be used.
|
|
131
|
-
If the value is string in torch device style, the specified device will be used.
|
|
132
|
-
observation_scaler (d3rlpy.preprocessing.Scaler or str): preprocessor.
|
|
133
|
-
The available options are `['pixel', 'min_max', 'standard']`.
|
|
134
|
-
action_scaler (d3rlpy.preprocessing.ActionScaler or str):
|
|
135
|
-
action preprocessor. The available options are `['min_max']`.
|
|
136
|
-
reward_scaler (d3rlpy.preprocessing.RewardScaler or str):
|
|
137
|
-
reward preprocessor. The available options are
|
|
138
|
-
`['clip', 'min_max', 'standard']`.
|
|
139
|
-
impl (d3rlpy.algos.torch.cql_impl.CQLImpl): algorithm implementation.
|
|
140
|
-
"""
|
|
141
|
-
|
|
142
|
-
mdp_dataset_builder: "MdpDatasetBuilder"
|
|
143
|
-
model: CQL_d3rlpy
|
|
144
|
-
|
|
145
|
-
can_predict_cold_users = True
|
|
146
|
-
|
|
147
|
-
_observation_shape = (2,)
|
|
148
|
-
_action_size = 1
|
|
149
|
-
|
|
150
|
-
_search_space = {
|
|
151
|
-
"actor_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
|
|
152
|
-
"critic_learning_rate": {"type": "loguniform", "args": [3e-5, 3e-4]},
|
|
153
|
-
"temp_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
|
|
154
|
-
"alpha_learning_rate": {"type": "loguniform", "args": [1e-5, 1e-3]},
|
|
155
|
-
"gamma": {"type": "loguniform", "args": [0.9, 0.999]},
|
|
156
|
-
"n_critics": {"type": "int", "args": [2, 4]},
|
|
157
|
-
}
|
|
158
|
-
|
|
159
|
-
def __init__(
|
|
160
|
-
self,
|
|
161
|
-
mdp_dataset_builder: "MdpDatasetBuilder",
|
|
162
|
-
# CQL inner params
|
|
163
|
-
actor_learning_rate: float = 1e-4,
|
|
164
|
-
critic_learning_rate: float = 3e-4,
|
|
165
|
-
temp_learning_rate: float = 1e-4,
|
|
166
|
-
alpha_learning_rate: float = 1e-4,
|
|
167
|
-
actor_optim_factory: OptimizerFactory = AdamFactory(),
|
|
168
|
-
critic_optim_factory: OptimizerFactory = AdamFactory(),
|
|
169
|
-
temp_optim_factory: OptimizerFactory = AdamFactory(),
|
|
170
|
-
alpha_optim_factory: OptimizerFactory = AdamFactory(),
|
|
171
|
-
actor_encoder_factory: EncoderFactory = DefaultEncoderFactory(),
|
|
172
|
-
critic_encoder_factory: EncoderFactory = DefaultEncoderFactory(),
|
|
173
|
-
q_func_factory: QFunctionFactory = MeanQFunctionFactory(),
|
|
174
|
-
batch_size: int = 64,
|
|
175
|
-
n_steps: int = 1,
|
|
176
|
-
gamma: float = 0.99,
|
|
177
|
-
tau: float = 0.005,
|
|
178
|
-
n_critics: int = 2,
|
|
179
|
-
initial_temperature: float = 1.0,
|
|
180
|
-
initial_alpha: float = 1.0,
|
|
181
|
-
alpha_threshold: float = 10.0,
|
|
182
|
-
conservative_weight: float = 5.0,
|
|
183
|
-
n_action_samples: int = 10,
|
|
184
|
-
soft_q_backup: bool = False,
|
|
185
|
-
use_gpu: Union[int, str, bool] = False,
|
|
186
|
-
observation_scaler: ObservationScaler = None,
|
|
187
|
-
action_scaler: ActionScaler = None,
|
|
188
|
-
reward_scaler: RewardScaler = None,
|
|
189
|
-
**params,
|
|
190
|
-
):
|
|
191
|
-
super().__init__()
|
|
192
|
-
assert_omp_single_thread()
|
|
193
|
-
|
|
194
|
-
if isinstance(actor_optim_factory, dict):
|
|
195
|
-
local = {}
|
|
196
|
-
local["config"] = {}
|
|
197
|
-
local["config"]["params"] = dict(locals().items())
|
|
198
|
-
local["config"]["type"] = "cql"
|
|
199
|
-
local["observation_shape"] = self._observation_shape
|
|
200
|
-
local["action_size"] = self._action_size
|
|
201
|
-
deserialized_config = LearnableConfigWithShape.deserialize_from_dict(local)
|
|
202
|
-
|
|
203
|
-
self.logger.info("-- Desiarializing CQL parameters")
|
|
204
|
-
actor_optim_factory = deserialized_config.config.actor_optim_factory
|
|
205
|
-
critic_optim_factory = deserialized_config.config.critic_optim_factory
|
|
206
|
-
temp_optim_factory = deserialized_config.config.temp_optim_factory
|
|
207
|
-
alpha_optim_factory = deserialized_config.config.alpha_optim_factory
|
|
208
|
-
actor_encoder_factory = deserialized_config.config.actor_encoder_factory
|
|
209
|
-
critic_encoder_factory = deserialized_config.config.critic_encoder_factory
|
|
210
|
-
q_func_factory = deserialized_config.config.q_func_factory
|
|
211
|
-
observation_scaler = deserialized_config.config.observation_scaler
|
|
212
|
-
action_scaler = deserialized_config.config.action_scaler
|
|
213
|
-
reward_scaler = deserialized_config.config.reward_scaler
|
|
214
|
-
# non-model params
|
|
215
|
-
mdp_dataset_builder = MdpDatasetBuilder(**mdp_dataset_builder)
|
|
216
|
-
|
|
217
|
-
self.mdp_dataset_builder = mdp_dataset_builder
|
|
218
|
-
self.n_steps = n_steps
|
|
219
|
-
|
|
220
|
-
self.model = CQLConfig(
|
|
221
|
-
actor_learning_rate=actor_learning_rate,
|
|
222
|
-
critic_learning_rate=critic_learning_rate,
|
|
223
|
-
temp_learning_rate=temp_learning_rate,
|
|
224
|
-
alpha_learning_rate=alpha_learning_rate,
|
|
225
|
-
actor_optim_factory=actor_optim_factory,
|
|
226
|
-
critic_optim_factory=critic_optim_factory,
|
|
227
|
-
temp_optim_factory=temp_optim_factory,
|
|
228
|
-
alpha_optim_factory=alpha_optim_factory,
|
|
229
|
-
actor_encoder_factory=actor_encoder_factory,
|
|
230
|
-
critic_encoder_factory=critic_encoder_factory,
|
|
231
|
-
q_func_factory=q_func_factory,
|
|
232
|
-
batch_size=batch_size,
|
|
233
|
-
gamma=gamma,
|
|
234
|
-
tau=tau,
|
|
235
|
-
n_critics=n_critics,
|
|
236
|
-
initial_temperature=initial_temperature,
|
|
237
|
-
initial_alpha=initial_alpha,
|
|
238
|
-
alpha_threshold=alpha_threshold,
|
|
239
|
-
conservative_weight=conservative_weight,
|
|
240
|
-
n_action_samples=n_action_samples,
|
|
241
|
-
soft_q_backup=soft_q_backup,
|
|
242
|
-
observation_scaler=observation_scaler,
|
|
243
|
-
action_scaler=action_scaler,
|
|
244
|
-
reward_scaler=reward_scaler,
|
|
245
|
-
**params,
|
|
246
|
-
).create(device=use_gpu)
|
|
247
|
-
|
|
248
|
-
# explicitly create the model's algorithm implementation at init stage
|
|
249
|
-
# despite the lazy on-fit init convention in d3rlpy a) to avoid serialization
|
|
250
|
-
# complications and b) to make model ready for prediction even before fitting
|
|
251
|
-
self.model.create_impl(observation_shape=self._observation_shape, action_size=self._action_size)
|
|
252
|
-
|
|
253
|
-
def _fit(
|
|
254
|
-
self,
|
|
255
|
-
log: SparkDataFrame,
|
|
256
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
257
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
258
|
-
) -> None:
|
|
259
|
-
mdp_dataset: MDPDataset = self.mdp_dataset_builder.build(log)
|
|
260
|
-
self.model.fit(mdp_dataset, self.n_steps)
|
|
261
|
-
|
|
262
|
-
@staticmethod
|
|
263
|
-
def _predict_pairs_inner(
|
|
264
|
-
model: bytes,
|
|
265
|
-
user_idx: int,
|
|
266
|
-
items: np.ndarray,
|
|
267
|
-
) -> PandasDataFrame:
|
|
268
|
-
user_item_pairs = PandasDataFrame({"user_idx": np.repeat(user_idx, len(items)), "item_idx": items})
|
|
269
|
-
|
|
270
|
-
# deserialize model policy and predict items relevance for the user
|
|
271
|
-
policy = CQL._deserialize_policy(model)
|
|
272
|
-
items_batch = user_item_pairs.to_numpy()
|
|
273
|
-
user_item_pairs["relevance"] = CQL._predict_relevance_with_policy(policy, items_batch)
|
|
274
|
-
|
|
275
|
-
return user_item_pairs
|
|
276
|
-
|
|
277
|
-
def _predict(
|
|
278
|
-
self,
|
|
279
|
-
log: SparkDataFrame, # noqa: ARG002
|
|
280
|
-
k: int, # noqa: ARG002
|
|
281
|
-
users: SparkDataFrame,
|
|
282
|
-
items: SparkDataFrame,
|
|
283
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
284
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
285
|
-
filter_seen_items: bool = True, # noqa: ARG002
|
|
286
|
-
) -> SparkDataFrame:
|
|
287
|
-
available_items = items.toPandas()["item_idx"].values
|
|
288
|
-
policy_bytes = self._serialize_policy()
|
|
289
|
-
|
|
290
|
-
def grouped_map(log_slice: PandasDataFrame) -> PandasDataFrame:
|
|
291
|
-
return CQL._predict_pairs_inner(
|
|
292
|
-
model=policy_bytes,
|
|
293
|
-
user_idx=log_slice["user_idx"][0],
|
|
294
|
-
items=available_items,
|
|
295
|
-
)[["user_idx", "item_idx", "relevance"]]
|
|
296
|
-
|
|
297
|
-
# predict relevance for all available items and return them as is;
|
|
298
|
-
# `filter_seen_items` and top `k` params are ignored
|
|
299
|
-
self.logger.debug("Predict started")
|
|
300
|
-
rec_schema = get_schema(
|
|
301
|
-
query_column="user_idx",
|
|
302
|
-
item_column="item_idx",
|
|
303
|
-
rating_column="relevance",
|
|
304
|
-
has_timestamp=False,
|
|
305
|
-
)
|
|
306
|
-
return users.groupby("user_idx").applyInPandas(grouped_map, rec_schema)
|
|
307
|
-
|
|
308
|
-
def _predict_pairs(
|
|
309
|
-
self,
|
|
310
|
-
pairs: SparkDataFrame,
|
|
311
|
-
log: Optional[SparkDataFrame] = None,
|
|
312
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
313
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
314
|
-
) -> SparkDataFrame:
|
|
315
|
-
policy_bytes = self._serialize_policy()
|
|
316
|
-
|
|
317
|
-
def grouped_map(user_log: PandasDataFrame) -> PandasDataFrame:
|
|
318
|
-
return CQL._predict_pairs_inner(
|
|
319
|
-
model=policy_bytes,
|
|
320
|
-
user_idx=user_log["user_idx"][0],
|
|
321
|
-
items=np.array(user_log["item_idx_to_pred"][0]),
|
|
322
|
-
)[["user_idx", "item_idx", "relevance"]]
|
|
323
|
-
|
|
324
|
-
self.logger.debug("Calculate relevance for user-item pairs")
|
|
325
|
-
rec_schema = get_schema(
|
|
326
|
-
query_column="user_idx",
|
|
327
|
-
item_column="item_idx",
|
|
328
|
-
rating_column="relevance",
|
|
329
|
-
has_timestamp=False,
|
|
330
|
-
)
|
|
331
|
-
return (
|
|
332
|
-
pairs.groupBy("user_idx")
|
|
333
|
-
.agg(sf.collect_list("item_idx").alias("item_idx_to_pred"))
|
|
334
|
-
.join(log.select("user_idx").distinct(), on="user_idx", how="inner")
|
|
335
|
-
.groupby("user_idx")
|
|
336
|
-
.applyInPandas(grouped_map, rec_schema)
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
@property
|
|
340
|
-
def _init_args(self) -> dict[str, Any]:
|
|
341
|
-
return {
|
|
342
|
-
# non-model hyperparams
|
|
343
|
-
"mdp_dataset_builder": self.mdp_dataset_builder.init_args(),
|
|
344
|
-
"n_steps": self.n_steps,
|
|
345
|
-
# model internal hyperparams
|
|
346
|
-
**self._get_model_hyperparams(),
|
|
347
|
-
"use_gpu": self.model._impl.device,
|
|
348
|
-
}
|
|
349
|
-
|
|
350
|
-
def _save_model(self, path: str) -> None:
|
|
351
|
-
self.logger.info("-- Saving model to %s", path)
|
|
352
|
-
self.model.save_model(path)
|
|
353
|
-
|
|
354
|
-
def _load_model(self, path: str) -> None:
|
|
355
|
-
self.logger.info("-- Loading model from %s", path)
|
|
356
|
-
self.model.load_model(path)
|
|
357
|
-
|
|
358
|
-
def _get_model_hyperparams(self) -> dict[str, Any]:
|
|
359
|
-
"""Get model hyperparams as dictionary.
|
|
360
|
-
NB: The code is taken from a `d3rlpy.base.save_config(logger)` method as
|
|
361
|
-
there's no method to just return such params without saving them.
|
|
362
|
-
"""
|
|
363
|
-
assert self.model._impl is not None, IMPL_NOT_INITIALIZED_ERROR
|
|
364
|
-
config = LearnableConfigWithShape(
|
|
365
|
-
observation_shape=self.model.impl.observation_shape,
|
|
366
|
-
action_size=self.model.impl.action_size,
|
|
367
|
-
config=self.model.config,
|
|
368
|
-
)
|
|
369
|
-
config = config.serialize_to_dict()
|
|
370
|
-
config.update(config["config"]["params"])
|
|
371
|
-
for key_to_delete in ["observation_shape", "action_size", "config"]:
|
|
372
|
-
config.pop(key_to_delete)
|
|
373
|
-
|
|
374
|
-
return config
|
|
375
|
-
|
|
376
|
-
def _serialize_policy(self) -> bytes:
|
|
377
|
-
# store using temporary file and immediately read serialized version
|
|
378
|
-
with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
|
|
379
|
-
# noinspection PyProtectedMember
|
|
380
|
-
self.model.save_policy(tmp.name)
|
|
381
|
-
with open(tmp.name, "rb") as policy_file:
|
|
382
|
-
return policy_file.read()
|
|
383
|
-
|
|
384
|
-
@staticmethod
|
|
385
|
-
def _deserialize_policy(policy: bytes) -> torch.jit.ScriptModule:
|
|
386
|
-
with io.BytesIO(policy) as buffer:
|
|
387
|
-
return torch.jit.load(buffer, map_location=torch.device("cpu"))
|
|
388
|
-
|
|
389
|
-
@staticmethod
|
|
390
|
-
def _predict_relevance_with_policy(policy: torch.jit.ScriptModule, items: np.ndarray) -> np.ndarray:
|
|
391
|
-
items = torch.from_numpy(items).float().cpu()
|
|
392
|
-
with torch.no_grad():
|
|
393
|
-
return policy.forward(items).numpy()
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
class MdpDatasetBuilder:
|
|
397
|
-
r"""
|
|
398
|
-
Markov Decision Process Dataset builder.
|
|
399
|
-
This class transforms datasets with user logs, which is natural for recommender systems,
|
|
400
|
-
to datasets consisting of users' decision-making session logs, which is natural for RL methods.
|
|
401
|
-
|
|
402
|
-
Args:
|
|
403
|
-
top_k (int): the number of top user items to learn predicting.
|
|
404
|
-
action_randomization_scale (float): the scale of action randomization gaussian noise.
|
|
405
|
-
"""
|
|
406
|
-
|
|
407
|
-
logger: logging.Logger
|
|
408
|
-
top_k: int
|
|
409
|
-
action_randomization_scale: float
|
|
410
|
-
|
|
411
|
-
def __init__(self, top_k: int, action_randomization_scale: float = 1e-3):
|
|
412
|
-
self.logger = logging.getLogger("replay")
|
|
413
|
-
self.top_k = top_k
|
|
414
|
-
# cannot set zero scale as then d3rlpy will treat transitions as discrete
|
|
415
|
-
assert action_randomization_scale > 0
|
|
416
|
-
self.action_randomization_scale = action_randomization_scale
|
|
417
|
-
|
|
418
|
-
def build(self, log: SparkDataFrame) -> MDPDataset:
|
|
419
|
-
"""Builds and returns MDP dataset from users' log."""
|
|
420
|
-
|
|
421
|
-
start_time = timer()
|
|
422
|
-
# reward top-K watched movies with 1, the others - with 0
|
|
423
|
-
reward_condition = (
|
|
424
|
-
sf.row_number().over(Window.partitionBy("user_idx").orderBy([sf.desc("relevance"), sf.desc("timestamp")]))
|
|
425
|
-
<= self.top_k
|
|
426
|
-
)
|
|
427
|
-
|
|
428
|
-
# every user has his own episode (the latest item is defined as terminal)
|
|
429
|
-
terminal_condition = sf.row_number().over(Window.partitionBy("user_idx").orderBy(sf.desc("timestamp"))) == 1
|
|
430
|
-
|
|
431
|
-
user_logs = (
|
|
432
|
-
log.withColumn("reward", sf.when(reward_condition, sf.lit(1)).otherwise(sf.lit(0)))
|
|
433
|
-
.withColumn("terminal", sf.when(terminal_condition, sf.lit(1)).otherwise(sf.lit(0)))
|
|
434
|
-
.withColumn("action", sf.col("relevance").cast("float") + sf.randn() * self.action_randomization_scale)
|
|
435
|
-
.orderBy(["user_idx", "timestamp"], ascending=True)
|
|
436
|
-
.select(["user_idx", "item_idx", "action", "reward", "terminal"])
|
|
437
|
-
.toPandas()
|
|
438
|
-
)
|
|
439
|
-
train_dataset = MDPDataset(
|
|
440
|
-
observations=np.array(user_logs[["user_idx", "item_idx"]]),
|
|
441
|
-
actions=user_logs["action"].to_numpy()[:, None],
|
|
442
|
-
rewards=user_logs["reward"].to_numpy(),
|
|
443
|
-
terminals=user_logs["terminal"].to_numpy(),
|
|
444
|
-
)
|
|
445
|
-
|
|
446
|
-
prepare_time = timer() - start_time
|
|
447
|
-
self.logger.info("-- Building MDP dataset took %.2f seconds", prepare_time)
|
|
448
|
-
return train_dataset
|
|
449
|
-
|
|
450
|
-
def init_args(self):
|
|
451
|
-
return {
|
|
452
|
-
"top_k": self.top_k,
|
|
453
|
-
"action_randomization_scale": self.action_randomization_scale,
|
|
454
|
-
}
|