replay-rec 0.20.2__py3-none-any.whl → 0.20.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. replay/__init__.py +1 -1
  2. replay/data/nn/sequential_dataset.py +8 -2
  3. replay/experimental/__init__.py +0 -0
  4. replay/experimental/metrics/__init__.py +62 -0
  5. replay/experimental/metrics/base_metric.py +603 -0
  6. replay/experimental/metrics/coverage.py +97 -0
  7. replay/experimental/metrics/experiment.py +175 -0
  8. replay/experimental/metrics/hitrate.py +26 -0
  9. replay/experimental/metrics/map.py +30 -0
  10. replay/experimental/metrics/mrr.py +18 -0
  11. replay/experimental/metrics/ncis_precision.py +31 -0
  12. replay/experimental/metrics/ndcg.py +49 -0
  13. replay/experimental/metrics/precision.py +22 -0
  14. replay/experimental/metrics/recall.py +25 -0
  15. replay/experimental/metrics/rocauc.py +49 -0
  16. replay/experimental/metrics/surprisal.py +90 -0
  17. replay/experimental/metrics/unexpectedness.py +76 -0
  18. replay/experimental/models/__init__.py +50 -0
  19. replay/experimental/models/admm_slim.py +257 -0
  20. replay/experimental/models/base_neighbour_rec.py +200 -0
  21. replay/experimental/models/base_rec.py +1386 -0
  22. replay/experimental/models/base_torch_rec.py +234 -0
  23. replay/experimental/models/cql.py +454 -0
  24. replay/experimental/models/ddpg.py +932 -0
  25. replay/experimental/models/dt4rec/__init__.py +0 -0
  26. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  27. replay/experimental/models/dt4rec/gpt1.py +401 -0
  28. replay/experimental/models/dt4rec/trainer.py +127 -0
  29. replay/experimental/models/dt4rec/utils.py +264 -0
  30. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  31. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  32. replay/experimental/models/hierarchical_recommender.py +331 -0
  33. replay/experimental/models/implicit_wrap.py +131 -0
  34. replay/experimental/models/lightfm_wrap.py +303 -0
  35. replay/experimental/models/mult_vae.py +332 -0
  36. replay/experimental/models/neural_ts.py +986 -0
  37. replay/experimental/models/neuromf.py +406 -0
  38. replay/experimental/models/scala_als.py +293 -0
  39. replay/experimental/models/u_lin_ucb.py +115 -0
  40. replay/experimental/nn/data/__init__.py +1 -0
  41. replay/experimental/nn/data/schema_builder.py +102 -0
  42. replay/experimental/preprocessing/__init__.py +3 -0
  43. replay/experimental/preprocessing/data_preparator.py +839 -0
  44. replay/experimental/preprocessing/padder.py +229 -0
  45. replay/experimental/preprocessing/sequence_generator.py +208 -0
  46. replay/experimental/scenarios/__init__.py +1 -0
  47. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  48. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  49. replay/experimental/scenarios/obp_wrapper/replay_offline.py +261 -0
  50. replay/experimental/scenarios/obp_wrapper/utils.py +85 -0
  51. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  52. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  53. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  54. replay/experimental/utils/__init__.py +0 -0
  55. replay/experimental/utils/logger.py +24 -0
  56. replay/experimental/utils/model_handler.py +186 -0
  57. replay/experimental/utils/session_handler.py +44 -0
  58. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/METADATA +11 -17
  59. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/RECORD +62 -7
  60. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/WHEEL +0 -0
  61. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/LICENSE +0 -0
  62. {replay_rec-0.20.2.dist-info → replay_rec-0.20.3rc0.dist-info}/licenses/NOTICE +0 -0
@@ -0,0 +1,932 @@
1
+ from pathlib import Path
2
+ from typing import Any, Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scipy.sparse as sp
7
+ import torch
8
+ import tqdm
9
+ from pytorch_optimizer import Ranger
10
+ from torch import nn
11
+ from torch.distributions.gamma import Gamma
12
+
13
+ from replay.data import get_schema
14
+ from replay.experimental.models.base_torch_rec import Recommender
15
+ from replay.utils import PYSPARK_AVAILABLE, PandasDataFrame, SparkDataFrame
16
+ from replay.utils.spark_utils import convert2spark
17
+
18
+ if PYSPARK_AVAILABLE:
19
+ from pyspark.sql import functions as sf
20
+
21
+
22
+ def to_np(tensor: torch.Tensor) -> np.array:
23
+ """Converts torch.Tensor to numpy."""
24
+ return tensor.detach().cpu().numpy()
25
+
26
+
27
+ class ReplayBuffer:
28
+ """
29
+ Stores transitions for training RL model.
30
+
31
+ Usually transition is (state, action, reward, next_state).
32
+ In this implementation we compute state using embedding of user
33
+ and embeddings of `memory_size` latest relevant items.
34
+ Thereby in this ReplayBuffer we store (user, memory) instead of state.
35
+ """
36
+
37
+ def __init__(self, device, capacity, memory_size, embedding_dim):
38
+ self.capacity = capacity
39
+
40
+ self.buffer = {
41
+ "user": torch.zeros((capacity,), device=device),
42
+ "memory": torch.zeros((capacity, memory_size), device=device),
43
+ "action": torch.zeros((capacity, embedding_dim), device=device),
44
+ "reward": torch.zeros((capacity,), device=device),
45
+ "next_user": torch.zeros((capacity,), device=device),
46
+ "next_memory": torch.zeros((capacity, memory_size), device=device),
47
+ "done": torch.zeros((capacity,), device=device),
48
+ "sample_weight": torch.zeros((capacity,), device=device),
49
+ }
50
+
51
+ self.pos = 0
52
+ self.is_filled = False
53
+
54
+ def push(
55
+ self,
56
+ user,
57
+ memory,
58
+ action,
59
+ reward,
60
+ next_user,
61
+ next_memory,
62
+ done,
63
+ sample_weight,
64
+ ):
65
+ """Add transition to buffer."""
66
+
67
+ batch_size = user.shape[0]
68
+
69
+ self.buffer["user"][self.pos : self.pos + batch_size] = user
70
+ self.buffer["memory"][self.pos : self.pos + batch_size] = memory
71
+ self.buffer["action"][self.pos : self.pos + batch_size] = action
72
+ self.buffer["reward"][self.pos : self.pos + batch_size] = reward
73
+ self.buffer["next_user"][self.pos : self.pos + batch_size] = next_user
74
+ self.buffer["next_memory"][self.pos : self.pos + batch_size] = next_memory
75
+ self.buffer["done"][self.pos : self.pos + batch_size] = done
76
+ self.buffer["sample_weight"][self.pos : self.pos + batch_size] = sample_weight
77
+
78
+ new_pos = self.pos + batch_size
79
+ if new_pos >= self.capacity:
80
+ self.is_filled = True
81
+ self.pos = new_pos % self.capacity
82
+
83
+ def sample(self, batch_size):
84
+ """Sample transition from buffer."""
85
+ current_buffer_len = len(self)
86
+
87
+ indices = np.random.choice(current_buffer_len, batch_size)
88
+
89
+ return {
90
+ "user": self.buffer["user"][indices],
91
+ "memory": self.buffer["memory"][indices],
92
+ "action": self.buffer["action"][indices],
93
+ "reward": self.buffer["reward"][indices],
94
+ "next_user": self.buffer["next_user"][indices],
95
+ "next_memory": self.buffer["next_memory"][indices],
96
+ "done": self.buffer["done"][indices],
97
+ "sample_weight": self.buffer["sample_weight"][indices],
98
+ }
99
+
100
+ def __len__(self):
101
+ return self.capacity if self.is_filled else self.pos + 1
102
+
103
+
104
+ class OUNoise:
105
+ """https://github.com/vitchyr/rlkit/blob/master/rlkit/exploration_strategies/ou_strategy.py"""
106
+
107
+ def __init__(
108
+ self,
109
+ action_dim,
110
+ device,
111
+ theta=0.15,
112
+ max_sigma=0.4,
113
+ min_sigma=0.4,
114
+ noise_type="gauss",
115
+ decay_period=10,
116
+ ):
117
+ self.theta = theta
118
+ self.sigma = max_sigma
119
+ self.max_sigma = max_sigma
120
+ self.min_sigma = min_sigma
121
+ self.decay_period = decay_period
122
+ self.action_dim = action_dim
123
+ self.noise_type = noise_type
124
+ self.device = device
125
+ self.state = torch.zeros((1, action_dim), device=self.device)
126
+
127
+ def reset(self, user_batch_size):
128
+ """Fill state with zeros."""
129
+ if self.state.shape[0] == user_batch_size:
130
+ self.state.fill_(0)
131
+ else:
132
+ self.state = torch.zeros((user_batch_size, self.action_dim), device=self.device)
133
+
134
+ def evolve_state(self):
135
+ """Perform OU discrete approximation step"""
136
+ x = self.state
137
+ d_x = -self.theta * x + self.sigma * torch.randn(x.shape, device=self.device)
138
+ self.state = x + d_x
139
+ return self.state
140
+
141
+ def get_action(self, action, step=0):
142
+ """Get state after applying noise."""
143
+ self.sigma = self.max_sigma - (self.max_sigma - self.min_sigma) * min(1.0, step / self.decay_period)
144
+ if self.noise_type == "ou":
145
+ ou_state = self.evolve_state()
146
+ return action + ou_state
147
+ elif self.noise_type == "gauss":
148
+ return action + self.sigma * torch.randn(action.shape, device=self.device)
149
+ else:
150
+ msg = "noise_type must be one of ['ou', 'gauss']"
151
+ raise ValueError(msg)
152
+
153
+
154
+ class ActorDRR(nn.Module):
155
+ """
156
+ DDPG Actor model (based on `DRR
157
+ <https://arxiv.org/pdf/1802.05814.pdf>`).
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ user_num,
163
+ item_num,
164
+ embedding_dim,
165
+ hidden_dim,
166
+ memory_size,
167
+ env_gamma_alpha,
168
+ device,
169
+ min_trajectory_len,
170
+ ):
171
+ super().__init__()
172
+ self.layers = nn.Sequential(
173
+ nn.Linear(embedding_dim * 3, hidden_dim),
174
+ nn.LayerNorm(hidden_dim),
175
+ nn.ReLU(),
176
+ nn.Linear(hidden_dim, embedding_dim),
177
+ )
178
+
179
+ self.state_repr = StateReprModule(user_num, item_num, embedding_dim, memory_size)
180
+
181
+ self.initialize()
182
+
183
+ self.environment = Env(
184
+ item_num,
185
+ user_num,
186
+ memory_size,
187
+ env_gamma_alpha,
188
+ device,
189
+ min_trajectory_len,
190
+ )
191
+
192
+ def initialize(self):
193
+ """weight init"""
194
+ for layer in self.layers:
195
+ if isinstance(layer, nn.Linear):
196
+ nn.init.kaiming_uniform_(layer.weight)
197
+
198
+ def forward(self, user, memory):
199
+ """
200
+ :param user: user batch
201
+ :param memory: memory batch
202
+ :return: output, vector of the size `embedding_dim`
203
+ """
204
+ state = self.state_repr(user, memory)
205
+ return self.layers(state)
206
+
207
+ def get_action(self, action_emb, items, items_mask, return_scores=False):
208
+ """
209
+ :param action_emb: output of the .forward() (user_batch_size x emb_dim)
210
+ :param items: items batch (user_batch_size x items_num)
211
+ :param items_mask: mask of available items for reccomendation (user_batch_size x items_num)
212
+ :param return_scores: whether to return scores of items
213
+ :return: output, prediction (and scores if return_scores)
214
+ """
215
+
216
+ assert items.shape == items_mask.shape
217
+
218
+ items = self.state_repr.item_embeddings(items) # B x i x emb_dim
219
+ scores = torch.bmm(
220
+ items,
221
+ action_emb.unsqueeze(-1), # B x emb_dim x 1
222
+ ).squeeze(-1)
223
+
224
+ assert scores.shape == items_mask.shape
225
+
226
+ scores = scores * items_mask
227
+
228
+ if return_scores:
229
+ return scores, torch.argmax(scores, dim=1)
230
+ else:
231
+ return torch.argmax(scores, dim=1)
232
+
233
+
234
+ class CriticDRR(nn.Module):
235
+ """
236
+ DDPG Critic model (based on `DRR
237
+ <https://arxiv.org/pdf/1802.05814.pdf>`
238
+ and `Bayes-UCBDQN <https://arxiv.org/pdf/2205.07704.pdf>`).
239
+ """
240
+
241
+ def __init__(self, state_repr_dim, action_emb_dim, hidden_dim, heads_num, heads_q):
242
+ """
243
+ :param heads_num: number of heads (samples of Q funtion)
244
+ :param heads_q: quantile of Q function distribution
245
+ """
246
+ super().__init__()
247
+ self.layers = nn.Sequential(
248
+ nn.Linear(state_repr_dim + action_emb_dim, hidden_dim),
249
+ nn.LayerNorm(hidden_dim),
250
+ nn.ReLU(),
251
+ )
252
+
253
+ self.heads = nn.ModuleList([nn.Linear(hidden_dim, 1) for _ in range(heads_num)])
254
+ self.heads_q = heads_q
255
+
256
+ self.initialize()
257
+
258
+ def initialize(self):
259
+ """weight init"""
260
+ for layer in self.layers:
261
+ if isinstance(layer, nn.Linear):
262
+ nn.init.kaiming_uniform_(layer.weight)
263
+
264
+ for head in self.heads:
265
+ nn.init.kaiming_uniform_(head.weight)
266
+
267
+ def forward(self, state, action):
268
+ """
269
+ :param state: state batch
270
+ :param action: action batch
271
+ :return: x, Q values for given states and actions
272
+ """
273
+ x = torch.cat([state, action], 1)
274
+ out = self.layers(x)
275
+ heads_out = torch.stack([head(out) for head in self.heads])
276
+ out = torch.quantile(heads_out, self.heads_q, dim=0)
277
+
278
+ return out
279
+
280
+
281
+ class Env:
282
+ """
283
+ RL environment for recommender systems.
284
+ Simulates interacting with a batch of users
285
+
286
+ Keep users' latest relevant items (memory).
287
+
288
+ :param item_count: total number of items
289
+ :param user_count: total number of users
290
+ :param memory_size: maximum number of items in memory
291
+ :param memory: torch.tensor with users' latest relevant items
292
+ :param matrix: sparse matrix with users-item ratings
293
+ :param user_ids: user ids from the batch
294
+ :param related_items: relevant items for current users
295
+ :param nonrelated_items: non-relevant items for current users
296
+ :param max_num_rele: maximum number of related items by users in the batch
297
+ :param available_items: items available for recommendation
298
+ :param available_items_mask: mask of non-seen items
299
+ :param gamma: param of Gamma distibution for sample weights
300
+ """
301
+
302
+ matrix: np.array
303
+ related_items: torch.Tensor
304
+ nonrelated_items: torch.Tensor
305
+ available_items: torch.Tensor # B x i
306
+ available_items_mask: torch.Tensor # B x i
307
+ user_id: torch.Tensor # batch of users B x i
308
+ num_rele: int
309
+
310
+ def __init__(
311
+ self,
312
+ item_count,
313
+ user_count,
314
+ memory_size,
315
+ gamma_alpha,
316
+ device,
317
+ min_trajectory_len,
318
+ ):
319
+ """
320
+ Initialize memory as ['item_num'] * 'memory_size' for each user.
321
+
322
+ 'item_num' is a padding index in StateReprModule.
323
+ It will result in zero embeddings.
324
+ """
325
+ self.item_count = item_count
326
+ self.user_count = user_count
327
+ self.memory_size = memory_size
328
+ self.device = device
329
+ self.gamma = Gamma(
330
+ torch.tensor([float(gamma_alpha)]),
331
+ torch.tensor([1 / float(gamma_alpha)]),
332
+ )
333
+ self.memory = torch.full((user_count, memory_size), item_count, device=device)
334
+ self.min_trajectory_len = min_trajectory_len
335
+ self.max_num_rele = None
336
+ self.user_batch_size = None
337
+ self.user_ids = None
338
+
339
+ def update_env(self, matrix=None, item_count=None):
340
+ """Update some of Env attributes."""
341
+ if item_count is not None:
342
+ self.item_count = item_count
343
+ if matrix is not None:
344
+ self.matrix = matrix.copy()
345
+
346
+ def reset(self, user_ids):
347
+ """
348
+ :param user_id: batch of user ids
349
+ :return: user, memory
350
+ """
351
+ self.user_batch_size = len(user_ids)
352
+
353
+ self.user_ids = torch.tensor(user_ids, dtype=torch.int64, device=self.device)
354
+
355
+ self.max_num_rele = max((self.matrix[user_ids] > 0).sum(1).max(), self.min_trajectory_len)
356
+ self.available_items = torch.zeros(
357
+ (self.user_batch_size, 2 * self.max_num_rele),
358
+ dtype=torch.int64,
359
+ device=self.device,
360
+ )
361
+ self.available_items_mask = torch.ones_like(self.available_items, device=self.device)
362
+
363
+ # padding with non-existent items
364
+ self.related_items = torch.full(
365
+ (self.user_batch_size, self.max_num_rele),
366
+ -1, # maybe define new constant
367
+ device=self.device,
368
+ )
369
+
370
+ for idx, user_id in enumerate(user_ids):
371
+ user_related_items = torch.tensor(np.argwhere(self.matrix[user_id] > 0)[:, 1], device=self.device)
372
+
373
+ user_num_rele = len(user_related_items)
374
+
375
+ self.related_items[idx, :user_num_rele] = user_related_items
376
+
377
+ replace = bool(2 * self.max_num_rele > self.item_count)
378
+
379
+ nonrelated_items = torch.tensor(
380
+ np.random.choice(
381
+ list(set(range(self.item_count + 1)) - set(user_related_items.tolist())),
382
+ replace=replace,
383
+ size=2 * self.max_num_rele - user_num_rele,
384
+ )
385
+ ).to(self.device)
386
+
387
+ self.available_items[idx, :user_num_rele] = user_related_items
388
+ self.available_items[idx, user_num_rele:] = nonrelated_items
389
+ self.available_items[self.available_items == -1] = self.item_count
390
+ perm = torch.randperm(self.available_items.shape[1])
391
+ self.available_items[idx] = self.available_items[idx, perm]
392
+
393
+ return self.user_ids, self.memory[self.user_ids]
394
+
395
+ def step(self, actions, actions_emb=None, buffer: ReplayBuffer = None):
396
+ """Execute step and return (user, memory) for new state"""
397
+ initial_users = self.user_ids
398
+ initial_memory = self.memory[self.user_ids].clone()
399
+
400
+ global_actions = self.available_items[torch.arange(self.available_items.shape[0]), actions]
401
+ rewards = (global_actions.reshape(-1, 1) == self.related_items).sum(1)
402
+ for idx, reward in enumerate(rewards):
403
+ if reward:
404
+ user_id = self.user_ids[idx]
405
+ self.memory[user_id] = torch.tensor([*self.memory[user_id][1:], global_actions[idx]])
406
+
407
+ self.available_items_mask[torch.arange(self.available_items_mask.shape[0]), actions] = 0
408
+
409
+ if buffer is not None:
410
+ sample_weight = self.gamma.sample((self.user_batch_size,)).squeeze().detach().to(self.device)
411
+ buffer.push(
412
+ initial_users.detach(),
413
+ initial_memory.detach(),
414
+ actions_emb.detach(),
415
+ rewards.detach(),
416
+ self.user_ids.detach(),
417
+ self.memory[self.user_ids].detach(),
418
+ rewards.detach(),
419
+ sample_weight,
420
+ )
421
+
422
+ return self.user_ids, self.memory[self.user_ids], rewards, 0
423
+
424
+
425
+ class StateReprModule(nn.Module):
426
+ """
427
+ Compute state for RL environment. Based on `DRR paper
428
+ <https://arxiv.org/pdf/1810.12027.pdf>`_
429
+
430
+ Computes State is a concatenation of user embedding,
431
+ weighted average pooling of `memory_size` latest relevant items
432
+ and their pairwise product.
433
+ """
434
+
435
+ def __init__(
436
+ self,
437
+ user_num,
438
+ item_num,
439
+ embedding_dim,
440
+ memory_size,
441
+ ):
442
+ super().__init__()
443
+ self.user_embeddings = nn.Embedding(user_num, embedding_dim)
444
+
445
+ self.item_embeddings = nn.Embedding(item_num + 1, embedding_dim, padding_idx=int(item_num))
446
+
447
+ self.drr_ave = torch.nn.Conv1d(in_channels=memory_size, out_channels=1, kernel_size=1)
448
+
449
+ self.initialize()
450
+
451
+ def initialize(self):
452
+ """weight init"""
453
+ nn.init.normal_(self.user_embeddings.weight, std=0.01)
454
+ self.item_embeddings.weight.data[-1].zero_()
455
+
456
+ nn.init.normal_(self.item_embeddings.weight, std=0.01)
457
+ nn.init.uniform_(self.drr_ave.weight)
458
+
459
+ self.drr_ave.bias.data.zero_()
460
+
461
+ def forward(self, user, memory):
462
+ """
463
+ :param user: user batch
464
+ :param memory: memory batch
465
+ :return: vector of dimension 3 * embedding_dim
466
+ """
467
+ user_embedding = self.user_embeddings(user.long())
468
+
469
+ item_embeddings = self.item_embeddings(memory.long())
470
+ drr_ave = self.drr_ave(item_embeddings).squeeze(1)
471
+
472
+ return torch.cat((user_embedding, user_embedding * drr_ave, drr_ave), 1)
473
+
474
+
475
+ class DDPG(Recommender):
476
+ """
477
+ `Deep Deterministic Policy Gradient
478
+ <https://arxiv.org/pdf/1810.12027.pdf>`_
479
+
480
+ This implementation enhanced by more advanced noise strategy.
481
+ """
482
+
483
+ batch_size: int = 512
484
+ embedding_dim: int = 8
485
+ hidden_dim: int = 16
486
+ value_lr: float = 1e-5
487
+ value_decay: float = 1e-5
488
+ policy_lr: float = 1e-5
489
+ policy_decay: float = 1e-6
490
+ gamma: float = 0.8
491
+ memory_size: int = 5
492
+ min_value: int = -10
493
+ max_value: int = 10
494
+ buffer_size: int = 1000000
495
+ _search_space = {
496
+ "noise_sigma": {"type": "uniform", "args": [0.1, 0.6]},
497
+ "gamma": {"type": "uniform", "args": [0.7, 1.0]},
498
+ "value_lr": {"type": "loguniform", "args": [1e-7, 1e-1]},
499
+ "value_decay": {"type": "loguniform", "args": [1e-7, 1e-1]},
500
+ "policy_lr": {"type": "loguniform", "args": [1e-7, 1e-1]},
501
+ "policy_decay": {"type": "loguniform", "args": [1e-7, 1e-1]},
502
+ "memory_size": {"type": "categorical", "args": [3, 5, 7, 9]},
503
+ "noise_type": {"type": "categorical", "args": ["gauss", "ou"]},
504
+ }
505
+ checkpoint_step: int = 10000
506
+ replay_buffer: ReplayBuffer
507
+ ou_noise: OUNoise
508
+ model: ActorDRR
509
+ target_model: ActorDRR
510
+ value_net: CriticDRR
511
+ target_value_net: CriticDRR
512
+ policy_optimizer: Ranger
513
+ value_optimizer: Ranger
514
+
515
+ def __init__(
516
+ self,
517
+ noise_sigma: float = 0.2,
518
+ noise_theta: float = 0.05,
519
+ noise_type: str = "gauss",
520
+ seed: int = 9,
521
+ user_num: int = 10,
522
+ item_num: int = 10,
523
+ log_dir: str = "logs/tmp",
524
+ exact_embeddings_size=True,
525
+ n_critics_head: int = 10,
526
+ env_gamma_alpha: float = 0.2,
527
+ critic_heads_q: float = 0.15,
528
+ n_jobs=None,
529
+ use_gpu=False,
530
+ user_batch_size: int = 8,
531
+ min_trajectory_len: int = 10,
532
+ ):
533
+ """
534
+ :param noise_sigma: Ornstein-Uhlenbeck noise sigma value
535
+ :param noise_theta: Ornstein-Uhlenbeck noise theta value
536
+ :param noise_type: type of action noise, one of ["ou", "gauss"]
537
+ :param seed: random seed
538
+ :param user_num: number of users, specify when using ``exact_embeddings_size``
539
+ :param item_num: number of items, specify when using ``exact_embeddings_size``
540
+ :param log_dir: dir to save models
541
+ :exact_embeddings_size: flag whether to set user/item_num from training log
542
+ """
543
+ super().__init__()
544
+ np.random.seed(seed)
545
+ torch.manual_seed(seed)
546
+
547
+ self.noise_theta = noise_theta
548
+ self.noise_sigma = noise_sigma
549
+ self.noise_type = noise_type
550
+ self.seed = seed
551
+ self.user_num = user_num
552
+ self.item_num = item_num
553
+ self.log_dir = Path(log_dir)
554
+ self.exact_embeddings_size = exact_embeddings_size
555
+ self.n_critics_head = n_critics_head
556
+ self.env_gamma_alpha = env_gamma_alpha
557
+ self.critic_heads_q = critic_heads_q
558
+ self.user_batch_size = user_batch_size
559
+ self.min_trajectory_len = min_trajectory_len
560
+
561
+ self.memory = None
562
+ self.fit_users = None
563
+ self.fit_items = None
564
+
565
+ if n_jobs is not None:
566
+ torch.set_num_threads(n_jobs)
567
+
568
+ if use_gpu:
569
+ use_cuda = torch.cuda.is_available()
570
+ if use_cuda:
571
+ self.device = torch.device("cuda")
572
+ else:
573
+ self.device = torch.device("cpu")
574
+ else:
575
+ self.device = torch.device("cpu")
576
+
577
+ @property
578
+ def _init_args(self):
579
+ return {
580
+ "noise_sigma": self.noise_sigma,
581
+ "noise_theta": self.noise_theta,
582
+ "noise_type": self.noise_type,
583
+ "seed": self.seed,
584
+ "user_num": self.user_num,
585
+ "item_num": self.item_num,
586
+ "exact_embeddings_size": self.exact_embeddings_size,
587
+ }
588
+
589
+ @property
590
+ def _dataframes(self):
591
+ return {
592
+ "memory": self.memory,
593
+ "fit_users": self.fit_users,
594
+ "fit_items": self.fit_items,
595
+ }
596
+
597
+ def _batch_pass(self, batch: dict) -> dict[str, Any]:
598
+ user = batch["user"]
599
+ memory = batch["memory"]
600
+ action = batch["action"]
601
+ reward = batch["reward"]
602
+ next_user = batch["next_user"]
603
+ next_memory = batch["next_memory"]
604
+ done = batch["done"]
605
+ sample_weight = batch["sample_weight"]
606
+ state = self.model.state_repr(user, memory)
607
+
608
+ with torch.no_grad():
609
+ next_state = self.model.state_repr(next_user, next_memory)
610
+ next_action = self.target_model(next_user, next_memory)
611
+ target_value = self.target_value_net(next_state, next_action.detach())
612
+ expected_value = reward + (1.0 - done) * self.gamma * target_value.squeeze(1) # smth strange, check article
613
+ expected_value = torch.clamp(expected_value, self.min_value, self.max_value)
614
+
615
+ proto_action = self.model.layers(state)
616
+ policy_loss = -self.value_net(state.detach(), proto_action).mean()
617
+
618
+ value = self.value_net(state, action)
619
+ value_loss = ((value - expected_value.detach()).pow(2) * sample_weight).squeeze(1).mean()
620
+ return policy_loss, value_loss
621
+
622
+ @staticmethod
623
+ def _predict_pairs_inner(
624
+ model,
625
+ user_idx: int,
626
+ items_np: np.ndarray,
627
+ ) -> SparkDataFrame:
628
+ with torch.no_grad():
629
+ user_batch = torch.tensor([user_idx], dtype=torch.int64)
630
+ memory = model.environment.memory[user_batch]
631
+ action_emb = model(user_batch, memory)
632
+ items = torch.tensor(items_np, dtype=torch.int64).unsqueeze(0)
633
+ scores, _ = model.get_action(action_emb, items, torch.full_like(items, True), True)
634
+ scores = scores.squeeze()
635
+ return PandasDataFrame(
636
+ {
637
+ "user_idx": scores.shape[0] * [user_idx],
638
+ "item_idx": items_np,
639
+ "relevance": scores,
640
+ }
641
+ )
642
+
643
+ def _predict(
644
+ self,
645
+ log: SparkDataFrame,
646
+ k: int, # noqa: ARG002
647
+ users: SparkDataFrame,
648
+ items: SparkDataFrame,
649
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
650
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
651
+ filter_seen_items: bool = True, # noqa: ARG002
652
+ ) -> SparkDataFrame:
653
+ items_consider_in_pred = items.toPandas()["item_idx"].values
654
+ model = self.model.cpu()
655
+
656
+ def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
657
+ return DDPG._predict_pairs_inner(
658
+ model=model,
659
+ user_idx=pandas_df["user_idx"][0],
660
+ items_np=items_consider_in_pred,
661
+ )[["user_idx", "item_idx", "relevance"]]
662
+
663
+ self.logger.debug("Predict started")
664
+ rec_schema = get_schema(
665
+ query_column="user_idx",
666
+ item_column="item_idx",
667
+ rating_column="relevance",
668
+ has_timestamp=False,
669
+ )
670
+ recs = (
671
+ users.join(log, how="left", on="user_idx")
672
+ .select("user_idx", "item_idx")
673
+ .groupby("user_idx")
674
+ .applyInPandas(grouped_map, rec_schema)
675
+ )
676
+ return recs
677
+
678
+ def _predict_pairs(
679
+ self,
680
+ pairs: SparkDataFrame,
681
+ log: Optional[SparkDataFrame] = None,
682
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
683
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
684
+ ) -> SparkDataFrame:
685
+ model = self.model.cpu()
686
+
687
+ def grouped_map(pandas_df: PandasDataFrame) -> PandasDataFrame:
688
+ return DDPG._predict_pairs_inner(
689
+ model=model,
690
+ user_idx=pandas_df["user_idx"][0],
691
+ items_np=np.array(pandas_df["item_idx_to_pred"][0]),
692
+ )
693
+
694
+ self.logger.debug("Calculate relevance for user-item pairs")
695
+
696
+ rec_schema = get_schema(
697
+ query_column="user_idx",
698
+ item_column="item_idx",
699
+ rating_column="relevance",
700
+ has_timestamp=False,
701
+ )
702
+ recs = (
703
+ pairs.groupBy("user_idx")
704
+ .agg(sf.collect_list("item_idx").alias("item_idx_to_pred"))
705
+ .join(log.select("user_idx").distinct(), on="user_idx", how="inner")
706
+ .groupby("user_idx")
707
+ .applyInPandas(grouped_map, rec_schema)
708
+ )
709
+
710
+ return recs
711
+
712
+ @staticmethod
713
+ def _preprocess_df(data):
714
+ """
715
+ :param data: pandas DataFrame
716
+ """
717
+ data = data[["user_idx", "item_idx", "relevance"]]
718
+ users = data["user_idx"].values
719
+ items = data["item_idx"].values
720
+ scores = data["relevance"].values
721
+
722
+ user_num = int(max(users)) + 1
723
+ item_num = int(max(items)) + 1
724
+
725
+ train_matrix = sp.dok_matrix((user_num, item_num), dtype=np.float32)
726
+ for user, item, rel in zip(users, items, scores):
727
+ train_matrix[user, item] = rel
728
+
729
+ appropriate_users = data["user_idx"].unique()
730
+
731
+ return train_matrix, user_num, item_num, appropriate_users
732
+
733
+ @staticmethod
734
+ def _preprocess_log(log):
735
+ return DDPG._preprocess_df(log.toPandas())
736
+
737
+ def _get_batch(self) -> dict:
738
+ batch = self.replay_buffer.sample(self.batch_size)
739
+ return batch
740
+
741
+ def _run_train_step(self, batch: dict) -> None:
742
+ policy_loss, value_loss = self._batch_pass(batch)
743
+ total_loss = policy_loss + value_loss
744
+
745
+ self.policy_optimizer.zero_grad()
746
+ self.value_optimizer.zero_grad()
747
+ total_loss.backward()
748
+ self.policy_optimizer.step()
749
+ self.value_optimizer.step()
750
+
751
+ self._target_update(self.target_value_net, self.value_net)
752
+ self._target_update(self.target_model, self.model)
753
+
754
+ @staticmethod
755
+ def _target_update(target_net, net, soft_tau=1e-3):
756
+ for target_param, param in zip(target_net.parameters(), net.parameters()):
757
+ target_param.data.copy_(target_param.data * (1.0 - soft_tau) + param.data * soft_tau)
758
+
759
+ def _init_inner(self):
760
+ self.replay_buffer = ReplayBuffer(
761
+ self.device,
762
+ self.buffer_size,
763
+ memory_size=self.memory_size,
764
+ embedding_dim=self.embedding_dim,
765
+ )
766
+
767
+ self.ou_noise = OUNoise(
768
+ self.embedding_dim,
769
+ device=self.device,
770
+ theta=self.noise_theta,
771
+ max_sigma=self.noise_sigma,
772
+ min_sigma=self.noise_sigma,
773
+ noise_type=self.noise_type,
774
+ )
775
+
776
+ self.model = ActorDRR(
777
+ self.user_num,
778
+ self.item_num,
779
+ self.embedding_dim,
780
+ self.hidden_dim,
781
+ self.memory_size,
782
+ env_gamma_alpha=self.env_gamma_alpha,
783
+ device=self.device,
784
+ min_trajectory_len=self.min_trajectory_len,
785
+ ).to(self.device)
786
+
787
+ self.target_model = ActorDRR(
788
+ self.user_num,
789
+ self.item_num,
790
+ self.embedding_dim,
791
+ self.hidden_dim,
792
+ self.memory_size,
793
+ env_gamma_alpha=self.env_gamma_alpha,
794
+ device=self.device,
795
+ min_trajectory_len=self.min_trajectory_len,
796
+ ).to(self.device)
797
+
798
+ self.value_net = CriticDRR(
799
+ self.embedding_dim * 3,
800
+ self.embedding_dim,
801
+ self.hidden_dim,
802
+ heads_num=self.n_critics_head,
803
+ heads_q=self.critic_heads_q,
804
+ ).to(self.device)
805
+
806
+ self.target_value_net = CriticDRR(
807
+ self.embedding_dim * 3,
808
+ self.embedding_dim,
809
+ self.hidden_dim,
810
+ heads_num=self.n_critics_head,
811
+ heads_q=self.critic_heads_q,
812
+ ).to(self.device)
813
+
814
+ self._target_update(self.target_value_net, self.value_net, soft_tau=1)
815
+ self._target_update(self.target_model, self.model, soft_tau=1)
816
+ self.policy_optimizer = Ranger(
817
+ self.model.parameters(),
818
+ lr=self.policy_lr,
819
+ weight_decay=self.policy_decay,
820
+ )
821
+ self.value_optimizer = Ranger(
822
+ self.value_net.parameters(),
823
+ lr=self.value_lr,
824
+ weight_decay=self.value_decay,
825
+ )
826
+
827
+ def _fit(
828
+ self,
829
+ log: SparkDataFrame,
830
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
831
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
832
+ ) -> None:
833
+ data = log.toPandas()
834
+ self._fit_df(data)
835
+
836
+ def _fit_df(self, data):
837
+ train_matrix, user_num, item_num, users = self._preprocess_df(data)
838
+
839
+ if self.exact_embeddings_size:
840
+ self.user_num = user_num
841
+ self.item_num = item_num
842
+ self._init_inner()
843
+
844
+ self.model.environment.update_env(matrix=train_matrix)
845
+ users = np.random.permutation(users)
846
+
847
+ self.logger.debug("Training DDPG")
848
+ self.train(users)
849
+
850
+ @staticmethod
851
+ def users_loader(users, batch_size):
852
+ """loader for users' batch"""
853
+ pos = 0
854
+ while pos != len(users):
855
+ new_pos = min(pos + batch_size, len(users))
856
+ yield users[pos:new_pos]
857
+ pos = new_pos
858
+
859
+ def train(self, users: np.array) -> None:
860
+ """
861
+ Run training loop
862
+
863
+ :param users: array with users for training
864
+ :return:
865
+ """
866
+ self.log_dir.mkdir(parents=True, exist_ok=True)
867
+ step = 0
868
+ users_loader = self.users_loader(users, self.user_batch_size)
869
+ for user_ids in tqdm.auto.tqdm(list(users_loader)):
870
+ user_ids, memory = self.model.environment.reset(user_ids)
871
+ self.ou_noise.reset(user_ids.shape[0])
872
+ for users_step in range(self.model.environment.max_num_rele):
873
+ actions_emb = self.model(user_ids, memory)
874
+ actions_emb = self.ou_noise.get_action(actions_emb, users_step)
875
+
876
+ actions = self.model.get_action(
877
+ actions_emb,
878
+ self.model.environment.available_items,
879
+ self.model.environment.available_items_mask,
880
+ )
881
+
882
+ _, memory, _, _ = self.model.environment.step(actions, actions_emb, self.replay_buffer)
883
+
884
+ if len(self.replay_buffer) > self.batch_size:
885
+ batch = self._get_batch()
886
+ self._run_train_step(batch)
887
+
888
+ if step % self.checkpoint_step == 0 and step > 0:
889
+ self._save_model(self.log_dir / f"model_{step}.pt")
890
+ step += 1
891
+
892
+ self._save_model(self.log_dir / "model_final.pt")
893
+
894
+ def _save_model(self, path: str) -> None:
895
+ self.logger.debug(
896
+ "-- Saving model to file (user_num=%d, item_num=%d)",
897
+ self.user_num,
898
+ self.item_num,
899
+ )
900
+ memory_df = pd.DataFrame(
901
+ self.model.environment.memory.cpu(),
902
+ columns=["item_n", "item_n-1", "item_n-2", "item_n-3", "item_n-4"],
903
+ )
904
+ memory_df.loc[:, "user_id_for_order"] = np.arange(self.user_num)
905
+ self.memory = convert2spark(memory_df)
906
+
907
+ torch.save(
908
+ {
909
+ "actor": self.model.state_dict(),
910
+ "critic": self.value_net.state_dict(),
911
+ "policy_optimizer": self.policy_optimizer.state_dict(),
912
+ "value_optimizer": self.value_optimizer.state_dict(),
913
+ },
914
+ path,
915
+ )
916
+
917
+ def _load_model(self, path: str) -> None:
918
+ self.logger.debug("-- Loading model from file")
919
+ self._init_inner()
920
+
921
+ checkpoint = torch.load(path)
922
+ self.model.load_state_dict(checkpoint["actor"])
923
+ self.value_net.load_state_dict(checkpoint["critic"])
924
+ self.policy_optimizer.load_state_dict(checkpoint["policy_optimizer"])
925
+ self.value_optimizer.load_state_dict(checkpoint["value_optimizer"])
926
+
927
+ self._target_update(self.target_value_net, self.value_net, soft_tau=1)
928
+ self._target_update(self.target_model, self.model, soft_tau=1)
929
+
930
+ memory_df = self.memory.toPandas()
931
+ memory_df = memory_df.sort_values(by="user_id_for_order").drop("user_id_for_order", axis=1)
932
+ self.model.environment.memory = torch.tensor(memory_df.to_numpy()).to(self.device)