replay-rec 0.18.0__py3-none-any.whl → 0.18.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. replay/__init__.py +1 -1
  2. replay/experimental/__init__.py +0 -0
  3. replay/experimental/metrics/__init__.py +62 -0
  4. replay/experimental/metrics/base_metric.py +602 -0
  5. replay/experimental/metrics/coverage.py +97 -0
  6. replay/experimental/metrics/experiment.py +175 -0
  7. replay/experimental/metrics/hitrate.py +26 -0
  8. replay/experimental/metrics/map.py +30 -0
  9. replay/experimental/metrics/mrr.py +18 -0
  10. replay/experimental/metrics/ncis_precision.py +31 -0
  11. replay/experimental/metrics/ndcg.py +49 -0
  12. replay/experimental/metrics/precision.py +22 -0
  13. replay/experimental/metrics/recall.py +25 -0
  14. replay/experimental/metrics/rocauc.py +49 -0
  15. replay/experimental/metrics/surprisal.py +90 -0
  16. replay/experimental/metrics/unexpectedness.py +76 -0
  17. replay/experimental/models/__init__.py +10 -0
  18. replay/experimental/models/admm_slim.py +205 -0
  19. replay/experimental/models/base_neighbour_rec.py +204 -0
  20. replay/experimental/models/base_rec.py +1271 -0
  21. replay/experimental/models/base_torch_rec.py +234 -0
  22. replay/experimental/models/cql.py +454 -0
  23. replay/experimental/models/ddpg.py +923 -0
  24. replay/experimental/models/dt4rec/__init__.py +0 -0
  25. replay/experimental/models/dt4rec/dt4rec.py +189 -0
  26. replay/experimental/models/dt4rec/gpt1.py +401 -0
  27. replay/experimental/models/dt4rec/trainer.py +127 -0
  28. replay/experimental/models/dt4rec/utils.py +265 -0
  29. replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
  30. replay/experimental/models/extensions/spark_custom_models/als_extension.py +792 -0
  31. replay/experimental/models/implicit_wrap.py +131 -0
  32. replay/experimental/models/lightfm_wrap.py +302 -0
  33. replay/experimental/models/mult_vae.py +332 -0
  34. replay/experimental/models/neuromf.py +406 -0
  35. replay/experimental/models/scala_als.py +296 -0
  36. replay/experimental/nn/data/__init__.py +1 -0
  37. replay/experimental/nn/data/schema_builder.py +55 -0
  38. replay/experimental/preprocessing/__init__.py +3 -0
  39. replay/experimental/preprocessing/data_preparator.py +839 -0
  40. replay/experimental/preprocessing/padder.py +229 -0
  41. replay/experimental/preprocessing/sequence_generator.py +208 -0
  42. replay/experimental/scenarios/__init__.py +1 -0
  43. replay/experimental/scenarios/obp_wrapper/__init__.py +8 -0
  44. replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +74 -0
  45. replay/experimental/scenarios/obp_wrapper/replay_offline.py +248 -0
  46. replay/experimental/scenarios/obp_wrapper/utils.py +87 -0
  47. replay/experimental/scenarios/two_stages/__init__.py +0 -0
  48. replay/experimental/scenarios/two_stages/reranker.py +117 -0
  49. replay/experimental/scenarios/two_stages/two_stages_scenario.py +757 -0
  50. replay/experimental/utils/__init__.py +0 -0
  51. replay/experimental/utils/logger.py +24 -0
  52. replay/experimental/utils/model_handler.py +186 -0
  53. replay/experimental/utils/session_handler.py +44 -0
  54. {replay_rec-0.18.0.dist-info → replay_rec-0.18.0rc0.dist-info}/METADATA +11 -3
  55. replay_rec-0.18.0rc0.dist-info/NOTICE +41 -0
  56. {replay_rec-0.18.0.dist-info → replay_rec-0.18.0rc0.dist-info}/RECORD +58 -5
  57. {replay_rec-0.18.0.dist-info → replay_rec-0.18.0rc0.dist-info}/WHEEL +1 -1
  58. {replay_rec-0.18.0.dist-info → replay_rec-0.18.0rc0.dist-info}/LICENSE +0 -0
File without changes
@@ -0,0 +1,189 @@
1
+ from typing import List, Optional
2
+
3
+ import pandas as pd
4
+ from tqdm import tqdm
5
+
6
+ from replay.experimental.models.base_rec import Recommender
7
+ from replay.utils import PYSPARK_AVAILABLE, TORCH_AVAILABLE, SparkDataFrame
8
+
9
+ if PYSPARK_AVAILABLE:
10
+ from replay.utils.spark_utils import convert2spark
11
+
12
+ if TORCH_AVAILABLE:
13
+ import torch
14
+ from torch.utils.data.dataloader import DataLoader
15
+
16
+ from .gpt1 import GPT, GPTConfig
17
+ from .trainer import Trainer, TrainerConfig
18
+ from .utils import (
19
+ Collator,
20
+ StateActionReturnDataset,
21
+ ValidateDataset,
22
+ WarmUpScheduler,
23
+ create_dataset,
24
+ matrix2df,
25
+ set_seed,
26
+ )
27
+
28
+
29
+ class DT4Rec(Recommender):
30
+ """
31
+ Decision Transformer for Recommendations
32
+
33
+ General Idea:
34
+ `Decision Transformer: Reinforcement Learning
35
+ via Sequence Modeling <https://arxiv.org/pdf/2106.01345.pdf>`_.
36
+
37
+ Ideas for improvements:
38
+ `User Retention-oriented Recommendation with Decision
39
+ Transformer <https://arxiv.org/pdf/2303.06347.pdf>`_.
40
+
41
+ Also, some sources are listed in their respective classes
42
+ """
43
+
44
+ optimizer = None
45
+ train_batch_size = 128
46
+ val_batch_size = 128
47
+ lr_scheduler = None
48
+
49
+ def __init__(
50
+ self,
51
+ item_num,
52
+ user_num,
53
+ seed=123,
54
+ trajectory_len=30,
55
+ epochs=1,
56
+ batch_size=64,
57
+ use_cuda=True,
58
+ ):
59
+ self.item_num = item_num
60
+ self.user_num = user_num
61
+ self.seed = seed
62
+ self.trajectory_len = trajectory_len
63
+ self.epochs = epochs
64
+ self.batch_size = batch_size
65
+ self.tconf: TrainerConfig = TrainerConfig(epochs=epochs)
66
+ self.mconf: GPTConfig = GPTConfig(
67
+ user_num=user_num,
68
+ item_num=item_num,
69
+ vocab_size=self.item_num + 1,
70
+ block_size=self.trajectory_len * 3,
71
+ max_timestep=self.item_num,
72
+ )
73
+ self.model: GPT
74
+ self.user_trajectory: List
75
+ self.trainer: Trainer
76
+ self.use_cuda = use_cuda
77
+ set_seed(self.seed)
78
+
79
+ def _init_args(self):
80
+ pass
81
+
82
+ def _update_mconf(self, **kwargs):
83
+ self.mconf.update(**kwargs)
84
+
85
+ def _update_tconf(self, **kwargs):
86
+ self.tconf.update(**kwargs)
87
+
88
+ def _make_prediction_dataloader(self, users, items, max_context_len=30):
89
+ val_dataset = ValidateDataset(
90
+ self.user_trajectory,
91
+ max_context_len=max_context_len - 1,
92
+ val_items=items,
93
+ val_users=users,
94
+ )
95
+
96
+ val_dataloader = DataLoader(
97
+ val_dataset,
98
+ pin_memory=True,
99
+ batch_size=self.val_batch_size,
100
+ collate_fn=Collator(self.item_num),
101
+ )
102
+
103
+ return val_dataloader
104
+
105
+ def train(
106
+ self,
107
+ log,
108
+ val_users=None,
109
+ val_items=None,
110
+ experiment=None,
111
+ ):
112
+ """
113
+ Run training loop
114
+ """
115
+ assert (val_users is None) == (val_items is None) == (experiment is None)
116
+ with_validate = experiment is not None
117
+ df = log.toPandas()[["user_idx", "item_idx", "relevance", "timestamp"]]
118
+ self.user_trajectory = create_dataset(df, user_num=self.user_num, item_pad=self.item_num)
119
+
120
+ train_dataset = StateActionReturnDataset(self.user_trajectory, self.trajectory_len)
121
+
122
+ train_dataloader = DataLoader(
123
+ train_dataset,
124
+ shuffle=True,
125
+ pin_memory=True,
126
+ batch_size=self.train_batch_size,
127
+ collate_fn=Collator(self.item_num),
128
+ )
129
+
130
+ if with_validate:
131
+ val_dataloader = self._make_prediction_dataloader(val_users, val_items, max_context_len=self.trajectory_len)
132
+ else:
133
+ val_dataloader = None
134
+
135
+ self.model = GPT(self.mconf)
136
+
137
+ optimizer = torch.optim.AdamW(
138
+ self.model.configure_optimizers(),
139
+ lr=3e-4,
140
+ betas=(0.9, 0.95),
141
+ )
142
+ lr_scheduler = WarmUpScheduler(optimizer, dim_embed=768, warmup_steps=4000)
143
+
144
+ self.tconf.update(optimizer=optimizer, lr_scheduler=lr_scheduler)
145
+ self.trainer = Trainer(
146
+ self.model,
147
+ train_dataloader,
148
+ self.tconf,
149
+ val_dataloader,
150
+ experiment,
151
+ self.use_cuda,
152
+ )
153
+ self.trainer.train()
154
+
155
+ def _fit(
156
+ self,
157
+ log: SparkDataFrame,
158
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
159
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
160
+ ) -> None:
161
+ self.train(log)
162
+
163
+ def _predict(
164
+ self,
165
+ log: SparkDataFrame, # noqa: ARG002
166
+ k: int, # noqa: ARG002
167
+ users: SparkDataFrame,
168
+ items: SparkDataFrame,
169
+ user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
170
+ item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
171
+ filter_seen_items: bool = True, # noqa: ARG002
172
+ ) -> SparkDataFrame:
173
+ items_consider_in_pred = items.toPandas()["item_idx"].values
174
+ users_consider_in_pred = users.toPandas()["user_idx"].values
175
+ ans = self._predict_helper(users_consider_in_pred, items_consider_in_pred)
176
+ return convert2spark(ans)
177
+
178
+ def _predict_helper(self, users, items, max_context_len=30):
179
+ predict_dataloader = self._make_prediction_dataloader(users, items, max_context_len)
180
+ self.model.eval()
181
+ ans_df = pd.DataFrame(columns=["user_idx", "item_idx", "relevance"])
182
+ with torch.no_grad():
183
+ for batch in tqdm(predict_dataloader):
184
+ states, actions, rtgs, timesteps, users = self.trainer._move_batch(batch)
185
+ logits = self.model(states, actions, rtgs, timesteps, users)
186
+ items_relevances = logits[:, -1, :][:, items]
187
+ ans_df = ans_df.append(matrix2df(items_relevances, users.squeeze(), items))
188
+
189
+ return ans_df
@@ -0,0 +1,401 @@
1
+ import logging
2
+ import math
3
+
4
+ from replay.utils import TORCH_AVAILABLE
5
+
6
+ if TORCH_AVAILABLE:
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as func
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class GELU(nn.Module):
15
+ """
16
+ GELU callable class
17
+ """
18
+
19
+ def forward(self, x):
20
+ """
21
+ Apply GELU
22
+ """
23
+ return func.gelu(x)
24
+
25
+
26
+ class GPTConfig:
27
+ """base GPT config, params common to all GPT versions"""
28
+
29
+ embd_pdrop = 0.1
30
+ resid_pdrop = 0.1
31
+ attn_pdrop = 0.1
32
+ n_layer = 6
33
+ n_head = 8
34
+ n_embd = 128
35
+ memory_size = 3
36
+
37
+ def __init__(self, vocab_size, block_size, **kwargs):
38
+ self.vocab_size = vocab_size
39
+ self.block_size = block_size
40
+ for k, v in kwargs.items():
41
+ setattr(self, k, v)
42
+
43
+ def update(self, **kwargs):
44
+ """
45
+ Arguments setter
46
+ """
47
+ for k, v in kwargs.items():
48
+ setattr(self, k, v)
49
+
50
+
51
+ class GPT1Config(GPTConfig):
52
+ """GPT-1 like network roughly 125M params"""
53
+
54
+ n_layer = 12
55
+ n_head = 12
56
+ n_embd = 768
57
+
58
+
59
+ class CausalSelfAttention(nn.Module):
60
+ """
61
+ A vanilla multi-head masked self-attention layer with a projection at the end.
62
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
63
+ explicit implementation here to show that there is nothing too scary here.
64
+ """
65
+
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ assert config.n_embd % config.n_head == 0
69
+ # key, query, value projections for all heads
70
+ self.key = nn.Linear(config.n_embd, config.n_embd)
71
+ self.query = nn.Linear(config.n_embd, config.n_embd)
72
+ self.value = nn.Linear(config.n_embd, config.n_embd)
73
+ # regularization
74
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
75
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
76
+ # output projection
77
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
78
+ # causal mask to ensure that attention is only applied to the left in the input sequence
79
+ # self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
80
+ # .view(1, 1, config.block_size, config.block_size))
81
+ self.register_buffer(
82
+ "mask",
83
+ torch.tril(torch.ones(config.block_size + 1, config.block_size + 1)).view(
84
+ 1, 1, config.block_size + 1, config.block_size + 1
85
+ ),
86
+ )
87
+ self.n_head = config.n_head
88
+
89
+ def forward(self, x):
90
+ """
91
+ Apply attention
92
+ """
93
+ B, T, C = x.size() # noqa: N806
94
+
95
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
96
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
97
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
98
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
99
+
100
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
101
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
102
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
103
+ att = func.softmax(att, dim=-1)
104
+ att = self.attn_drop(att)
105
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
106
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
107
+
108
+ # output projection
109
+ y = self.resid_drop(self.proj(y))
110
+ return y
111
+
112
+
113
+ class Block(nn.Module):
114
+ """an unassuming Transformer block"""
115
+
116
+ def __init__(self, config):
117
+ super().__init__()
118
+ self.ln1 = nn.LayerNorm(config.n_embd)
119
+ self.ln2 = nn.LayerNorm(config.n_embd)
120
+ self.attn = CausalSelfAttention(config)
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(config.n_embd, 4 * config.n_embd),
123
+ GELU(),
124
+ nn.Linear(4 * config.n_embd, config.n_embd),
125
+ nn.Dropout(config.resid_pdrop),
126
+ )
127
+
128
+ def forward(self, x):
129
+ """
130
+ :x: batch
131
+ """
132
+ x = x + self.attn(self.ln1(x))
133
+ x = x + self.mlp(self.ln2(x))
134
+ return x
135
+
136
+
137
+ class StateReprModule(nn.Module):
138
+ """
139
+ Compute state for RL environment. Based on `DRR paper
140
+ <https://arxiv.org/pdf/1810.12027.pdf>`_
141
+
142
+ Computes State is a concatenation of user embedding,
143
+ weighted average pooling of `memory_size` latest relevant items
144
+ and their pairwise product.
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ user_num,
150
+ item_num,
151
+ embedding_dim,
152
+ memory_size,
153
+ ):
154
+ super().__init__()
155
+ self.user_embeddings = nn.Embedding(user_num, embedding_dim)
156
+
157
+ self.item_embeddings = nn.Embedding(item_num + 1, embedding_dim, padding_idx=int(item_num))
158
+
159
+ self.drr_ave = torch.nn.Conv1d(in_channels=memory_size, out_channels=1, kernel_size=1)
160
+
161
+ self.linear = nn.Linear(3 * embedding_dim, embedding_dim)
162
+
163
+ self.initialize()
164
+
165
+ def initialize(self):
166
+ """weight init"""
167
+ nn.init.normal_(self.user_embeddings.weight, std=0.01)
168
+ self.item_embeddings.weight.data[-1].zero_()
169
+
170
+ nn.init.normal_(self.item_embeddings.weight, std=0.01)
171
+ nn.init.uniform_(self.drr_ave.weight)
172
+
173
+ self.drr_ave.bias.data.zero_()
174
+
175
+ def forward(self, user, memory):
176
+ """
177
+ :param user: user batch
178
+ :param memory: memory batch
179
+ :return: vector of dimension embedding_dim
180
+ """
181
+ user_embedding = self.user_embeddings(user.long()).squeeze(1)
182
+ item_embeddings = self.item_embeddings(memory.long())
183
+ drr_ave = self.drr_ave(item_embeddings).squeeze(1)
184
+ output = torch.cat((user_embedding, user_embedding * drr_ave, drr_ave), 1)
185
+ output = self.linear(output)
186
+
187
+ return output
188
+
189
+
190
+ class GPT(nn.Module):
191
+ """the full GPT language model, with a context size of block_size"""
192
+
193
+ def __init__(self, config):
194
+ super().__init__()
195
+
196
+ self.config = config
197
+
198
+ self.user_num = config.user_num
199
+ self.memory_size = config.memory_size
200
+
201
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd))
202
+ self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep + 1, config.n_embd))
203
+ self.drop = nn.Dropout(config.embd_pdrop)
204
+
205
+ # transformer
206
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
207
+ # decoder head
208
+ self.ln_f = nn.LayerNorm(config.n_embd)
209
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
210
+
211
+ self.block_size = config.block_size
212
+ self.apply(self._init_weights)
213
+
214
+ logger.info(
215
+ "number of parameters: %e",
216
+ sum(p.numel() for p in self.parameters()),
217
+ )
218
+
219
+ self.state_repr = StateReprModule(
220
+ user_num=config.user_num,
221
+ item_num=config.vocab_size,
222
+ embedding_dim=config.n_embd,
223
+ memory_size=config.memory_size,
224
+ )
225
+
226
+ self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh())
227
+ self.action_embeddings = nn.Sequential(self.state_repr.item_embeddings, nn.Tanh())
228
+ nn.init.normal_(self.action_embeddings[0].weight, mean=0.0, std=0.02)
229
+
230
+ def get_block_size(self):
231
+ """
232
+ Return block_size
233
+ """
234
+ return self.block_size
235
+
236
+ @staticmethod
237
+ def _init_weights(module):
238
+ if isinstance(module, (nn.Linear, nn.Embedding)):
239
+ module.weight.data.normal_(mean=0.0, std=0.02)
240
+ if isinstance(module, nn.Linear) and module.bias is not None:
241
+ module.bias.data.zero_()
242
+ elif isinstance(module, nn.LayerNorm):
243
+ module.bias.data.zero_()
244
+ module.weight.data.fill_(1.0)
245
+
246
+ def configure_optimizers(self):
247
+ """
248
+ This long function is unfortunately doing something very simple and is being very defensive:
249
+ We are separating out all parameters of the model into two buckets: those that will experience
250
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
251
+ We are then returning the PyTorch optimizer object.
252
+ """
253
+
254
+ # separate out all parameters to those that will and won't experience regularizing weight decay
255
+ decay = set()
256
+ no_decay = set()
257
+ whitelist_weight_modules = (
258
+ torch.nn.Linear,
259
+ torch.nn.Conv2d,
260
+ torch.nn.Conv1d,
261
+ )
262
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
263
+ for mn, m in self.named_modules():
264
+ for pn, _ in m.named_parameters():
265
+ fpn = f"{mn}.{pn}" if mn else pn # full param name
266
+
267
+ if pn.endswith("bias"):
268
+ # all biases will not be decayed
269
+ no_decay.add(fpn)
270
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
271
+ # weights of whitelist modules will be weight decayed
272
+ decay.add(fpn)
273
+ elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
274
+ # weights of blacklist modules will NOT be weight decayed
275
+ no_decay.add(fpn)
276
+
277
+ # special case the position embedding parameter in the root GPT module as not decayed
278
+ no_decay.add("pos_emb")
279
+ no_decay.add("global_pos_emb")
280
+
281
+ # validate that we considered every parameter
282
+ param_dict = dict(self.named_parameters())
283
+ inter_params = decay & no_decay
284
+ union_params = decay | no_decay
285
+ assert len(inter_params) == 0, f"parameters {inter_params!s} made it into both decay/no_decay sets!"
286
+ assert (
287
+ len(param_dict.keys() - union_params) == 0
288
+ ), f"parameters {param_dict.keys() - union_params!s} were not separated into either decay/no_decay set!"
289
+
290
+ optim_groups = [
291
+ {
292
+ "params": [param_dict[pn] for pn in sorted(decay)],
293
+ "weight_decay": 0.1,
294
+ },
295
+ {
296
+ "params": [param_dict[pn] for pn in sorted(no_decay)],
297
+ "weight_decay": 0.0,
298
+ },
299
+ ]
300
+ return optim_groups
301
+
302
+ # state, action, and return
303
+ def forward(
304
+ self,
305
+ states,
306
+ actions,
307
+ rtgs,
308
+ timesteps,
309
+ users,
310
+ ):
311
+ """
312
+ :states: states batch, (batch, trajectory_len, 3)
313
+ :actions: actions batch, (batch, trajectory_len, 1)
314
+ :rtgs: rtgs batch, (batch, trajectory_len, 1)
315
+ :timesteps: timesteps batch, (batch, 1, 1)
316
+ :users:users batch, (batch, 1)
317
+ """
318
+ inference = not self.training
319
+ state_embeddings = self.state_repr(
320
+ users.repeat((1, self.block_size // 3)).reshape(-1, 1),
321
+ states.reshape(-1, 3),
322
+ )
323
+
324
+ state_embeddings = state_embeddings.reshape(
325
+ states.shape[0], states.shape[1], self.config.n_embd
326
+ ) # (batch, block_size, n_embd)
327
+
328
+ if actions is not None:
329
+ rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
330
+ action_embeddings = self.action_embeddings(
331
+ actions.type(torch.long).squeeze(-1)
332
+ ) # (batch, block_size, n_embd)
333
+
334
+ token_embeddings = torch.zeros(
335
+ (
336
+ states.shape[0],
337
+ states.shape[1] * 3 - int(inference),
338
+ self.config.n_embd,
339
+ ),
340
+ dtype=torch.float32,
341
+ device=state_embeddings.device,
342
+ )
343
+ token_embeddings[:, ::3, :] = rtg_embeddings
344
+ token_embeddings[:, 1::3, :] = state_embeddings
345
+ token_embeddings[:, 2::3, :] = action_embeddings[:, -states.shape[1] + int(inference) :, :]
346
+ else:
347
+ # only happens at very first timestep of evaluation
348
+ rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
349
+
350
+ token_embeddings = torch.zeros(
351
+ (states.shape[0], states.shape[1] * 2, self.config.n_embd),
352
+ dtype=torch.float32,
353
+ device=state_embeddings.device,
354
+ )
355
+ token_embeddings[:, ::2, :] = rtg_embeddings # really just [:,0,:]
356
+ token_embeddings[:, 1::2, :] = state_embeddings # really just [:,1,:]
357
+
358
+ batch_size = states.shape[0]
359
+ all_global_pos_emb = torch.repeat_interleave(
360
+ self.global_pos_emb, batch_size, dim=0
361
+ ) # batch_size, traj_length, n_embd
362
+
363
+ position_embeddings = (
364
+ torch.gather(
365
+ all_global_pos_emb,
366
+ 1,
367
+ torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1),
368
+ )
369
+ + self.pos_emb[:, : token_embeddings.shape[1], :]
370
+ )
371
+ x = self.drop(token_embeddings + position_embeddings)
372
+ x = self.blocks(x)
373
+ x = self.ln_f(x)
374
+ logits = self.head(x)
375
+
376
+ if actions is not None:
377
+ logits = logits[:, 1::3, :] # only keep predictions from state_embeddings
378
+ elif actions is None:
379
+ logits = logits[:, 1:, :]
380
+
381
+ return logits
382
+
383
+ def predict(self, states, actions, rtgs, timesteps, users):
384
+ """
385
+ :states: states batch, (batch, block_size, 3)
386
+ :actions: actions batch, (batch, block_size, 1)
387
+ :rtgs: rtgs batch, (batch, block_size, 1)
388
+ :timesteps: timesteps batch, (batch, 1, 1)
389
+ :users:users batch, (batch, 1)
390
+ """
391
+ logits, _ = self(
392
+ states=states.to(self.pos_emb.device),
393
+ actions=actions.to(self.pos_emb.device),
394
+ targets=None,
395
+ rtgs=rtgs.to(self.pos_emb.device),
396
+ timesteps=timesteps.to(self.pos_emb.device),
397
+ users=users.to(self.pos_emb.device),
398
+ )
399
+ logits = logits[:, -1, :]
400
+ actions = logits.argsort(dim=1, descending=True)
401
+ return actions