replay-rec 0.20.1rc0__py3-none-any.whl → 0.20.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- replay/__init__.py +1 -1
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/METADATA +18 -12
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/RECORD +6 -61
- replay/experimental/__init__.py +0 -0
- replay/experimental/metrics/__init__.py +0 -62
- replay/experimental/metrics/base_metric.py +0 -603
- replay/experimental/metrics/coverage.py +0 -97
- replay/experimental/metrics/experiment.py +0 -175
- replay/experimental/metrics/hitrate.py +0 -26
- replay/experimental/metrics/map.py +0 -30
- replay/experimental/metrics/mrr.py +0 -18
- replay/experimental/metrics/ncis_precision.py +0 -31
- replay/experimental/metrics/ndcg.py +0 -49
- replay/experimental/metrics/precision.py +0 -22
- replay/experimental/metrics/recall.py +0 -25
- replay/experimental/metrics/rocauc.py +0 -49
- replay/experimental/metrics/surprisal.py +0 -90
- replay/experimental/metrics/unexpectedness.py +0 -76
- replay/experimental/models/__init__.py +0 -50
- replay/experimental/models/admm_slim.py +0 -257
- replay/experimental/models/base_neighbour_rec.py +0 -200
- replay/experimental/models/base_rec.py +0 -1386
- replay/experimental/models/base_torch_rec.py +0 -234
- replay/experimental/models/cql.py +0 -454
- replay/experimental/models/ddpg.py +0 -932
- replay/experimental/models/dt4rec/__init__.py +0 -0
- replay/experimental/models/dt4rec/dt4rec.py +0 -189
- replay/experimental/models/dt4rec/gpt1.py +0 -401
- replay/experimental/models/dt4rec/trainer.py +0 -127
- replay/experimental/models/dt4rec/utils.py +0 -264
- replay/experimental/models/extensions/spark_custom_models/__init__.py +0 -0
- replay/experimental/models/extensions/spark_custom_models/als_extension.py +0 -792
- replay/experimental/models/hierarchical_recommender.py +0 -331
- replay/experimental/models/implicit_wrap.py +0 -131
- replay/experimental/models/lightfm_wrap.py +0 -303
- replay/experimental/models/mult_vae.py +0 -332
- replay/experimental/models/neural_ts.py +0 -986
- replay/experimental/models/neuromf.py +0 -406
- replay/experimental/models/scala_als.py +0 -293
- replay/experimental/models/u_lin_ucb.py +0 -115
- replay/experimental/nn/data/__init__.py +0 -1
- replay/experimental/nn/data/schema_builder.py +0 -102
- replay/experimental/preprocessing/__init__.py +0 -3
- replay/experimental/preprocessing/data_preparator.py +0 -839
- replay/experimental/preprocessing/padder.py +0 -229
- replay/experimental/preprocessing/sequence_generator.py +0 -208
- replay/experimental/scenarios/__init__.py +0 -1
- replay/experimental/scenarios/obp_wrapper/__init__.py +0 -8
- replay/experimental/scenarios/obp_wrapper/obp_optuna_objective.py +0 -74
- replay/experimental/scenarios/obp_wrapper/replay_offline.py +0 -261
- replay/experimental/scenarios/obp_wrapper/utils.py +0 -85
- replay/experimental/scenarios/two_stages/__init__.py +0 -0
- replay/experimental/scenarios/two_stages/reranker.py +0 -117
- replay/experimental/scenarios/two_stages/two_stages_scenario.py +0 -757
- replay/experimental/utils/__init__.py +0 -0
- replay/experimental/utils/logger.py +0 -24
- replay/experimental/utils/model_handler.py +0 -186
- replay/experimental/utils/session_handler.py +0 -44
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/WHEEL +0 -0
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/licenses/LICENSE +0 -0
- {replay_rec-0.20.1rc0.dist-info → replay_rec-0.20.2.dist-info}/licenses/NOTICE +0 -0
|
File without changes
|
|
@@ -1,189 +0,0 @@
|
|
|
1
|
-
from typing import Optional
|
|
2
|
-
|
|
3
|
-
import pandas as pd
|
|
4
|
-
from tqdm import tqdm
|
|
5
|
-
|
|
6
|
-
from replay.experimental.models.base_rec import Recommender
|
|
7
|
-
from replay.utils import PYSPARK_AVAILABLE, TORCH_AVAILABLE, SparkDataFrame
|
|
8
|
-
|
|
9
|
-
if PYSPARK_AVAILABLE:
|
|
10
|
-
from replay.utils.spark_utils import convert2spark
|
|
11
|
-
|
|
12
|
-
if TORCH_AVAILABLE:
|
|
13
|
-
import torch
|
|
14
|
-
from torch.utils.data.dataloader import DataLoader
|
|
15
|
-
|
|
16
|
-
from .gpt1 import GPT, GPTConfig
|
|
17
|
-
from .trainer import Trainer, TrainerConfig
|
|
18
|
-
from .utils import (
|
|
19
|
-
Collator,
|
|
20
|
-
StateActionReturnDataset,
|
|
21
|
-
ValidateDataset,
|
|
22
|
-
WarmUpScheduler,
|
|
23
|
-
create_dataset,
|
|
24
|
-
matrix2df,
|
|
25
|
-
set_seed,
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class DT4Rec(Recommender):
|
|
30
|
-
"""
|
|
31
|
-
Decision Transformer for Recommendations
|
|
32
|
-
|
|
33
|
-
General Idea:
|
|
34
|
-
`Decision Transformer: Reinforcement Learning
|
|
35
|
-
via Sequence Modeling <https://arxiv.org/pdf/2106.01345.pdf>`_.
|
|
36
|
-
|
|
37
|
-
Ideas for improvements:
|
|
38
|
-
`User Retention-oriented Recommendation with Decision
|
|
39
|
-
Transformer <https://arxiv.org/pdf/2303.06347.pdf>`_.
|
|
40
|
-
|
|
41
|
-
Also, some sources are listed in their respective classes
|
|
42
|
-
"""
|
|
43
|
-
|
|
44
|
-
optimizer = None
|
|
45
|
-
train_batch_size = 128
|
|
46
|
-
val_batch_size = 128
|
|
47
|
-
lr_scheduler = None
|
|
48
|
-
|
|
49
|
-
def __init__(
|
|
50
|
-
self,
|
|
51
|
-
item_num,
|
|
52
|
-
user_num,
|
|
53
|
-
seed=123,
|
|
54
|
-
trajectory_len=30,
|
|
55
|
-
epochs=1,
|
|
56
|
-
batch_size=64,
|
|
57
|
-
use_cuda=True,
|
|
58
|
-
):
|
|
59
|
-
self.item_num = item_num
|
|
60
|
-
self.user_num = user_num
|
|
61
|
-
self.seed = seed
|
|
62
|
-
self.trajectory_len = trajectory_len
|
|
63
|
-
self.epochs = epochs
|
|
64
|
-
self.batch_size = batch_size
|
|
65
|
-
self.tconf: TrainerConfig = TrainerConfig(epochs=epochs)
|
|
66
|
-
self.mconf: GPTConfig = GPTConfig(
|
|
67
|
-
user_num=user_num,
|
|
68
|
-
item_num=item_num,
|
|
69
|
-
vocab_size=self.item_num + 1,
|
|
70
|
-
block_size=self.trajectory_len * 3,
|
|
71
|
-
max_timestep=self.item_num,
|
|
72
|
-
)
|
|
73
|
-
self.model: GPT
|
|
74
|
-
self.user_trajectory: list
|
|
75
|
-
self.trainer: Trainer
|
|
76
|
-
self.use_cuda = use_cuda
|
|
77
|
-
set_seed(self.seed)
|
|
78
|
-
|
|
79
|
-
def _init_args(self):
|
|
80
|
-
pass
|
|
81
|
-
|
|
82
|
-
def _update_mconf(self, **kwargs):
|
|
83
|
-
self.mconf.update(**kwargs)
|
|
84
|
-
|
|
85
|
-
def _update_tconf(self, **kwargs):
|
|
86
|
-
self.tconf.update(**kwargs)
|
|
87
|
-
|
|
88
|
-
def _make_prediction_dataloader(self, users, items, max_context_len=30):
|
|
89
|
-
val_dataset = ValidateDataset(
|
|
90
|
-
self.user_trajectory,
|
|
91
|
-
max_context_len=max_context_len - 1,
|
|
92
|
-
val_items=items,
|
|
93
|
-
val_users=users,
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
val_dataloader = DataLoader(
|
|
97
|
-
val_dataset,
|
|
98
|
-
pin_memory=True,
|
|
99
|
-
batch_size=self.val_batch_size,
|
|
100
|
-
collate_fn=Collator(self.item_num),
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
return val_dataloader
|
|
104
|
-
|
|
105
|
-
def train(
|
|
106
|
-
self,
|
|
107
|
-
log,
|
|
108
|
-
val_users=None,
|
|
109
|
-
val_items=None,
|
|
110
|
-
experiment=None,
|
|
111
|
-
):
|
|
112
|
-
"""
|
|
113
|
-
Run training loop
|
|
114
|
-
"""
|
|
115
|
-
assert (val_users is None) == (val_items is None) == (experiment is None)
|
|
116
|
-
with_validate = experiment is not None
|
|
117
|
-
df = log.toPandas()[["user_idx", "item_idx", "relevance", "timestamp"]]
|
|
118
|
-
self.user_trajectory = create_dataset(df, user_num=self.user_num, item_pad=self.item_num)
|
|
119
|
-
|
|
120
|
-
train_dataset = StateActionReturnDataset(self.user_trajectory, self.trajectory_len)
|
|
121
|
-
|
|
122
|
-
train_dataloader = DataLoader(
|
|
123
|
-
train_dataset,
|
|
124
|
-
shuffle=True,
|
|
125
|
-
pin_memory=True,
|
|
126
|
-
batch_size=self.train_batch_size,
|
|
127
|
-
collate_fn=Collator(self.item_num),
|
|
128
|
-
)
|
|
129
|
-
|
|
130
|
-
if with_validate:
|
|
131
|
-
val_dataloader = self._make_prediction_dataloader(val_users, val_items, max_context_len=self.trajectory_len)
|
|
132
|
-
else:
|
|
133
|
-
val_dataloader = None
|
|
134
|
-
|
|
135
|
-
self.model = GPT(self.mconf)
|
|
136
|
-
|
|
137
|
-
optimizer = torch.optim.AdamW(
|
|
138
|
-
self.model.configure_optimizers(),
|
|
139
|
-
lr=3e-4,
|
|
140
|
-
betas=(0.9, 0.95),
|
|
141
|
-
)
|
|
142
|
-
lr_scheduler = WarmUpScheduler(optimizer, dim_embed=768, warmup_steps=4000)
|
|
143
|
-
|
|
144
|
-
self.tconf.update(optimizer=optimizer, lr_scheduler=lr_scheduler)
|
|
145
|
-
self.trainer = Trainer(
|
|
146
|
-
self.model,
|
|
147
|
-
train_dataloader,
|
|
148
|
-
self.tconf,
|
|
149
|
-
val_dataloader,
|
|
150
|
-
experiment,
|
|
151
|
-
self.use_cuda,
|
|
152
|
-
)
|
|
153
|
-
self.trainer.train()
|
|
154
|
-
|
|
155
|
-
def _fit(
|
|
156
|
-
self,
|
|
157
|
-
log: SparkDataFrame,
|
|
158
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
159
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
160
|
-
) -> None:
|
|
161
|
-
self.train(log)
|
|
162
|
-
|
|
163
|
-
def _predict(
|
|
164
|
-
self,
|
|
165
|
-
log: SparkDataFrame, # noqa: ARG002
|
|
166
|
-
k: int, # noqa: ARG002
|
|
167
|
-
users: SparkDataFrame,
|
|
168
|
-
items: SparkDataFrame,
|
|
169
|
-
user_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
170
|
-
item_features: Optional[SparkDataFrame] = None, # noqa: ARG002
|
|
171
|
-
filter_seen_items: bool = True, # noqa: ARG002
|
|
172
|
-
) -> SparkDataFrame:
|
|
173
|
-
items_consider_in_pred = items.toPandas()["item_idx"].values
|
|
174
|
-
users_consider_in_pred = users.toPandas()["user_idx"].values
|
|
175
|
-
ans = self._predict_helper(users_consider_in_pred, items_consider_in_pred)
|
|
176
|
-
return convert2spark(ans)
|
|
177
|
-
|
|
178
|
-
def _predict_helper(self, users, items, max_context_len=30):
|
|
179
|
-
predict_dataloader = self._make_prediction_dataloader(users, items, max_context_len)
|
|
180
|
-
self.model.eval()
|
|
181
|
-
ans_df = pd.DataFrame(columns=["user_idx", "item_idx", "relevance"])
|
|
182
|
-
with torch.no_grad():
|
|
183
|
-
for batch in tqdm(predict_dataloader):
|
|
184
|
-
states, actions, rtgs, timesteps, users = self.trainer._move_batch(batch)
|
|
185
|
-
logits = self.model(states, actions, rtgs, timesteps, users)
|
|
186
|
-
items_relevances = logits[:, -1, :][:, items]
|
|
187
|
-
ans_df = pd.concat([ans_df, matrix2df(items_relevances, users.squeeze(), items)])
|
|
188
|
-
|
|
189
|
-
return ans_df
|
|
@@ -1,401 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import math
|
|
3
|
-
|
|
4
|
-
from replay.utils import TORCH_AVAILABLE
|
|
5
|
-
|
|
6
|
-
if TORCH_AVAILABLE:
|
|
7
|
-
import torch
|
|
8
|
-
from torch import nn
|
|
9
|
-
from torch.nn import functional as func
|
|
10
|
-
|
|
11
|
-
logger = logging.getLogger(__name__)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class GELU(nn.Module):
|
|
15
|
-
"""
|
|
16
|
-
GELU callable class
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
def forward(self, x):
|
|
20
|
-
"""
|
|
21
|
-
Apply GELU
|
|
22
|
-
"""
|
|
23
|
-
return func.gelu(x)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class GPTConfig:
|
|
27
|
-
"""base GPT config, params common to all GPT versions"""
|
|
28
|
-
|
|
29
|
-
embd_pdrop = 0.1
|
|
30
|
-
resid_pdrop = 0.1
|
|
31
|
-
attn_pdrop = 0.1
|
|
32
|
-
n_layer = 6
|
|
33
|
-
n_head = 8
|
|
34
|
-
n_embd = 128
|
|
35
|
-
memory_size = 3
|
|
36
|
-
|
|
37
|
-
def __init__(self, vocab_size, block_size, **kwargs):
|
|
38
|
-
self.vocab_size = vocab_size
|
|
39
|
-
self.block_size = block_size
|
|
40
|
-
for k, v in kwargs.items():
|
|
41
|
-
setattr(self, k, v)
|
|
42
|
-
|
|
43
|
-
def update(self, **kwargs):
|
|
44
|
-
"""
|
|
45
|
-
Arguments setter
|
|
46
|
-
"""
|
|
47
|
-
for k, v in kwargs.items():
|
|
48
|
-
setattr(self, k, v)
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
class GPT1Config(GPTConfig):
|
|
52
|
-
"""GPT-1 like network roughly 125M params"""
|
|
53
|
-
|
|
54
|
-
n_layer = 12
|
|
55
|
-
n_head = 12
|
|
56
|
-
n_embd = 768
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class CausalSelfAttention(nn.Module):
|
|
60
|
-
"""
|
|
61
|
-
A vanilla multi-head masked self-attention layer with a projection at the end.
|
|
62
|
-
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
|
63
|
-
explicit implementation here to show that there is nothing too scary here.
|
|
64
|
-
"""
|
|
65
|
-
|
|
66
|
-
def __init__(self, config):
|
|
67
|
-
super().__init__()
|
|
68
|
-
assert config.n_embd % config.n_head == 0
|
|
69
|
-
# key, query, value projections for all heads
|
|
70
|
-
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
71
|
-
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
72
|
-
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
73
|
-
# regularization
|
|
74
|
-
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
|
75
|
-
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
|
76
|
-
# output projection
|
|
77
|
-
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
|
78
|
-
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
79
|
-
# self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
|
|
80
|
-
# .view(1, 1, config.block_size, config.block_size))
|
|
81
|
-
self.register_buffer(
|
|
82
|
-
"mask",
|
|
83
|
-
torch.tril(torch.ones(config.block_size + 1, config.block_size + 1)).view(
|
|
84
|
-
1, 1, config.block_size + 1, config.block_size + 1
|
|
85
|
-
),
|
|
86
|
-
)
|
|
87
|
-
self.n_head = config.n_head
|
|
88
|
-
|
|
89
|
-
def forward(self, x):
|
|
90
|
-
"""
|
|
91
|
-
Apply attention
|
|
92
|
-
"""
|
|
93
|
-
B, T, C = x.size() # noqa: N806
|
|
94
|
-
|
|
95
|
-
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
96
|
-
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
97
|
-
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
98
|
-
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
99
|
-
|
|
100
|
-
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
101
|
-
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
102
|
-
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
|
|
103
|
-
att = func.softmax(att, dim=-1)
|
|
104
|
-
att = self.attn_drop(att)
|
|
105
|
-
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
106
|
-
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
107
|
-
|
|
108
|
-
# output projection
|
|
109
|
-
y = self.resid_drop(self.proj(y))
|
|
110
|
-
return y
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
class Block(nn.Module):
|
|
114
|
-
"""an unassuming Transformer block"""
|
|
115
|
-
|
|
116
|
-
def __init__(self, config):
|
|
117
|
-
super().__init__()
|
|
118
|
-
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
119
|
-
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
120
|
-
self.attn = CausalSelfAttention(config)
|
|
121
|
-
self.mlp = nn.Sequential(
|
|
122
|
-
nn.Linear(config.n_embd, 4 * config.n_embd),
|
|
123
|
-
GELU(),
|
|
124
|
-
nn.Linear(4 * config.n_embd, config.n_embd),
|
|
125
|
-
nn.Dropout(config.resid_pdrop),
|
|
126
|
-
)
|
|
127
|
-
|
|
128
|
-
def forward(self, x):
|
|
129
|
-
"""
|
|
130
|
-
:x: batch
|
|
131
|
-
"""
|
|
132
|
-
x = x + self.attn(self.ln1(x))
|
|
133
|
-
x = x + self.mlp(self.ln2(x))
|
|
134
|
-
return x
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
class StateReprModule(nn.Module):
|
|
138
|
-
"""
|
|
139
|
-
Compute state for RL environment. Based on `DRR paper
|
|
140
|
-
<https://arxiv.org/pdf/1810.12027.pdf>`_
|
|
141
|
-
|
|
142
|
-
Computes State is a concatenation of user embedding,
|
|
143
|
-
weighted average pooling of `memory_size` latest relevant items
|
|
144
|
-
and their pairwise product.
|
|
145
|
-
"""
|
|
146
|
-
|
|
147
|
-
def __init__(
|
|
148
|
-
self,
|
|
149
|
-
user_num,
|
|
150
|
-
item_num,
|
|
151
|
-
embedding_dim,
|
|
152
|
-
memory_size,
|
|
153
|
-
):
|
|
154
|
-
super().__init__()
|
|
155
|
-
self.user_embeddings = nn.Embedding(user_num, embedding_dim)
|
|
156
|
-
|
|
157
|
-
self.item_embeddings = nn.Embedding(item_num + 1, embedding_dim, padding_idx=int(item_num))
|
|
158
|
-
|
|
159
|
-
self.drr_ave = torch.nn.Conv1d(in_channels=memory_size, out_channels=1, kernel_size=1)
|
|
160
|
-
|
|
161
|
-
self.linear = nn.Linear(3 * embedding_dim, embedding_dim)
|
|
162
|
-
|
|
163
|
-
self.initialize()
|
|
164
|
-
|
|
165
|
-
def initialize(self):
|
|
166
|
-
"""weight init"""
|
|
167
|
-
nn.init.normal_(self.user_embeddings.weight, std=0.01)
|
|
168
|
-
self.item_embeddings.weight.data[-1].zero_()
|
|
169
|
-
|
|
170
|
-
nn.init.normal_(self.item_embeddings.weight, std=0.01)
|
|
171
|
-
nn.init.uniform_(self.drr_ave.weight)
|
|
172
|
-
|
|
173
|
-
self.drr_ave.bias.data.zero_()
|
|
174
|
-
|
|
175
|
-
def forward(self, user, memory):
|
|
176
|
-
"""
|
|
177
|
-
:param user: user batch
|
|
178
|
-
:param memory: memory batch
|
|
179
|
-
:return: vector of dimension embedding_dim
|
|
180
|
-
"""
|
|
181
|
-
user_embedding = self.user_embeddings(user.long()).squeeze(1)
|
|
182
|
-
item_embeddings = self.item_embeddings(memory.long())
|
|
183
|
-
drr_ave = self.drr_ave(item_embeddings).squeeze(1)
|
|
184
|
-
output = torch.cat((user_embedding, user_embedding * drr_ave, drr_ave), 1)
|
|
185
|
-
output = self.linear(output)
|
|
186
|
-
|
|
187
|
-
return output
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
class GPT(nn.Module):
|
|
191
|
-
"""the full GPT language model, with a context size of block_size"""
|
|
192
|
-
|
|
193
|
-
def __init__(self, config):
|
|
194
|
-
super().__init__()
|
|
195
|
-
|
|
196
|
-
self.config = config
|
|
197
|
-
|
|
198
|
-
self.user_num = config.user_num
|
|
199
|
-
self.memory_size = config.memory_size
|
|
200
|
-
|
|
201
|
-
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size + 1, config.n_embd))
|
|
202
|
-
self.global_pos_emb = nn.Parameter(torch.zeros(1, config.max_timestep + 1, config.n_embd))
|
|
203
|
-
self.drop = nn.Dropout(config.embd_pdrop)
|
|
204
|
-
|
|
205
|
-
# transformer
|
|
206
|
-
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
|
207
|
-
# decoder head
|
|
208
|
-
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
209
|
-
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
210
|
-
|
|
211
|
-
self.block_size = config.block_size
|
|
212
|
-
self.apply(self._init_weights)
|
|
213
|
-
|
|
214
|
-
logger.info(
|
|
215
|
-
"number of parameters: %e",
|
|
216
|
-
sum(p.numel() for p in self.parameters()),
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
self.state_repr = StateReprModule(
|
|
220
|
-
user_num=config.user_num,
|
|
221
|
-
item_num=config.vocab_size,
|
|
222
|
-
embedding_dim=config.n_embd,
|
|
223
|
-
memory_size=config.memory_size,
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
self.ret_emb = nn.Sequential(nn.Linear(1, config.n_embd), nn.Tanh())
|
|
227
|
-
self.action_embeddings = nn.Sequential(self.state_repr.item_embeddings, nn.Tanh())
|
|
228
|
-
nn.init.normal_(self.action_embeddings[0].weight, mean=0.0, std=0.02)
|
|
229
|
-
|
|
230
|
-
def get_block_size(self):
|
|
231
|
-
"""
|
|
232
|
-
Return block_size
|
|
233
|
-
"""
|
|
234
|
-
return self.block_size
|
|
235
|
-
|
|
236
|
-
@staticmethod
|
|
237
|
-
def _init_weights(module):
|
|
238
|
-
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
239
|
-
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
240
|
-
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
241
|
-
module.bias.data.zero_()
|
|
242
|
-
elif isinstance(module, nn.LayerNorm):
|
|
243
|
-
module.bias.data.zero_()
|
|
244
|
-
module.weight.data.fill_(1.0)
|
|
245
|
-
|
|
246
|
-
def configure_optimizers(self):
|
|
247
|
-
"""
|
|
248
|
-
This long function is unfortunately doing something very simple and is being very defensive:
|
|
249
|
-
We are separating out all parameters of the model into two buckets: those that will experience
|
|
250
|
-
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
|
251
|
-
We are then returning the PyTorch optimizer object.
|
|
252
|
-
"""
|
|
253
|
-
|
|
254
|
-
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
255
|
-
decay = set()
|
|
256
|
-
no_decay = set()
|
|
257
|
-
whitelist_weight_modules = (
|
|
258
|
-
torch.nn.Linear,
|
|
259
|
-
torch.nn.Conv2d,
|
|
260
|
-
torch.nn.Conv1d,
|
|
261
|
-
)
|
|
262
|
-
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
263
|
-
for mn, m in self.named_modules():
|
|
264
|
-
for pn, _ in m.named_parameters():
|
|
265
|
-
fpn = f"{mn}.{pn}" if mn else pn # full param name
|
|
266
|
-
|
|
267
|
-
if pn.endswith("bias"):
|
|
268
|
-
# all biases will not be decayed
|
|
269
|
-
no_decay.add(fpn)
|
|
270
|
-
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
|
271
|
-
# weights of whitelist modules will be weight decayed
|
|
272
|
-
decay.add(fpn)
|
|
273
|
-
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
|
274
|
-
# weights of blacklist modules will NOT be weight decayed
|
|
275
|
-
no_decay.add(fpn)
|
|
276
|
-
|
|
277
|
-
# special case the position embedding parameter in the root GPT module as not decayed
|
|
278
|
-
no_decay.add("pos_emb")
|
|
279
|
-
no_decay.add("global_pos_emb")
|
|
280
|
-
|
|
281
|
-
# validate that we considered every parameter
|
|
282
|
-
param_dict = dict(self.named_parameters())
|
|
283
|
-
inter_params = decay & no_decay
|
|
284
|
-
union_params = decay | no_decay
|
|
285
|
-
assert len(inter_params) == 0, f"parameters {inter_params!s} made it into both decay/no_decay sets!"
|
|
286
|
-
assert (
|
|
287
|
-
len(param_dict.keys() - union_params) == 0
|
|
288
|
-
), f"parameters {param_dict.keys() - union_params!s} were not separated into either decay/no_decay set!"
|
|
289
|
-
|
|
290
|
-
optim_groups = [
|
|
291
|
-
{
|
|
292
|
-
"params": [param_dict[pn] for pn in sorted(decay)],
|
|
293
|
-
"weight_decay": 0.1,
|
|
294
|
-
},
|
|
295
|
-
{
|
|
296
|
-
"params": [param_dict[pn] for pn in sorted(no_decay)],
|
|
297
|
-
"weight_decay": 0.0,
|
|
298
|
-
},
|
|
299
|
-
]
|
|
300
|
-
return optim_groups
|
|
301
|
-
|
|
302
|
-
# state, action, and return
|
|
303
|
-
def forward(
|
|
304
|
-
self,
|
|
305
|
-
states,
|
|
306
|
-
actions,
|
|
307
|
-
rtgs,
|
|
308
|
-
timesteps,
|
|
309
|
-
users,
|
|
310
|
-
):
|
|
311
|
-
"""
|
|
312
|
-
:states: states batch, (batch, trajectory_len, 3)
|
|
313
|
-
:actions: actions batch, (batch, trajectory_len, 1)
|
|
314
|
-
:rtgs: rtgs batch, (batch, trajectory_len, 1)
|
|
315
|
-
:timesteps: timesteps batch, (batch, 1, 1)
|
|
316
|
-
:users:users batch, (batch, 1)
|
|
317
|
-
"""
|
|
318
|
-
inference = not self.training
|
|
319
|
-
state_embeddings = self.state_repr(
|
|
320
|
-
users.repeat((1, self.block_size // 3)).reshape(-1, 1),
|
|
321
|
-
states.reshape(-1, 3),
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
state_embeddings = state_embeddings.reshape(
|
|
325
|
-
states.shape[0], states.shape[1], self.config.n_embd
|
|
326
|
-
) # (batch, block_size, n_embd)
|
|
327
|
-
|
|
328
|
-
if actions is not None:
|
|
329
|
-
rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
|
|
330
|
-
action_embeddings = self.action_embeddings(
|
|
331
|
-
actions.type(torch.long).squeeze(-1)
|
|
332
|
-
) # (batch, block_size, n_embd)
|
|
333
|
-
|
|
334
|
-
token_embeddings = torch.zeros(
|
|
335
|
-
(
|
|
336
|
-
states.shape[0],
|
|
337
|
-
states.shape[1] * 3 - int(inference),
|
|
338
|
-
self.config.n_embd,
|
|
339
|
-
),
|
|
340
|
-
dtype=torch.float32,
|
|
341
|
-
device=state_embeddings.device,
|
|
342
|
-
)
|
|
343
|
-
token_embeddings[:, ::3, :] = rtg_embeddings
|
|
344
|
-
token_embeddings[:, 1::3, :] = state_embeddings
|
|
345
|
-
token_embeddings[:, 2::3, :] = action_embeddings[:, -states.shape[1] + int(inference) :, :]
|
|
346
|
-
else:
|
|
347
|
-
# only happens at very first timestep of evaluation
|
|
348
|
-
rtg_embeddings = self.ret_emb(rtgs.type(torch.float32))
|
|
349
|
-
|
|
350
|
-
token_embeddings = torch.zeros(
|
|
351
|
-
(states.shape[0], states.shape[1] * 2, self.config.n_embd),
|
|
352
|
-
dtype=torch.float32,
|
|
353
|
-
device=state_embeddings.device,
|
|
354
|
-
)
|
|
355
|
-
token_embeddings[:, ::2, :] = rtg_embeddings # really just [:,0,:]
|
|
356
|
-
token_embeddings[:, 1::2, :] = state_embeddings # really just [:,1,:]
|
|
357
|
-
|
|
358
|
-
batch_size = states.shape[0]
|
|
359
|
-
all_global_pos_emb = torch.repeat_interleave(
|
|
360
|
-
self.global_pos_emb, batch_size, dim=0
|
|
361
|
-
) # batch_size, traj_length, n_embd
|
|
362
|
-
|
|
363
|
-
position_embeddings = (
|
|
364
|
-
torch.gather(
|
|
365
|
-
all_global_pos_emb,
|
|
366
|
-
1,
|
|
367
|
-
torch.repeat_interleave(timesteps, self.config.n_embd, dim=-1),
|
|
368
|
-
)
|
|
369
|
-
+ self.pos_emb[:, : token_embeddings.shape[1], :]
|
|
370
|
-
)
|
|
371
|
-
x = self.drop(token_embeddings + position_embeddings)
|
|
372
|
-
x = self.blocks(x)
|
|
373
|
-
x = self.ln_f(x)
|
|
374
|
-
logits = self.head(x)
|
|
375
|
-
|
|
376
|
-
if actions is not None:
|
|
377
|
-
logits = logits[:, 1::3, :] # only keep predictions from state_embeddings
|
|
378
|
-
elif actions is None:
|
|
379
|
-
logits = logits[:, 1:, :]
|
|
380
|
-
|
|
381
|
-
return logits
|
|
382
|
-
|
|
383
|
-
def predict(self, states, actions, rtgs, timesteps, users):
|
|
384
|
-
"""
|
|
385
|
-
:states: states batch, (batch, block_size, 3)
|
|
386
|
-
:actions: actions batch, (batch, block_size, 1)
|
|
387
|
-
:rtgs: rtgs batch, (batch, block_size, 1)
|
|
388
|
-
:timesteps: timesteps batch, (batch, 1, 1)
|
|
389
|
-
:users:users batch, (batch, 1)
|
|
390
|
-
"""
|
|
391
|
-
logits, _ = self(
|
|
392
|
-
states=states.to(self.pos_emb.device),
|
|
393
|
-
actions=actions.to(self.pos_emb.device),
|
|
394
|
-
targets=None,
|
|
395
|
-
rtgs=rtgs.to(self.pos_emb.device),
|
|
396
|
-
timesteps=timesteps.to(self.pos_emb.device),
|
|
397
|
-
users=users.to(self.pos_emb.device),
|
|
398
|
-
)
|
|
399
|
-
logits = logits[:, -1, :]
|
|
400
|
-
actions = logits.argsort(dim=1, descending=True)
|
|
401
|
-
return actions
|