replay-rec 0.20.0__py3-none-any.whl → 0.20.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- replay/data/dataset.py +10 -9
- replay/data/dataset_utils/dataset_label_encoder.py +5 -4
- replay/data/nn/schema.py +9 -18
- replay/data/nn/sequence_tokenizer.py +16 -15
- replay/data/nn/sequential_dataset.py +4 -4
- replay/data/nn/torch_sequential_dataset.py +5 -4
- replay/data/nn/utils.py +2 -1
- replay/data/schema.py +3 -12
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +62 -0
- replay/experimental/metrics/base_metric.py +603 -0
- replay/experimental/metrics/coverage.py +97 -0
- replay/experimental/metrics/experiment.py +175 -0
- replay/experimental/metrics/hitrate.py +26 -0
- replay/experimental/metrics/map.py +30 -0
- replay/experimental/metrics/mrr.py +18 -0
- replay/experimental/metrics/ncis_precision.py +31 -0
- replay/experimental/metrics/ndcg.py +49 -0
- replay/experimental/metrics/precision.py +22 -0
- replay/experimental/metrics/recall.py +25 -0
- replay/experimental/metrics/rocauc.py +49 -0
- replay/experimental/metrics/surprisal.py +90 -0
- replay/experimental/metrics/unexpectedness.py +76 -0
- replay/experimental/models/__init__.py +50 -0
- replay/experimental/models/admm_slim.py +257 -0
- replay/experimental/models/base_neighbour_rec.py +200 -0
- replay/experimental/models/base_rec.py +1386 -0
- replay/experimental/models/base_torch_rec.py +234 -0
- replay/experimental/models/cql.py +454 -0
- replay/experimental/models/ddpg.py +932 -0
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +189 -0
- replay/experimental/models/dt4rec/gpt1.py +401 -0
- replay/experimental/models/dt4rec/trainer.py +127 -0
- replay/experimental/models/dt4rec/utils.py +264 -0
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
- replay/experimental/models/hierarchical_recommender.py +331 -0
- replay/experimental/models/implicit_wrap.py +131 -0
- replay/experimental/models/lightfm_wrap.py +303 -0
- replay/experimental/models/mult_vae.py +332 -0
- replay/experimental/models/neural_ts.py +986 -0
- replay/experimental/models/neuromf.py +406 -0
- replay/experimental/models/scala_als.py +293 -0
- replay/experimental/models/u_lin_ucb.py +115 -0
- replay/experimental/nn/data/__init__.py +1 -0
- replay/experimental/nn/data/schema_builder.py +102 -0
- replay/experimental/preprocessing/__init__.py +3 -0
- replay/experimental/preprocessing/data_preparator.py +839 -0
- replay/experimental/preprocessing/padder.py +229 -0
- replay/experimental/preprocessing/sequence_generator.py +208 -0
- replay/experimental/scenarios/__init__.py +1 -0
- replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
- replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +117 -0
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +24 -0
- replay/experimental/utils/model_handler.py +186 -0
- replay/experimental/utils/session_handler.py +44 -0
- replay/metrics/base_metric.py +11 -10
- replay/metrics/categorical_diversity.py +8 -8
- replay/metrics/coverage.py +4 -4
- replay/metrics/experiment.py +3 -3
- replay/metrics/hitrate.py +1 -3
- replay/metrics/map.py +1 -3
- replay/metrics/mrr.py +1 -3
- replay/metrics/ndcg.py +1 -2
- replay/metrics/novelty.py +3 -3
- replay/metrics/offline_metrics.py +16 -16
- replay/metrics/precision.py +1 -3
- replay/metrics/recall.py +1 -3
- replay/metrics/rocauc.py +1 -3
- replay/metrics/surprisal.py +4 -4
- replay/metrics/torch_metrics_builder.py +13 -12
- replay/metrics/unexpectedness.py +2 -2
- replay/models/als.py +2 -2
- replay/models/association_rules.py +4 -3
- replay/models/base_neighbour_rec.py +3 -2
- replay/models/base_rec.py +11 -10
- replay/models/cat_pop_rec.py +2 -1
- replay/models/extensions/ann/ann_mixin.py +2 -1
- replay/models/extensions/ann/index_builders/executor_hnswlib_index_builder.py +2 -1
- replay/models/extensions/ann/index_builders/executor_nmslib_index_builder.py +2 -1
- replay/models/lin_ucb.py +3 -3
- replay/models/nn/optimizer_utils/optimizer_factory.py +2 -2
- replay/models/nn/sequential/bert4rec/dataset.py +2 -2
- replay/models/nn/sequential/bert4rec/lightning.py +3 -3
- replay/models/nn/sequential/bert4rec/model.py +2 -2
- replay/models/nn/sequential/callbacks/prediction_callbacks.py +12 -12
- replay/models/nn/sequential/callbacks/validation_callback.py +9 -9
- replay/models/nn/sequential/compiled/base_compiled_model.py +5 -5
- replay/models/nn/sequential/postprocessors/_base.py +2 -3
- replay/models/nn/sequential/postprocessors/postprocessors.py +10 -10
- replay/models/nn/sequential/sasrec/lightning.py +3 -3
- replay/models/nn/sequential/sasrec/model.py +8 -8
- replay/models/slim.py +2 -2
- replay/models/ucb.py +2 -2
- replay/models/word2vec.py +3 -3
- replay/preprocessing/discretizer.py +8 -7
- replay/preprocessing/filters.py +4 -4
- replay/preprocessing/history_based_fp.py +6 -6
- replay/preprocessing/label_encoder.py +8 -7
- replay/scenarios/fallback.py +4 -3
- replay/splitters/base_splitter.py +3 -3
- replay/splitters/cold_user_random_splitter.py +4 -4
- replay/splitters/k_folds.py +4 -4
- replay/splitters/last_n_splitter.py +10 -10
- replay/splitters/new_users_splitter.py +4 -4
- replay/splitters/random_splitter.py +4 -4
- replay/splitters/ratio_splitter.py +10 -10
- replay/splitters/time_splitter.py +6 -6
- replay/splitters/two_stage_splitter.py +4 -4
- replay/utils/__init__.py +1 -0
- replay/utils/common.py +1 -1
- replay/utils/session_handler.py +2 -2
- replay/utils/spark_utils.py +6 -5
- replay/utils/types.py +3 -1
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/METADATA +17 -17
- replay_rec-0.20.0rc0.dist-info/RECORD +194 -0
- replay_rec-0.20.0.dist-info/RECORD +0 -139
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.0.dist-info → replay_rec-0.20.0rc0.dist-info}/licenses/NOTICE +0 -0
|
@@ -0,0 +1,932 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import scipy.sparse as sp
|
|
7
|
+
import torch
|
|
8
|
+
import tqdm
|
|
9
|
+
from pytorch_optimizer import Ranger
|
|
10
|
+
from torch import nn
|
|
11
|
+
from torch.distributions.gamma import Gamma
|
|
12
|
+
|
|
13
|
+
from replay.data import get_schema
|
|
14
|
+
from replay.experimental.models.base_torch_rec import Recommender
|
|
15
|
+
from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
|
|
16
|
+
from replay.utils.spark_utils import convert2spark
|
|
17
|
+
|
|
18
|
+
if PYSPARK_AVAILABLE:
|
|
19
|
+
from pyspark.sql import functions as sf
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def to_np(tensor: torch.Tensor) -> np.array:
|
|
23
|
+
"""Converts torch.Tensor to numpy."""
|
|
24
|
+
return tensor.detach().cpu().numpy()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ReplayBuffer:
|
|
28
|
+
"""
|
|
29
|
+
Stores transitions for training RL model.
|
|
30
|
+
|
|
31
|
+
Usually transition is (state, action, reward, next_state).
|
|
32
|
+
In this implementation we compute state using embedding of user
|
|
33
|
+
and embeddings of `memory_size` latest relevant items.
|
|
34
|
+
Thereby in this ReplayBuffer we store (user, memory) instead of state.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, device, capacity, memory_size, embedding_dim):
|
|
38
|
+
self.capacity = capacity
|
|
39
|
+
|
|
40
|
+
self.buffer = {
|
|
41
|
+
"user": torch.zeros((capacity,), device=device),
|
|
42
|
+
"memory": torch.zeros((capacity, memory_size), device=device),
|
|
43
|
+
"action": torch.zeros((capacity, embedding_dim), device=device),
|
|
44
|
+
"reward": torch.zeros((capacity,), device=device),
|
|
45
|
+
"next_user": torch.zeros((capacity,), device=device),
|
|
46
|
+
"next_memory": torch.zeros((capacity, memory_size), device=device),
|
|
47
|
+
"done": torch.zeros((capacity,), device=device),
|
|
48
|
+
"sample_weight": torch.zeros((capacity,), device=device),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
self.pos = 0
|
|
52
|
+
self.is_filled = False
|
|
53
|
+
|
|
54
|
+
def push(
|
|
55
|
+
self,
|
|
56
|
+
user,
|
|
57
|
+
memory,
|
|
58
|
+
action,
|
|
59
|
+
reward,
|
|
60
|
+
next_user,
|
|
61
|
+
next_memory,
|
|
62
|
+
done,
|
|
63
|
+
sample_weight,
|
|
64
|
+
):
|
|
65
|
+
"""Add transition to buffer."""
|
|
66
|
+
|
|
67
|
+
batch_size = user.shape[0]
|
|
68
|
+
|
|
69
|
+
self.buffer["user"][self.pos : self.pos + batch_size] = user
|
|
70
|
+
self.buffer["memory"][self.pos : self.pos + batch_size] = memory
|
|
71
|
+
self.buffer["action"][self.pos : self.pos + batch_size] = action
|
|
72
|
+
self.buffer["reward"][self.pos : self.pos + batch_size] = reward
|
|
73
|
+
self.buffer["next_user"][self.pos : self.pos + batch_size] = next_user
|
|
74
|
+
self.buffer["next_memory"][self.pos : self.pos + batch_size] = next_memory
|
|
75
|
+
self.buffer["done"][self.pos : self.pos + batch_size] = done
|
|
76
|
+
self.buffer["sample_weight"][self.pos : self.pos + batch_size] = sample_weight
|
|
77
|
+
|
|
78
|
+
new_pos = self.pos + batch_size
|
|
79
|
+
if new_pos >= self.capacity:
|
|
80
|
+
self.is_filled = True
|
|
81
|
+
self.pos = new_pos % self.capacity
|
|
82
|
+
|
|
83
|
+
def sample(self, batch_size):
|
|
84
|
+
"""Sample transition from buffer."""
|
|
85
|
+
current_buffer_len = len(self)
|
|
86
|
+
|
|
87
|
+
indices = np.random.choice(current_buffer_len, batch_size)
|
|
88
|
+
|
|
89
|
+
return {
|
|
90
|
+
"user": self.buffer["user"][indices],
|
|
91
|
+
"memory": self.buffer["memory"][indices],
|
|
92
|
+
"action": self.buffer["action"][indices],
|
|
93
|
+
"reward": self.buffer["reward"][indices],
|
|
94
|
+
"next_user": self.buffer["next_user"][indices],
|
|
95
|
+
"next_memory": self.buffer["next_memory"][indices],
|
|
96
|
+
"done": self.buffer["done"][indices],
|
|
97
|
+
"sample_weight": self.buffer["sample_weight"][indices],
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
def __len__(self):
|
|
101
|
+
return self.capacity if self.is_filled else self.pos + 1
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class OUNoise:
|
|
105
|
+
"""https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py"""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
action_dim,
|
|
110
|
+
device,
|
|
111
|
+
theta=0.15,
|
|
112
|
+
max_sigma=0.4,
|
|
113
|
+
min_sigma=0.4,
|
|
114
|
+
noise_type="gauss",
|
|
115
|
+
decay_period=10,
|
|
116
|
+
):
|
|
117
|
+
self.theta = theta
|
|
118
|
+
self.sigma = max_sigma
|
|
119
|
+
self.max_sigma = max_sigma
|
|
120
|
+
self.min_sigma = min_sigma
|
|
121
|
+
self.decay_period = decay_period
|
|
122
|
+
self.action_dim = action_dim
|
|
123
|
+
self.noise_type = noise_type
|
|
124
|
+
self.device = device
|
|
125
|
+
self.state = torch.zeros((1, action_dim), device=self.device)
|
|
126
|
+
|
|
127
|
+
def reset(self, user_batch_size):
|
|
128
|
+
"""Fill state with zeros."""
|
|
129
|
+
if self.state.shape[0] == user_batch_size:
|
|
130
|
+
self.state.fill_(0)
|
|
131
|
+
else:
|
|
132
|
+
self.state = torch.zeros((user_batch_size, self.action_dim), device=self.device)
|
|
133
|
+
|
|
134
|
+
def evolve_state(self):
|
|
135
|
+
"""Perform OU discrete approximation step"""
|
|
136
|
+
x = self.state
|
|
137
|
+
d_x = -self.theta * x + self.sigma * torch.randn(x.shape, device=self.device)
|
|
138
|
+
self.state = x + d_x
|
|
139
|
+
return self.state
|
|
140
|
+
|
|
141
|
+
def get_action(self, action, step=0):
|
|
142
|
+
"""Get state after applying noise."""
|
|
143
|
+
self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, step / self.decay_period)
|
|
144
|
+
if self.noise_type == "ou":
|
|
145
|
+
ou_state = self.evolve_state()
|
|
146
|
+
return action + ou_state
|
|
147
|
+
elif self.noise_type == "gauss":
|
|
148
|
+
return action + self.sigma * torch.randn(action.shape, device=self.device)
|
|
149
|
+
else:
|
|
150
|
+
msg = "noise_type must be one of ['ou', 'gauss']"
|
|
151
|
+
raise ValueError(msg)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class ActorDRR(nn.Module):
|
|
155
|
+
"""
|
|
156
|
+
DDPG Actor model (based on `DRR
|
|
157
|
+
<https://arxiv.org/pdf/1802.05814.pdf>`).
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
def __init__(
|
|
161
|
+
self,
|
|
162
|
+
user_num,
|
|
163
|
+
item_num,
|
|
164
|
+
embedding_dim,
|
|
165
|
+
hidden_dim,
|
|
166
|
+
memory_size,
|
|
167
|
+
env_gamma_alpha,
|
|
168
|
+
device,
|
|
169
|
+
min_trajectory_len,
|
|
170
|
+
):
|
|
171
|
+
super().__init__()
|
|
172
|
+
self.layers = nn.Sequential(
|
|
173
|
+
nn.Linear(embedding_dim * 3, hidden_dim),
|
|
174
|
+
nn.LayerNorm(hidden_dim),
|
|
175
|
+
nn.ReLU(),
|
|
176
|
+
nn.Linear(hidden_dim, embedding_dim),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
self.state_repr = StateReprModule(user_num, item_num, embedding_dim, memory_size)
|
|
180
|
+
|
|
181
|
+
self.initialize()
|
|
182
|
+
|
|
183
|
+
self.environment = Env(
|
|
184
|
+
item_num,
|
|
185
|
+
user_num,
|
|
186
|
+
memory_size,
|
|
187
|
+
env_gamma_alpha,
|
|
188
|
+
device,
|
|
189
|
+
min_trajectory_len,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def initialize(self):
|
|
193
|
+
"""weight init"""
|
|
194
|
+
for layer in self.layers:
|
|
195
|
+
if isinstance(layer, nn.Linear):
|
|
196
|
+
nn.init.kaiming_uniform_(layer.weight)
|
|
197
|
+
|
|
198
|
+
def forward(self, user, memory):
|
|
199
|
+
"""
|
|
200
|
+
:param user: user batch
|
|
201
|
+
:param memory: memory batch
|
|
202
|
+
:return: output, vector of the size `embedding_dim`
|
|
203
|
+
"""
|
|
204
|
+
state = self.state_repr(user, memory)
|
|
205
|
+
return self.layers(state)
|
|
206
|
+
|
|
207
|
+
def get_action(self, action_emb, items, items_mask, return_scores=False):
|
|
208
|
+
"""
|
|
209
|
+
:param action_emb: output of the .forward() (user_batch_size x emb_dim)
|
|
210
|
+
:param items: items batch (user_batch_size x items_num)
|
|
211
|
+
:param items_mask: mask of available items for reccomendation (user_batch_size x items_num)
|
|
212
|
+
:param return_scores: whether to return scores of items
|
|
213
|
+
:return: output, prediction (and scores if return_scores)
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
assert items.shape == items_mask.shape
|
|
217
|
+
|
|
218
|
+
items = self.state_repr.item_embeddings(items) # B x i x emb_dim
|
|
219
|
+
scores = torch.bmm(
|
|
220
|
+
items,
|
|
221
|
+
action_emb.unsqueeze(-1), # B x emb_dim x 1
|
|
222
|
+
).squeeze(-1)
|
|
223
|
+
|
|
224
|
+
assert scores.shape == items_mask.shape
|
|
225
|
+
|
|
226
|
+
scores = scores * items_mask
|
|
227
|
+
|
|
228
|
+
if return_scores:
|
|
229
|
+
return scores, torch.argmax(scores, dim=1)
|
|
230
|
+
else:
|
|
231
|
+
return torch.argmax(scores, dim=1)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class CriticDRR(nn.Module):
|
|
235
|
+
"""
|
|
236
|
+
DDPG Critic model (based on `DRR
|
|
237
|
+
<https://arxiv.org/pdf/1802.05814.pdf>`
|
|
238
|
+
and `Bayes-UCBDQN <https://arxiv.org/pdf/2205.07704.pdf>`).
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
def __init__(self, state_repr_dim, action_emb_dim, hidden_dim, heads_num, heads_q):
|
|
242
|
+
"""
|
|
243
|
+
:param heads_num: number of heads (samples of Q funtion)
|
|
244
|
+
:param heads_q: quantile of Q function distribution
|
|
245
|
+
"""
|
|
246
|
+
super().__init__()
|
|
247
|
+
self.layers = nn.Sequential(
|
|
248
|
+
nn.Linear(state_repr_dim + action_emb_dim, hidden_dim),
|
|
249
|
+
nn.LayerNorm(hidden_dim),
|
|
250
|
+
nn.ReLU(),
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
self.heads = nn.ModuleList([nn.Linear(hidden_dim, 1) for _ in range(heads_num)])
|
|
254
|
+
self.heads_q = heads_q
|
|
255
|
+
|
|
256
|
+
self.initialize()
|
|
257
|
+
|
|
258
|
+
def initialize(self):
|
|
259
|
+
"""weight init"""
|
|
260
|
+
for layer in self.layers:
|
|
261
|
+
if isinstance(layer, nn.Linear):
|
|
262
|
+
nn.init.kaiming_uniform_(layer.weight)
|
|
263
|
+
|
|
264
|
+
for head in self.heads:
|
|
265
|
+
nn.init.kaiming_uniform_(head.weight)
|
|
266
|
+
|
|
267
|
+
def forward(self, state, action):
|
|
268
|
+
"""
|
|
269
|
+
:param state: state batch
|
|
270
|
+
:param action: action batch
|
|
271
|
+
:return: x, Q values for given states and actions
|
|
272
|
+
"""
|
|
273
|
+
x = torch.cat([state, action], 1)
|
|
274
|
+
out = self.layers(x)
|
|
275
|
+
heads_out = torch.stack([head(out) for head in self.heads])
|
|
276
|
+
out = torch.quantile(heads_out, self.heads_q, dim=0)
|
|
277
|
+
|
|
278
|
+
return out
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class Env:
|
|
282
|
+
"""
|
|
283
|
+
RL environment for recommender systems.
|
|
284
|
+
Simulates interacting with a batch of users
|
|
285
|
+
|
|
286
|
+
Keep users' latest relevant items (memory).
|
|
287
|
+
|
|
288
|
+
:param item_count: total number of items
|
|
289
|
+
:param user_count: total number of users
|
|
290
|
+
:param memory_size: maximum number of items in memory
|
|
291
|
+
:param memory: torch.tensor with users' latest relevant items
|
|
292
|
+
:param matrix: sparse matrix with users-item ratings
|
|
293
|
+
:param user_ids: user ids from the batch
|
|
294
|
+
:param related_items: relevant items for current users
|
|
295
|
+
:param nonrelated_items: non-relevant items for current users
|
|
296
|
+
:param max_num_rele: maximum number of related items by users in the batch
|
|
297
|
+
:param available_items: items available for recommendation
|
|
298
|
+
:param available_items_mask: mask of non-seen items
|
|
299
|
+
:param gamma: param of Gamma distibution for sample weights
|
|
300
|
+
"""
|
|
301
|
+
|
|
302
|
+
matrix: np.array
|
|
303
|
+
related_items: torch.Tensor
|
|
304
|
+
nonrelated_items: torch.Tensor
|
|
305
|
+
available_items: torch.Tensor # B x i
|
|
306
|
+
available_items_mask: torch.Tensor # B x i
|
|
307
|
+
user_id: torch.Tensor # batch of users B x i
|
|
308
|
+
num_rele: int
|
|
309
|
+
|
|
310
|
+
def __init__(
|
|
311
|
+
self,
|
|
312
|
+
item_count,
|
|
313
|
+
user_count,
|
|
314
|
+
memory_size,
|
|
315
|
+
gamma_alpha,
|
|
316
|
+
device,
|
|
317
|
+
min_trajectory_len,
|
|
318
|
+
):
|
|
319
|
+
"""
|
|
320
|
+
Initialize memory as ['item_num'] * 'memory_size' for each user.
|
|
321
|
+
|
|
322
|
+
'item_num' is a padding index in StateReprModule.
|
|
323
|
+
It will result in zero embeddings.
|
|
324
|
+
"""
|
|
325
|
+
self.item_count = item_count
|
|
326
|
+
self.user_count = user_count
|
|
327
|
+
self.memory_size = memory_size
|
|
328
|
+
self.device = device
|
|
329
|
+
self.gamma = Gamma(
|
|
330
|
+
torch.tensor([float(gamma_alpha)]),
|
|
331
|
+
torch.tensor([1 / float(gamma_alpha)]),
|
|
332
|
+
)
|
|
333
|
+
self.memory = torch.full((user_count, memory_size), item_count, device=device)
|
|
334
|
+
self.min_trajectory_len = min_trajectory_len
|
|
335
|
+
self.max_num_rele = None
|
|
336
|
+
self.user_batch_size = None
|
|
337
|
+
self.user_ids = None
|
|
338
|
+
|
|
339
|
+
def update_env(self, matrix=None, item_count=None):
|
|
340
|
+
"""Update some of Env attributes."""
|
|
341
|
+
if item_count is not None:
|
|
342
|
+
self.item_count = item_count
|
|
343
|
+
if matrix is not None:
|
|
344
|
+
self.matrix = matrix.copy()
|
|
345
|
+
|
|
346
|
+
def reset(self, user_ids):
|
|
347
|
+
"""
|
|
348
|
+
:param user_id: batch of user ids
|
|
349
|
+
:return: user, memory
|
|
350
|
+
"""
|
|
351
|
+
self.user_batch_size = len(user_ids)
|
|
352
|
+
|
|
353
|
+
self.user_ids = torch.tensor(user_ids, dtype=torch.int64, device=self.device)
|
|
354
|
+
|
|
355
|
+
self.max_num_rele = max((self.matrix[user_ids] > 0).sum(1).max(), self.min_trajectory_len)
|
|
356
|
+
self.available_items = torch.zeros(
|
|
357
|
+
(self.user_batch_size, 2 * self.max_num_rele),
|
|
358
|
+
dtype=torch.int64,
|
|
359
|
+
device=self.device,
|
|
360
|
+
)
|
|
361
|
+
self.available_items_mask = torch.ones_like(self.available_items, device=self.device)
|
|
362
|
+
|
|
363
|
+
# padding with non-existent items
|
|
364
|
+
self.related_items = torch.full(
|
|
365
|
+
(self.user_batch_size, self.max_num_rele),
|
|
366
|
+
-1, # maybe define new constant
|
|
367
|
+
device=self.device,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
for idx, user_id in enumerate(user_ids):
|
|
371
|
+
user_related_items = torch.tensor(np.argwhere(self.matrix[user_id] > 0)[:, 1], device=self.device)
|
|
372
|
+
|
|
373
|
+
user_num_rele = len(user_related_items)
|
|
374
|
+
|
|
375
|
+
self.related_items[idx, :user_num_rele] = user_related_items
|
|
376
|
+
|
|
377
|
+
replace = bool(2 * self.max_num_rele > self.item_count)
|
|
378
|
+
|
|
379
|
+
nonrelated_items = torch.tensor(
|
|
380
|
+
np.random.choice(
|
|
381
|
+
list(set(range(self.item_count + 1)) - set(user_related_items.tolist())),
|
|
382
|
+
replace=replace,
|
|
383
|
+
size=2 * self.max_num_rele - user_num_rele,
|
|
384
|
+
)
|
|
385
|
+
).to(self.device)
|
|
386
|
+
|
|
387
|
+
self.available_items[idx, :user_num_rele] = user_related_items
|
|
388
|
+
self.available_items[idx, user_num_rele:] = nonrelated_items
|
|
389
|
+
self.available_items[self.available_items == -1] = self.item_count
|
|
390
|
+
perm = torch.randperm(self.available_items.shape[1])
|
|
391
|
+
self.available_items[idx] = self.available_items[idx, perm]
|
|
392
|
+
|
|
393
|
+
return self.user_ids, self.memory[self.user_ids]
|
|
394
|
+
|
|
395
|
+
def step(self, actions, actions_emb=None, buffer: ReplayBuffer = None):
|
|
396
|
+
"""Execute step and return (user, memory) for new state"""
|
|
397
|
+
initial_users = self.user_ids
|
|
398
|
+
initial_memory = self.memory[self.user_ids].clone()
|
|
399
|
+
|
|
400
|
+
global_actions = self.available_items[torch.arange(self.available_items.shape[0]), actions]
|
|
401
|
+
rewards = (global_actions.reshape(-1, 1) == self.related_items).sum(1)
|
|
402
|
+
for idx, reward in enumerate(rewards):
|
|
403
|
+
if reward:
|
|
404
|
+
user_id = self.user_ids[idx]
|
|
405
|
+
self.memory[user_id] = torch.tensor([*self.memory[user_id][1:], global_actions[idx]])
|
|
406
|
+
|
|
407
|
+
self.available_items_mask[torch.arange(self.available_items_mask.shape[0]), actions] = 0
|
|
408
|
+
|
|
409
|
+
if buffer is not None:
|
|
410
|
+
sample_weight = self.gamma.sample((self.user_batch_size,)).squeeze().detach().to(self.device)
|
|
411
|
+
buffer.push(
|
|
412
|
+
initial_users.detach(),
|
|
413
|
+
initial_memory.detach(),
|
|
414
|
+
actions_emb.detach(),
|
|
415
|
+
rewards.detach(),
|
|
416
|
+
self.user_ids.detach(),
|
|
417
|
+
self.memory[self.user_ids].detach(),
|
|
418
|
+
rewards.detach(),
|
|
419
|
+
sample_weight,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
return self.user_ids, self.memory[self.user_ids], rewards, 0
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class StateReprModule(nn.Module):
|
|
426
|
+
"""
|
|
427
|
+
Compute state for RL environment. Based on `DRR paper
|
|
428
|
+
<https://arxiv.org/pdf/1810.12027.pdf>`_
|
|
429
|
+
|
|
430
|
+
Computes State is a concatenation of user embedding,
|
|
431
|
+
weighted average pooling of `memory_size` latest relevant items
|
|
432
|
+
and their pairwise product.
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
def __init__(
|
|
436
|
+
self,
|
|
437
|
+
user_num,
|
|
438
|
+
item_num,
|
|
439
|
+
embedding_dim,
|
|
440
|
+
memory_size,
|
|
441
|
+
):
|
|
442
|
+
super().__init__()
|
|
443
|
+
self.user_embeddings = nn.Embedding(user_num, embedding_dim)
|
|
444
|
+
|
|
445
|
+
self.item_embeddings = nn.Embedding(item_num + 1, embedding_dim, padding_idx=int(item_num))
|
|
446
|
+
|
|
447
|
+
self.drr_ave = torch.nn.Conv1d(in_channels=memory_size, out_channels=1, kernel_size=1)
|
|
448
|
+
|
|
449
|
+
self.initialize()
|
|
450
|
+
|
|
451
|
+
def initialize(self):
|
|
452
|
+
"""weight init"""
|
|
453
|
+
nn.init.normal_(self.user_embeddings.weight, std=0.01)
|
|
454
|
+
self.item_embeddings.weight.data[-1].zero_()
|
|
455
|
+
|
|
456
|
+
nn.init.normal_(self.item_embeddings.weight, std=0.01)
|
|
457
|
+
nn.init.uniform_(self.drr_ave.weight)
|
|
458
|
+
|
|
459
|
+
self.drr_ave.bias.data.zero_()
|
|
460
|
+
|
|
461
|
+
def forward(self, user, memory):
|
|
462
|
+
"""
|
|
463
|
+
:param user: user batch
|
|
464
|
+
:param memory: memory batch
|
|
465
|
+
:return: vector of dimension 3 * embedding_dim
|
|
466
|
+
"""
|
|
467
|
+
user_embedding = self.user_embeddings(user.long())
|
|
468
|
+
|
|
469
|
+
item_embeddings = self.item_embeddings(memory.long())
|
|
470
|
+
drr_ave = self.drr_ave(item_embeddings).squeeze(1)
|
|
471
|
+
|
|
472
|
+
return torch.cat((user_embedding, user_embedding * drr_ave, drr_ave), 1)
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
class DDPG(Recommender):
|
|
476
|
+
"""
|
|
477
|
+
`Deep Deterministic Policy Gradient
|
|
478
|
+
<https://arxiv.org/pdf/1810.12027.pdf>`_
|
|
479
|
+
|
|
480
|
+
This implementation enhanced by more advanced noise strategy.
|
|
481
|
+
"""
|
|
482
|
+
|
|
483
|
+
batch_size: int = 512
|
|
484
|
+
embedding_dim: int = 8
|
|
485
|
+
hidden_dim: int = 16
|
|
486
|
+
value_lr: float = 1e-5
|
|
487
|
+
value_decay: float = 1e-5
|
|
488
|
+
policy_lr: float = 1e-5
|
|
489
|
+
policy_decay: float = 1e-6
|
|
490
|
+
gamma: float = 0.8
|
|
491
|
+
memory_size: int = 5
|
|
492
|
+
min_value: int = -10
|
|
493
|
+
max_value: int = 10
|
|
494
|
+
buffer_size: int = 1000000
|
|
495
|
+
_search_space = {
|
|
496
|
+
"noise_sigma": {"type": "uniform", "args": [0.1, 0.6]},
|
|
497
|
+
"gamma": {"type": "uniform", "args": [0.7, 1.0]},
|
|
498
|
+
"value_lr": {"type": "loguniform", "args": [1e-7, 1e-1]},
|
|
499
|
+
"value_decay": {"type": "loguniform", "args": [1e-7, 1e-1]},
|
|
500
|
+
"policy_lr": {"type": "loguniform", "args": [1e-7, 1e-1]},
|
|
501
|
+
"policy_decay": {"type": "loguniform", "args": [1e-7, 1e-1]},
|
|
502
|
+
"memory_size": {"type": "categorical", "args": [3, 5, 7, 9]},
|
|
503
|
+
"noise_type": {"type": "categorical", "args": ["gauss", "ou"]},
|
|
504
|
+
}
|
|
505
|
+
checkpoint_step: int = 10000
|
|
506
|
+
replay_buffer: ReplayBuffer
|
|
507
|
+
ou_noise: OUNoise
|
|
508
|
+
model: ActorDRR
|
|
509
|
+
target_model: ActorDRR
|
|
510
|
+
value_net: CriticDRR
|
|
511
|
+
target_value_net: CriticDRR
|
|
512
|
+
policy_optimizer: Ranger
|
|
513
|
+
value_optimizer: Ranger
|
|
514
|
+
|
|
515
|
+
def __init__(
|
|
516
|
+
self,
|
|
517
|
+
noise_sigma: float = 0.2,
|
|
518
|
+
noise_theta: float = 0.05,
|
|
519
|
+
noise_type: str = "gauss",
|
|
520
|
+
seed: int = 9,
|
|
521
|
+
user_num: int = 10,
|
|
522
|
+
item_num: int = 10,
|
|
523
|
+
log_dir: str = "logs/tmp",
|
|
524
|
+
exact_embeddings_size=True,
|
|
525
|
+
n_critics_head: int = 10,
|
|
526
|
+
env_gamma_alpha: float = 0.2,
|
|
527
|
+
critic_heads_q: float = 0.15,
|
|
528
|
+
n_jobs=None,
|
|
529
|
+
use_gpu=False,
|
|
530
|
+
user_batch_size: int = 8,
|
|
531
|
+
min_trajectory_len: int = 10,
|
|
532
|
+
):
|
|
533
|
+
"""
|
|
534
|
+
:param noise_sigma: Ornstein-Uhlenbeck noise sigma value
|
|
535
|
+
:param noise_theta: Ornstein-Uhlenbeck noise theta value
|
|
536
|
+
:param noise_type: type of action noise, one of ["ou", "gauss"]
|
|
537
|
+
:param seed: random seed
|
|
538
|
+
:param user_num: number of users, specify when using ``exact_embeddings_size``
|
|
539
|
+
:param item_num: number of items, specify when using ``exact_embeddings_size``
|
|
540
|
+
:param log_dir: dir to save models
|
|
541
|
+
:exact_embeddings_size: flag whether to set user/item_num from training log
|
|
542
|
+
"""
|
|
543
|
+
super().__init__()
|
|
544
|
+
np.random.seed(seed)
|
|
545
|
+
torch.manual_seed(seed)
|
|
546
|
+
|
|
547
|
+
self.noise_theta = noise_theta
|
|
548
|
+
self.noise_sigma = noise_sigma
|
|
549
|
+
self.noise_type = noise_type
|
|
550
|
+
self.seed = seed
|
|
551
|
+
self.user_num = user_num
|
|
552
|
+
self.item_num = item_num
|
|
553
|
+
self.log_dir = Path(log_dir)
|
|
554
|
+
self.exact_embeddings_size = exact_embeddings_size
|
|
555
|
+
self.n_critics_head = n_critics_head
|
|
556
|
+
self.env_gamma_alpha = env_gamma_alpha
|
|
557
|
+
self.critic_heads_q = critic_heads_q
|
|
558
|
+
self.user_batch_size = user_batch_size
|
|
559
|
+
self.min_trajectory_len = min_trajectory_len
|
|
560
|
+
|
|
561
|
+
self.memory = None
|
|
562
|
+
self.fit_users = None
|
|
563
|
+
self.fit_items = None
|
|
564
|
+
|
|
565
|
+
if n_jobs is not None:
|
|
566
|
+
torch.set_num_threads(n_jobs)
|
|
567
|
+
|
|
568
|
+
if use_gpu:
|
|
569
|
+
use_cuda = torch.cuda.is_available()
|
|
570
|
+
if use_cuda:
|
|
571
|
+
self.device = torch.device("cuda")
|
|
572
|
+
else:
|
|
573
|
+
self.device = torch.device("cpu")
|
|
574
|
+
else:
|
|
575
|
+
self.device = torch.device("cpu")
|
|
576
|
+
|
|
577
|
+
@property
|
|
578
|
+
def _init_args(self):
|
|
579
|
+
return {
|
|
580
|
+
"noise_sigma": self.noise_sigma,
|
|
581
|
+
"noise_theta": self.noise_theta,
|
|
582
|
+
"noise_type": self.noise_type,
|
|
583
|
+
"seed": self.seed,
|
|
584
|
+
"user_num": self.user_num,
|
|
585
|
+
"item_num": self.item_num,
|
|
586
|
+
"exact_embeddings_size": self.exact_embeddings_size,
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
@property
|
|
590
|
+
def _dataframes(self):
|
|
591
|
+
return {
|
|
592
|
+
"memory": self.memory,
|
|
593
|
+
"fit_users": self.fit_users,
|
|
594
|
+
"fit_items": self.fit_items,
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
def _batch_pass(self, batch: dict) -> dict[str, Any]:
|
|
598
|
+
user = batch["user"]
|
|
599
|
+
memory = batch["memory"]
|
|
600
|
+
action = batch["action"]
|
|
601
|
+
reward = batch["reward"]
|
|
602
|
+
next_user = batch["next_user"]
|
|
603
|
+
next_memory = batch["next_memory"]
|
|
604
|
+
done = batch["done"]
|
|
605
|
+
sample_weight = batch["sample_weight"]
|
|
606
|
+
state = self.model.state_repr(user, memory)
|
|
607
|
+
|
|
608
|
+
with torch.no_grad():
|
|
609
|
+
next_state = self.model.state_repr(next_user, next_memory)
|
|
610
|
+
next_action = self.target_model(next_user, next_memory)
|
|
611
|
+
target_value = self.target_value_net(next_state, next_action.detach())
|
|
612
|
+
expected_value = reward + (1.0 - done) * self.gamma * target_value.squeeze(1) # smth strange, check article
|
|
613
|
+
expected_value = torch.clamp(expected_value, self.min_value, self.max_value)
|
|
614
|
+
|
|
615
|
+
proto_action = self.model.layers(state)
|
|
616
|
+
policy_loss = -self.value_net(state.detach(), proto_action).mean()
|
|
617
|
+
|
|
618
|
+
value = self.value_net(state, action)
|
|
619
|
+
value_loss = ((value - expected_value.detach()).pow(2) * sample_weight).squeeze(1).mean()
|
|
620
|
+
return policy_loss, value_loss
|
|
621
|
+
|
|
622
|
+
@staticmethod
|
|
623
|
+
def _predict_pairs_inner(
|
|
624
|
+
model,
|
|
625
|
+
user_idx: int,
|
|
626
|
+
items_np: np.ndarray,
|
|
627
|
+
) -> SparkDataFrame:
|
|
628
|
+
with torch.no_grad():
|
|
629
|
+
user_batch = torch.tensor([user_idx], dtype=torch.int64)
|
|
630
|
+
memory = model.environment.memory[user_batch]
|
|
631
|
+
action_emb = model(user_batch, memory)
|
|
632
|
+
items = torch.tensor(items_np, dtype=torch.int64).unsqueeze(0)
|
|
633
|
+
scores, _ = model.get_action(action_emb, items, torch.full_like(items, True), True)
|
|
634
|
+
scores = scores.squeeze()
|
|
635
|
+
return PandasDataFrame(
|
|
636
|
+
{
|
|
637
|
+
"user_idx": scores.shape[0] * [user_idx],
|
|
638
|
+
"item_idx": items_np,
|
|
639
|
+
"relevance": scores,
|
|
640
|
+
}
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
def _predict(
|
|
644
|
+
self,
|
|
645
|
+
log: SparkDataFrame,
|
|
646
|
+
k: int, # noqa: ARG002
|
|
647
|
+
users: SparkDataFrame,
|
|
648
|
+
items: SparkDataFrame,
|
|
649
|
+
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
650
|
+
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
651
|
+
filter_seen_items: bool = True, # noqa: ARG002
|
|
652
|
+
) -> SparkDataFrame:
|
|
653
|
+
items_consider_in_pred = items.toPandas()["item_idx"].values
|
|
654
|
+
model = self.model.cpu()
|
|
655
|
+
|
|
656
|
+
def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
657
|
+
return DDPG._predict_pairs_inner(
|
|
658
|
+
model=model,
|
|
659
|
+
user_idx=pandas_df["user_idx"][0],
|
|
660
|
+
items_np=items_consider_in_pred,
|
|
661
|
+
)[["user_idx", "item_idx", "relevance"]]
|
|
662
|
+
|
|
663
|
+
self.logger.debug("Predict started")
|
|
664
|
+
rec_schema = get_schema(
|
|
665
|
+
query_column="user_idx",
|
|
666
|
+
item_column="item_idx",
|
|
667
|
+
rating_column="relevance",
|
|
668
|
+
has_timestamp=False,
|
|
669
|
+
)
|
|
670
|
+
recs = (
|
|
671
|
+
users.join(log, how="left", on="user_idx")
|
|
672
|
+
.select("user_idx", "item_idx")
|
|
673
|
+
.groupby("user_idx")
|
|
674
|
+
.applyInPandas(grouped_map, rec_schema)
|
|
675
|
+
)
|
|
676
|
+
return recs
|
|
677
|
+
|
|
678
|
+
def _predict_pairs(
|
|
679
|
+
self,
|
|
680
|
+
pairs: SparkDataFrame,
|
|
681
|
+
log: Optional[SparkDataFrame] = None,
|
|
682
|
+
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
683
|
+
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
684
|
+
) -> SparkDataFrame:
|
|
685
|
+
model = self.model.cpu()
|
|
686
|
+
|
|
687
|
+
def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
|
|
688
|
+
return DDPG._predict_pairs_inner(
|
|
689
|
+
model=model,
|
|
690
|
+
user_idx=pandas_df["user_idx"][0],
|
|
691
|
+
items_np=np.array(pandas_df["item_idx_to_pred"][0]),
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
self.logger.debug("Calculate relevance for user-item pairs")
|
|
695
|
+
|
|
696
|
+
rec_schema = get_schema(
|
|
697
|
+
query_column="user_idx",
|
|
698
|
+
item_column="item_idx",
|
|
699
|
+
rating_column="relevance",
|
|
700
|
+
has_timestamp=False,
|
|
701
|
+
)
|
|
702
|
+
recs = (
|
|
703
|
+
pairs.groupBy("user_idx")
|
|
704
|
+
.agg(sf.collect_list("item_idx").alias("item_idx_to_pred"))
|
|
705
|
+
.join(log.select("user_idx").distinct(), on="user_idx", how="inner")
|
|
706
|
+
.groupby("user_idx")
|
|
707
|
+
.applyInPandas(grouped_map, rec_schema)
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
return recs
|
|
711
|
+
|
|
712
|
+
@staticmethod
|
|
713
|
+
def _preprocess_df(data):
|
|
714
|
+
"""
|
|
715
|
+
:param data: pandas DataFrame
|
|
716
|
+
"""
|
|
717
|
+
data = data[["user_idx", "item_idx", "relevance"]]
|
|
718
|
+
users = data["user_idx"].values
|
|
719
|
+
items = data["item_idx"].values
|
|
720
|
+
scores = data["relevance"].values
|
|
721
|
+
|
|
722
|
+
user_num = int(max(users)) + 1
|
|
723
|
+
item_num = int(max(items)) + 1
|
|
724
|
+
|
|
725
|
+
train_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
|
|
726
|
+
for user, item, rel in zip(users, items, scores):
|
|
727
|
+
train_matrix[user, item] = rel
|
|
728
|
+
|
|
729
|
+
appropriate_users = data["user_idx"].unique()
|
|
730
|
+
|
|
731
|
+
return train_matrix, user_num, item_num, appropriate_users
|
|
732
|
+
|
|
733
|
+
@staticmethod
|
|
734
|
+
def _preprocess_log(log):
|
|
735
|
+
return DDPG._preprocess_df(log.toPandas())
|
|
736
|
+
|
|
737
|
+
def _get_batch(self) -> dict:
|
|
738
|
+
batch = self.replay_buffer.sample(self.batch_size)
|
|
739
|
+
return batch
|
|
740
|
+
|
|
741
|
+
def _run_train_step(self, batch: dict) -> None:
|
|
742
|
+
policy_loss, value_loss = self._batch_pass(batch)
|
|
743
|
+
total_loss = policy_loss + value_loss
|
|
744
|
+
|
|
745
|
+
self.policy_optimizer.zero_grad()
|
|
746
|
+
self.value_optimizer.zero_grad()
|
|
747
|
+
total_loss.backward()
|
|
748
|
+
self.policy_optimizer.step()
|
|
749
|
+
self.value_optimizer.step()
|
|
750
|
+
|
|
751
|
+
self._target_update(self.target_value_net, self.value_net)
|
|
752
|
+
self._target_update(self.target_model, self.model)
|
|
753
|
+
|
|
754
|
+
@staticmethod
|
|
755
|
+
def _target_update(target_net, net, soft_tau=1e-3):
|
|
756
|
+
for target_param, param in zip(target_net.parameters(), net.parameters()):
|
|
757
|
+
target_param.data.copy_(target_param.data * (1.0 - soft_tau) + param.data * soft_tau)
|
|
758
|
+
|
|
759
|
+
def _init_inner(self):
|
|
760
|
+
self.replay_buffer = ReplayBuffer(
|
|
761
|
+
self.device,
|
|
762
|
+
self.buffer_size,
|
|
763
|
+
memory_size=self.memory_size,
|
|
764
|
+
embedding_dim=self.embedding_dim,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
self.ou_noise = OUNoise(
|
|
768
|
+
self.embedding_dim,
|
|
769
|
+
device=self.device,
|
|
770
|
+
theta=self.noise_theta,
|
|
771
|
+
max_sigma=self.noise_sigma,
|
|
772
|
+
min_sigma=self.noise_sigma,
|
|
773
|
+
noise_type=self.noise_type,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
self.model = ActorDRR(
|
|
777
|
+
self.user_num,
|
|
778
|
+
self.item_num,
|
|
779
|
+
self.embedding_dim,
|
|
780
|
+
self.hidden_dim,
|
|
781
|
+
self.memory_size,
|
|
782
|
+
env_gamma_alpha=self.env_gamma_alpha,
|
|
783
|
+
device=self.device,
|
|
784
|
+
min_trajectory_len=self.min_trajectory_len,
|
|
785
|
+
).to(self.device)
|
|
786
|
+
|
|
787
|
+
self.target_model = ActorDRR(
|
|
788
|
+
self.user_num,
|
|
789
|
+
self.item_num,
|
|
790
|
+
self.embedding_dim,
|
|
791
|
+
self.hidden_dim,
|
|
792
|
+
self.memory_size,
|
|
793
|
+
env_gamma_alpha=self.env_gamma_alpha,
|
|
794
|
+
device=self.device,
|
|
795
|
+
min_trajectory_len=self.min_trajectory_len,
|
|
796
|
+
).to(self.device)
|
|
797
|
+
|
|
798
|
+
self.value_net = CriticDRR(
|
|
799
|
+
self.embedding_dim * 3,
|
|
800
|
+
self.embedding_dim,
|
|
801
|
+
self.hidden_dim,
|
|
802
|
+
heads_num=self.n_critics_head,
|
|
803
|
+
heads_q=self.critic_heads_q,
|
|
804
|
+
).to(self.device)
|
|
805
|
+
|
|
806
|
+
self.target_value_net = CriticDRR(
|
|
807
|
+
self.embedding_dim * 3,
|
|
808
|
+
self.embedding_dim,
|
|
809
|
+
self.hidden_dim,
|
|
810
|
+
heads_num=self.n_critics_head,
|
|
811
|
+
heads_q=self.critic_heads_q,
|
|
812
|
+
).to(self.device)
|
|
813
|
+
|
|
814
|
+
self._target_update(self.target_value_net, self.value_net, soft_tau=1)
|
|
815
|
+
self._target_update(self.target_model, self.model, soft_tau=1)
|
|
816
|
+
self.policy_optimizer = Ranger(
|
|
817
|
+
self.model.parameters(),
|
|
818
|
+
lr=self.policy_lr,
|
|
819
|
+
weight_decay=self.policy_decay,
|
|
820
|
+
)
|
|
821
|
+
self.value_optimizer = Ranger(
|
|
822
|
+
self.value_net.parameters(),
|
|
823
|
+
lr=self.value_lr,
|
|
824
|
+
weight_decay=self.value_decay,
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
def _fit(
|
|
828
|
+
self,
|
|
829
|
+
log: SparkDataFrame,
|
|
830
|
+
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
831
|
+
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
832
|
+
) -> None:
|
|
833
|
+
data = log.toPandas()
|
|
834
|
+
self._fit_df(data)
|
|
835
|
+
|
|
836
|
+
def _fit_df(self, data):
|
|
837
|
+
train_matrix, user_num, item_num, users = self._preprocess_df(data)
|
|
838
|
+
|
|
839
|
+
if self.exact_embeddings_size:
|
|
840
|
+
self.user_num = user_num
|
|
841
|
+
self.item_num = item_num
|
|
842
|
+
self._init_inner()
|
|
843
|
+
|
|
844
|
+
self.model.environment.update_env(matrix=train_matrix)
|
|
845
|
+
users = np.random.permutation(users)
|
|
846
|
+
|
|
847
|
+
self.logger.debug("Training DDPG")
|
|
848
|
+
self.train(users)
|
|
849
|
+
|
|
850
|
+
@staticmethod
|
|
851
|
+
def users_loader(users, batch_size):
|
|
852
|
+
"""loader for users' batch"""
|
|
853
|
+
pos = 0
|
|
854
|
+
while pos != len(users):
|
|
855
|
+
new_pos = min(pos + batch_size, len(users))
|
|
856
|
+
yield users[pos:new_pos]
|
|
857
|
+
pos = new_pos
|
|
858
|
+
|
|
859
|
+
def train(self, users: np.array) -> None:
|
|
860
|
+
"""
|
|
861
|
+
Run training loop
|
|
862
|
+
|
|
863
|
+
:param users: array with users for training
|
|
864
|
+
:return:
|
|
865
|
+
"""
|
|
866
|
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
867
|
+
step = 0
|
|
868
|
+
users_loader = self.users_loader(users, self.user_batch_size)
|
|
869
|
+
for user_ids in tqdm.auto.tqdm(list(users_loader)):
|
|
870
|
+
user_ids, memory = self.model.environment.reset(user_ids)
|
|
871
|
+
self.ou_noise.reset(user_ids.shape[0])
|
|
872
|
+
for users_step in range(self.model.environment.max_num_rele):
|
|
873
|
+
actions_emb = self.model(user_ids, memory)
|
|
874
|
+
actions_emb = self.ou_noise.get_action(actions_emb, users_step)
|
|
875
|
+
|
|
876
|
+
actions = self.model.get_action(
|
|
877
|
+
actions_emb,
|
|
878
|
+
self.model.environment.available_items,
|
|
879
|
+
self.model.environment.available_items_mask,
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
_, memory, _, _ = self.model.environment.step(actions, actions_emb, self.replay_buffer)
|
|
883
|
+
|
|
884
|
+
if len(self.replay_buffer) > self.batch_size:
|
|
885
|
+
batch = self._get_batch()
|
|
886
|
+
self._run_train_step(batch)
|
|
887
|
+
|
|
888
|
+
if step % self.checkpoint_step == 0 and step > 0:
|
|
889
|
+
self._save_model(self.log_dir / f"model_{step}.pt")
|
|
890
|
+
step += 1
|
|
891
|
+
|
|
892
|
+
self._save_model(self.log_dir / "model_final.pt")
|
|
893
|
+
|
|
894
|
+
def _save_model(self, path: str) -> None:
|
|
895
|
+
self.logger.debug(
|
|
896
|
+
"-- Saving model to file (user_num=%d, item_num=%d)",
|
|
897
|
+
self.user_num,
|
|
898
|
+
self.item_num,
|
|
899
|
+
)
|
|
900
|
+
memory_df = pd.DataFrame(
|
|
901
|
+
self.model.environment.memory.cpu(),
|
|
902
|
+
columns=["item_n", "item_n-1", "item_n-2", "item_n-3", "item_n-4"],
|
|
903
|
+
)
|
|
904
|
+
memory_df.loc[:, "user_id_for_order"] = np.arange(self.user_num)
|
|
905
|
+
self.memory = convert2spark(memory_df)
|
|
906
|
+
|
|
907
|
+
torch.save(
|
|
908
|
+
{
|
|
909
|
+
"actor": self.model.state_dict(),
|
|
910
|
+
"critic": self.value_net.state_dict(),
|
|
911
|
+
"policy_optimizer": self.policy_optimizer.state_dict(),
|
|
912
|
+
"value_optimizer": self.value_optimizer.state_dict(),
|
|
913
|
+
},
|
|
914
|
+
path,
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
def _load_model(self, path: str) -> None:
|
|
918
|
+
self.logger.debug("-- Loading model from file")
|
|
919
|
+
self._init_inner()
|
|
920
|
+
|
|
921
|
+
checkpoint = torch.load(path)
|
|
922
|
+
self.model.load_state_dict(checkpoint["actor"])
|
|
923
|
+
self.value_net.load_state_dict(checkpoint["critic"])
|
|
924
|
+
self.policy_optimizer.load_state_dict(checkpoint["policy_optimizer"])
|
|
925
|
+
self.value_optimizer.load_state_dict(checkpoint["value_optimizer"])
|
|
926
|
+
|
|
927
|
+
self._target_update(self.target_value_net, self.value_net, soft_tau=1)
|
|
928
|
+
self._target_update(self.target_model, self.model, soft_tau=1)
|
|
929
|
+
|
|
930
|
+
memory_df = self.memory.toPandas()
|
|
931
|
+
memory_df = memory_df.sort_values(by="user_id_for_order").drop("user_id_for_order", axis=1)
|
|
932
|
+
self.model.environment.memory = torch.tensor(memory_df.to_numpy()).to(self.device)
|