egogym 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- baselines/pi_policy.py +110 -0
- baselines/rum/__init__.py +1 -0
- baselines/rum/loss_fns/__init__.py +37 -0
- baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
- baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
- baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
- baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
- baselines/rum/models/__init__.py +1 -0
- baselines/rum/models/bet/__init__.py +3 -0
- baselines/rum/models/bet/bet.py +347 -0
- baselines/rum/models/bet/gpt.py +277 -0
- baselines/rum/models/bet/tokenized_bet.py +454 -0
- baselines/rum/models/bet/utils.py +124 -0
- baselines/rum/models/bet/vqbet.py +410 -0
- baselines/rum/models/bet/vqvae/__init__.py +3 -0
- baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
- baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
- baselines/rum/models/bet/vqvae/vqvae.py +313 -0
- baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
- baselines/rum/models/custom.py +33 -0
- baselines/rum/models/encoders/__init__.py +0 -0
- baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
- baselines/rum/models/encoders/identity.py +45 -0
- baselines/rum/models/encoders/timm_encoders.py +82 -0
- baselines/rum/models/policies/diffusion_policy.py +881 -0
- baselines/rum/models/policies/open_loop.py +122 -0
- baselines/rum/models/policies/simple_open_loop.py +108 -0
- baselines/rum/molmo/server.py +144 -0
- baselines/rum/policy.py +293 -0
- baselines/rum/utils/__init__.py +212 -0
- baselines/rum/utils/action_transforms.py +22 -0
- baselines/rum/utils/decord_transforms.py +135 -0
- baselines/rum/utils/rpc.py +249 -0
- baselines/rum/utils/schedulers.py +71 -0
- baselines/rum/utils/trajectory_vis.py +128 -0
- baselines/rum/utils/zmq_utils.py +281 -0
- baselines/rum_policy.py +108 -0
- egogym/__init__.py +8 -0
- egogym/assets/constants.py +1804 -0
- egogym/components/__init__.py +1 -0
- egogym/components/object.py +94 -0
- egogym/egogym.py +106 -0
- egogym/embodiments/__init__.py +10 -0
- egogym/embodiments/arms/__init__.py +4 -0
- egogym/embodiments/arms/arm.py +65 -0
- egogym/embodiments/arms/droid.py +49 -0
- egogym/embodiments/grippers/__init__.py +4 -0
- egogym/embodiments/grippers/floating_gripper.py +58 -0
- egogym/embodiments/grippers/rum.py +6 -0
- egogym/embodiments/robot.py +95 -0
- egogym/evaluate.py +216 -0
- egogym/managers/__init__.py +2 -0
- egogym/managers/objects_managers.py +30 -0
- egogym/managers/textures_manager.py +21 -0
- egogym/misc/molmo_client.py +49 -0
- egogym/misc/molmo_server.py +197 -0
- egogym/policies/__init__.py +1 -0
- egogym/policies/base_policy.py +13 -0
- egogym/scripts/analayze.py +834 -0
- egogym/scripts/plot.py +87 -0
- egogym/scripts/plot_correlation.py +392 -0
- egogym/scripts/plot_correlation_hardcoded.py +338 -0
- egogym/scripts/plot_failure.py +248 -0
- egogym/scripts/plot_failure_hardcoded.py +195 -0
- egogym/scripts/plot_failure_vlm.py +257 -0
- egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
- egogym/scripts/plot_line.py +303 -0
- egogym/scripts/plot_line_hardcoded.py +285 -0
- egogym/scripts/plot_pi0_bars.py +169 -0
- egogym/tasks/close.py +84 -0
- egogym/tasks/open.py +85 -0
- egogym/tasks/pick.py +121 -0
- egogym/utils.py +969 -0
- egogym/wrappers/__init__.py +20 -0
- egogym/wrappers/episode_monitor.py +282 -0
- egogym/wrappers/unprivileged_chatgpt.py +163 -0
- egogym/wrappers/unprivileged_gemini.py +157 -0
- egogym/wrappers/unprivileged_molmo.py +88 -0
- egogym/wrappers/unprivileged_moondream.py +121 -0
- egogym-0.1.0.dist-info/METADATA +52 -0
- egogym-0.1.0.dist-info/RECORD +83 -0
- egogym-0.1.0.dist-info/WHEEL +5 -0
- egogym-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Dict, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import accelerate
|
|
6
|
+
import einops
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
import tqdm
|
|
11
|
+
|
|
12
|
+
from baselines.rum.models.bet.gpt import GPT
|
|
13
|
+
from baselines.rum.models.bet.utils import MLP
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
GENERATOR_SEED_FIXED = 123456789
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class KMeansDiscretizer:
|
|
20
|
+
"""
|
|
21
|
+
Simplified and modified version of KMeans algorithm from sklearn.
|
|
22
|
+
We initialize this with a fixed seed to ensure that on each GPU we come up with the same
|
|
23
|
+
clusters.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
num_bins: int = 100,
|
|
29
|
+
kmeans_iters: int = 50,
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.n_bins = num_bins
|
|
33
|
+
self.kmeans_iters = kmeans_iters
|
|
34
|
+
|
|
35
|
+
def fit(self, input_actions: torch.Tensor) -> None:
|
|
36
|
+
self.bin_centers = KMeansDiscretizer._kmeans(
|
|
37
|
+
input_actions, ncluster=self.n_bins, niter=self.kmeans_iters
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def _kmeans(cls, x: torch.Tensor, ncluster: int = 512, niter: int = 50):
|
|
42
|
+
"""
|
|
43
|
+
Simple k-means clustering algorithm adapted from Karpathy's minGPT libary
|
|
44
|
+
https://github.com/karpathy/minGPT/blob/master/play_image.ipynb
|
|
45
|
+
"""
|
|
46
|
+
N, D = x.size()
|
|
47
|
+
generator = torch.Generator()
|
|
48
|
+
generator.manual_seed(GENERATOR_SEED_FIXED)
|
|
49
|
+
|
|
50
|
+
c = x[
|
|
51
|
+
torch.randperm(N, generator=generator)[:ncluster]
|
|
52
|
+
] # init clusters at random, with a fixed seed
|
|
53
|
+
|
|
54
|
+
pbar = tqdm.trange(niter)
|
|
55
|
+
pbar.set_description("K-means clustering")
|
|
56
|
+
for i in pbar:
|
|
57
|
+
# assign all pixels to the closest codebook element
|
|
58
|
+
a = ((x[:, None, :] - c[None, :, :]) ** 2).sum(-1).argmin(1)
|
|
59
|
+
# move each codebook element to be the mean of the pixels that assigned to it
|
|
60
|
+
c = torch.stack([x[a == k].mean(0) for k in range(ncluster)])
|
|
61
|
+
# re-assign any poorly positioned codebook elements
|
|
62
|
+
nanix = torch.any(torch.isnan(c), dim=1)
|
|
63
|
+
ndead = nanix.sum().item()
|
|
64
|
+
if ndead:
|
|
65
|
+
tqdm.tqdm.write(
|
|
66
|
+
"done step %d/%d, re-initialized %d dead clusters"
|
|
67
|
+
% (i + 1, niter, ndead)
|
|
68
|
+
)
|
|
69
|
+
c[nanix] = x[
|
|
70
|
+
torch.randperm(N, generator=generator)[:ndead]
|
|
71
|
+
] # re-init dead clusters
|
|
72
|
+
return c
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class BehaviorTransformer(nn.Module):
|
|
76
|
+
GOAL_SPEC = Enum("GOAL_SPEC", "concat stack unconditional")
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
obs_dim: int,
|
|
81
|
+
act_dim: int,
|
|
82
|
+
goal_dim: int,
|
|
83
|
+
gpt_model: GPT,
|
|
84
|
+
num_extra_predicted_actions: Optional[int] = None,
|
|
85
|
+
trainable_obs_padding: bool = True,
|
|
86
|
+
n_clusters: int = 32,
|
|
87
|
+
kmeans_fit_steps: int = 500,
|
|
88
|
+
kmeans_iters: int = 50,
|
|
89
|
+
offset_loss_multiplier: float = 1.0e3,
|
|
90
|
+
offset_distance_metric: str = "L2",
|
|
91
|
+
gamma: float = 2.0,
|
|
92
|
+
**kwargs,
|
|
93
|
+
):
|
|
94
|
+
super().__init__()
|
|
95
|
+
self._obs_dim = obs_dim
|
|
96
|
+
self._act_dim = act_dim
|
|
97
|
+
self._goal_dim = goal_dim
|
|
98
|
+
self._num_extra_predicted_actions = num_extra_predicted_actions
|
|
99
|
+
# Gradient-free, all zeros if we don't want to train this.
|
|
100
|
+
self._obs_padding = nn.Parameter(
|
|
101
|
+
trainable_obs_padding * torch.randn(obs_dim),
|
|
102
|
+
requires_grad=trainable_obs_padding,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
if goal_dim <= 0:
|
|
106
|
+
self._cbet_method = self.GOAL_SPEC.unconditional
|
|
107
|
+
elif obs_dim == goal_dim:
|
|
108
|
+
self._cbet_method = self.GOAL_SPEC.concat
|
|
109
|
+
else:
|
|
110
|
+
self._cbet_method = self.GOAL_SPEC.stack
|
|
111
|
+
|
|
112
|
+
self._gpt_model = gpt_model
|
|
113
|
+
# For now, we assume the number of clusters is given.
|
|
114
|
+
assert n_clusters > 0
|
|
115
|
+
self._K = n_clusters
|
|
116
|
+
self._kmeans_fit_steps = kmeans_fit_steps
|
|
117
|
+
self._clustering_algo = KMeansDiscretizer(
|
|
118
|
+
num_bins=n_clusters, kmeans_iters=kmeans_iters
|
|
119
|
+
)
|
|
120
|
+
self._current_steps = 0
|
|
121
|
+
self._map_to_cbet_preds = MLP(
|
|
122
|
+
in_channels=gpt_model.config.output_dim,
|
|
123
|
+
hidden_channels=[(act_dim + 1) * n_clusters],
|
|
124
|
+
)
|
|
125
|
+
self._collected_actions = []
|
|
126
|
+
self._have_fit_kmeans = False
|
|
127
|
+
self._offset_loss_multiplier = offset_loss_multiplier
|
|
128
|
+
# Placeholder for the cluster centers.
|
|
129
|
+
generator = torch.Generator()
|
|
130
|
+
generator.manual_seed(GENERATOR_SEED_FIXED)
|
|
131
|
+
self.register_buffer(
|
|
132
|
+
"_cluster_centers",
|
|
133
|
+
torch.randn(
|
|
134
|
+
(n_clusters, act_dim), generator=generator, dtype=torch.float32
|
|
135
|
+
),
|
|
136
|
+
)
|
|
137
|
+
self._criterion = FocalLoss(gamma=gamma, reduction="none")
|
|
138
|
+
self._offset_criterion = (
|
|
139
|
+
nn.MSELoss(reduction="none")
|
|
140
|
+
if offset_distance_metric == "L2"
|
|
141
|
+
else nn.L1Loss(reduction="none")
|
|
142
|
+
)
|
|
143
|
+
self._accelerator = accelerate.Accelerator()
|
|
144
|
+
|
|
145
|
+
def _load_from_state_dict(self, *args, **kwargs):
|
|
146
|
+
# Don't fit kmeans if we are loading from a state dict.
|
|
147
|
+
self._current_steps = self._kmeans_fit_steps
|
|
148
|
+
self._have_fit_kmeans = True
|
|
149
|
+
return super()._load_from_state_dict(*args, **kwargs)
|
|
150
|
+
|
|
151
|
+
def forward(
|
|
152
|
+
self,
|
|
153
|
+
obs_seq: torch.Tensor,
|
|
154
|
+
goal_seq: Optional[torch.Tensor],
|
|
155
|
+
action_seq: Optional[torch.Tensor],
|
|
156
|
+
padding_seq: Optional[torch.Tensor],
|
|
157
|
+
predict_with_offset: bool = True,
|
|
158
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
159
|
+
if self._current_steps == 0:
|
|
160
|
+
self._cluster_centers = self._cluster_centers.to(obs_seq.device)
|
|
161
|
+
if self._current_steps < self._kmeans_fit_steps and action_seq is not None:
|
|
162
|
+
self._current_steps += 1
|
|
163
|
+
self._fit_kmeans(obs_seq, goal_seq, action_seq, padding_seq)
|
|
164
|
+
return self._predict(
|
|
165
|
+
obs_seq,
|
|
166
|
+
goal_seq,
|
|
167
|
+
action_seq,
|
|
168
|
+
padding_seq,
|
|
169
|
+
predict_with_offset=predict_with_offset,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def _fit_kmeans(
|
|
173
|
+
self,
|
|
174
|
+
obs_seq: torch.Tensor,
|
|
175
|
+
goal_seq: Optional[torch.Tensor],
|
|
176
|
+
action_seq: Optional[torch.Tensor],
|
|
177
|
+
padding_seq: Optional[torch.Tensor],
|
|
178
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
179
|
+
assert self._current_steps <= self._kmeans_fit_steps
|
|
180
|
+
if self._current_steps == 1:
|
|
181
|
+
self._cluster_centers = self._cluster_centers.to(action_seq.device)
|
|
182
|
+
|
|
183
|
+
all_action_seq = self._accelerator.gather(action_seq)
|
|
184
|
+
all_padding_seq = self._accelerator.gather(padding_seq)
|
|
185
|
+
self._collected_actions.append(
|
|
186
|
+
all_action_seq[torch.logical_not(all_padding_seq)]
|
|
187
|
+
)
|
|
188
|
+
if self._current_steps == self._kmeans_fit_steps:
|
|
189
|
+
logging.info("Fitting KMeans")
|
|
190
|
+
self._clustering_algo.fit(
|
|
191
|
+
torch.cat(self._collected_actions, dim=0).view(-1, self._act_dim)
|
|
192
|
+
)
|
|
193
|
+
self._have_fit_kmeans = True
|
|
194
|
+
self._cluster_centers = self._clustering_algo.bin_centers.float().to(
|
|
195
|
+
action_seq.device
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def _predict(
|
|
199
|
+
self,
|
|
200
|
+
obs_seq: torch.Tensor,
|
|
201
|
+
goal_seq: Optional[torch.Tensor],
|
|
202
|
+
action_seq: Optional[torch.Tensor],
|
|
203
|
+
is_padded_action_seq: Optional[torch.Tensor],
|
|
204
|
+
predict_with_offset: bool = True,
|
|
205
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
|
|
206
|
+
batch_size, obs_T, _ = obs_seq.shape
|
|
207
|
+
_, action_T, _ = (
|
|
208
|
+
action_seq.shape if action_seq is not None else (None, None, None)
|
|
209
|
+
)
|
|
210
|
+
# Take the one that is not None.
|
|
211
|
+
actions_to_predict = action_T or obs_T
|
|
212
|
+
if self._num_extra_predicted_actions:
|
|
213
|
+
actions_to_predict += self._num_extra_predicted_actions
|
|
214
|
+
# Now, figure out if we should pad the obs seq.
|
|
215
|
+
if obs_T < actions_to_predict:
|
|
216
|
+
# We need to pad the obs seq.
|
|
217
|
+
pad_size = actions_to_predict - obs_T
|
|
218
|
+
padded_obs_seq = torch.cat(
|
|
219
|
+
[
|
|
220
|
+
obs_seq,
|
|
221
|
+
einops.repeat(
|
|
222
|
+
self._obs_padding, "D -> N T D", N=batch_size, T=pad_size
|
|
223
|
+
),
|
|
224
|
+
],
|
|
225
|
+
dim=1,
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
padded_obs_seq = obs_seq
|
|
229
|
+
# Assume dimensions are N T D for N sequences of T timesteps with dimension D.
|
|
230
|
+
if self._cbet_method == self.GOAL_SPEC.unconditional:
|
|
231
|
+
gpt_input = padded_obs_seq
|
|
232
|
+
elif self._cbet_method == self.GOAL_SPEC.concat:
|
|
233
|
+
gpt_input = torch.cat([goal_seq, padded_obs_seq], dim=1)
|
|
234
|
+
elif self._cbet_method == self.GOAL_SPEC.stack:
|
|
235
|
+
gpt_input = torch.cat([goal_seq, padded_obs_seq], dim=-1)
|
|
236
|
+
else:
|
|
237
|
+
raise NotImplementedError
|
|
238
|
+
|
|
239
|
+
gpt_output = self._gpt_model(gpt_input)
|
|
240
|
+
if self._cbet_method == self.GOAL_SPEC.concat:
|
|
241
|
+
# Chop off the goal encodings.
|
|
242
|
+
gpt_output = gpt_output[:, goal_seq.size(1) :, :]
|
|
243
|
+
cbet_preds = self._map_to_cbet_preds(gpt_output)
|
|
244
|
+
cbet_logits, cbet_offsets = torch.split(
|
|
245
|
+
cbet_preds, [self._K, self._K * self._act_dim], dim=-1
|
|
246
|
+
)
|
|
247
|
+
cbet_offsets = einops.rearrange(cbet_offsets, "N T (K A) -> N T K A", K=self._K)
|
|
248
|
+
|
|
249
|
+
cbet_probs = torch.softmax(cbet_logits, dim=-1)
|
|
250
|
+
N, T, choices = cbet_probs.shape
|
|
251
|
+
# Sample from the multinomial distribution, one per row.
|
|
252
|
+
sampled_centers = einops.rearrange(
|
|
253
|
+
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
|
254
|
+
"(N T) 1 -> N T 1",
|
|
255
|
+
N=N,
|
|
256
|
+
)
|
|
257
|
+
flattened_cbet_offsets = einops.rearrange(cbet_offsets, "N T K A -> (N T) K A")
|
|
258
|
+
sampled_offsets = flattened_cbet_offsets[
|
|
259
|
+
torch.arange(flattened_cbet_offsets.shape[0]), sampled_centers.flatten()
|
|
260
|
+
].view(N, T, self._act_dim)
|
|
261
|
+
centers = self._cluster_centers[sampled_centers.flatten()].view(
|
|
262
|
+
N, T, self._act_dim
|
|
263
|
+
)
|
|
264
|
+
a_hat = centers
|
|
265
|
+
if predict_with_offset:
|
|
266
|
+
a_hat += sampled_offsets
|
|
267
|
+
if action_seq is None:
|
|
268
|
+
return a_hat, None, {}
|
|
269
|
+
# We are in training, so figure out the loss for the actions.
|
|
270
|
+
# First, we need to find the closest cluster center for each action.
|
|
271
|
+
action_bins = self._find_closest_cluster(action_seq)
|
|
272
|
+
true_offsets = action_seq - self._cluster_centers[action_bins]
|
|
273
|
+
predicted_offsets = flattened_cbet_offsets[
|
|
274
|
+
torch.arange(flattened_cbet_offsets.shape[0]), action_bins.flatten()
|
|
275
|
+
].view(N, T, self._act_dim)
|
|
276
|
+
# Now we can compute the loss.
|
|
277
|
+
offset_loss = self._offset_criterion(predicted_offsets, true_offsets)
|
|
278
|
+
cbet_loss = self._criterion(
|
|
279
|
+
einops.rearrange(cbet_logits, "N T D -> (N T) D"),
|
|
280
|
+
einops.rearrange(action_bins, "N T -> (N T)"),
|
|
281
|
+
)
|
|
282
|
+
# Now, use the padding mask to mask out the loss.
|
|
283
|
+
if is_padded_action_seq is not None:
|
|
284
|
+
cbet_loss *= ~is_padded_action_seq.view(-1)
|
|
285
|
+
offset_loss *= ~is_padded_action_seq.unsqueeze(-1)
|
|
286
|
+
cbet_loss, offset_loss = cbet_loss.mean(), offset_loss.mean()
|
|
287
|
+
loss = cbet_loss + self._offset_loss_multiplier * offset_loss
|
|
288
|
+
action_mse = F.mse_loss(a_hat, action_seq, reduction="none")
|
|
289
|
+
action_l1 = F.l1_loss(a_hat, action_seq, reduction="none")
|
|
290
|
+
norm = torch.norm(action_seq, p=2, dim=-1, keepdim=True) + 1e-9
|
|
291
|
+
normalized_mse = (action_mse / norm).mean()
|
|
292
|
+
if self._current_steps < self._kmeans_fit_steps:
|
|
293
|
+
loss = loss.detach() + (loss * 0.0)
|
|
294
|
+
|
|
295
|
+
loss_dict = {
|
|
296
|
+
"classification_loss": cbet_loss.detach().cpu().item(),
|
|
297
|
+
"offset_loss": offset_loss.detach().cpu().item(),
|
|
298
|
+
"loss": loss.detach().cpu().item(),
|
|
299
|
+
"L2_loss": action_mse.mean().detach().cpu().item(),
|
|
300
|
+
"L2_loss_normalized": normalized_mse.mean().detach().cpu().item(),
|
|
301
|
+
"L1_loss": action_l1.mean().detach().cpu().item(),
|
|
302
|
+
}
|
|
303
|
+
return a_hat, loss, loss_dict
|
|
304
|
+
|
|
305
|
+
def _find_closest_cluster(self, action_seq: torch.Tensor) -> torch.Tensor:
|
|
306
|
+
N, T, _ = action_seq.shape
|
|
307
|
+
flattened_actions = einops.rearrange(action_seq, "N T A -> (N T) A")
|
|
308
|
+
cluster_center_distance = torch.sum(
|
|
309
|
+
(flattened_actions[:, None, :] - self._cluster_centers[None, :, :]) ** 2,
|
|
310
|
+
dim=2,
|
|
311
|
+
) # (N T) K A -> (N T) K
|
|
312
|
+
closest_cluster_center = torch.argmin(cluster_center_distance, dim=1) # (N T)
|
|
313
|
+
discretized_action = einops.rearrange(
|
|
314
|
+
closest_cluster_center, "(N T) -> N T", N=N, T=T
|
|
315
|
+
)
|
|
316
|
+
return discretized_action
|
|
317
|
+
|
|
318
|
+
def configure_optimizers(self, weight_decay, learning_rate, betas):
|
|
319
|
+
optimizer = self._gpt_model.configure_optimizers(
|
|
320
|
+
weight_decay=weight_decay,
|
|
321
|
+
learning_rate=learning_rate,
|
|
322
|
+
betas=betas,
|
|
323
|
+
)
|
|
324
|
+
optimizer.add_param_group({"params": self._map_to_cbet_preds.parameters()})
|
|
325
|
+
return optimizer
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class FocalLoss(nn.Module):
|
|
329
|
+
def __init__(self, gamma: float = 0, reduction: str = "mean"):
|
|
330
|
+
super(FocalLoss, self).__init__()
|
|
331
|
+
self.gamma = gamma
|
|
332
|
+
if reduction not in ("mean", "sum", "none"):
|
|
333
|
+
raise NotImplementedError
|
|
334
|
+
self.reduction = reduction
|
|
335
|
+
|
|
336
|
+
def forward(self, input, target):
|
|
337
|
+
logpt = F.log_softmax(input, dim=-1)
|
|
338
|
+
logpt = logpt.gather(1, target.view(-1, 1)).view(-1)
|
|
339
|
+
pt = logpt.exp()
|
|
340
|
+
|
|
341
|
+
loss = -1 * (1 - pt) ** self.gamma * logpt
|
|
342
|
+
if self.reduction == "mean":
|
|
343
|
+
return loss.mean()
|
|
344
|
+
elif self.reduction == "sum":
|
|
345
|
+
return loss.sum()
|
|
346
|
+
else:
|
|
347
|
+
return loss
|
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""
|
|
2
|
+
An adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
|
3
|
+
Original source: https://github.com/karpathy/nanoGPT
|
|
4
|
+
|
|
5
|
+
Original License:
|
|
6
|
+
MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2022 Andrej Karpathy
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Original comments:
|
|
29
|
+
Full definition of a GPT Language Model, all of it in this single file.
|
|
30
|
+
References:
|
|
31
|
+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
|
32
|
+
https://github.com/openai/gpt-2/blob/master/src/model.py
|
|
33
|
+
2) huggingface/transformers PyTorch implementation:
|
|
34
|
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
import math
|
|
38
|
+
from dataclasses import dataclass
|
|
39
|
+
|
|
40
|
+
import torch
|
|
41
|
+
import torch.nn as nn
|
|
42
|
+
from torch.nn import functional as F
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
|
|
46
|
+
def new_gelu(x):
|
|
47
|
+
"""
|
|
48
|
+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
|
|
49
|
+
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
|
|
50
|
+
"""
|
|
51
|
+
return (
|
|
52
|
+
0.5
|
|
53
|
+
* x
|
|
54
|
+
* (
|
|
55
|
+
1.0
|
|
56
|
+
+ torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class CausalSelfAttention(nn.Module):
|
|
62
|
+
def __init__(self, config):
|
|
63
|
+
super().__init__()
|
|
64
|
+
assert config.n_embd % config.n_head == 0
|
|
65
|
+
# key, query, value projections for all heads, but in a batch
|
|
66
|
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
|
67
|
+
# output projection
|
|
68
|
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
|
69
|
+
# regularization
|
|
70
|
+
self.attn_dropout = nn.Dropout(config.dropout)
|
|
71
|
+
self.resid_dropout = nn.Dropout(config.dropout)
|
|
72
|
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
73
|
+
self.register_buffer(
|
|
74
|
+
"bias",
|
|
75
|
+
torch.tril(torch.ones(config.block_size, config.block_size)).view(
|
|
76
|
+
1, 1, config.block_size, config.block_size
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
self.n_head = config.n_head
|
|
80
|
+
self.n_embd = config.n_embd
|
|
81
|
+
|
|
82
|
+
def forward(self, x):
|
|
83
|
+
# batch size, sequence length, embedding dimensionality (n_embd)
|
|
84
|
+
B, T, C = x.size()
|
|
85
|
+
|
|
86
|
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
87
|
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
|
88
|
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(
|
|
89
|
+
1, 2
|
|
90
|
+
) # (B, nh, T, hs)
|
|
91
|
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(
|
|
92
|
+
1, 2
|
|
93
|
+
) # (B, nh, T, hs)
|
|
94
|
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(
|
|
95
|
+
1, 2
|
|
96
|
+
) # (B, nh, T, hs)
|
|
97
|
+
|
|
98
|
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
99
|
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
100
|
+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
|
101
|
+
att = F.softmax(att, dim=-1)
|
|
102
|
+
att = self.attn_dropout(att)
|
|
103
|
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
104
|
+
y = (
|
|
105
|
+
y.transpose(1, 2).contiguous().view(B, T, C)
|
|
106
|
+
) # re-assemble all head outputs side by side
|
|
107
|
+
|
|
108
|
+
# output projection
|
|
109
|
+
y = self.resid_dropout(self.c_proj(y))
|
|
110
|
+
return y
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class MLP(nn.Module):
|
|
114
|
+
def __init__(self, config):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
|
117
|
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
|
118
|
+
self.dropout = nn.Dropout(config.dropout)
|
|
119
|
+
|
|
120
|
+
def forward(self, x):
|
|
121
|
+
x = self.c_fc(x)
|
|
122
|
+
x = new_gelu(x)
|
|
123
|
+
x = self.c_proj(x)
|
|
124
|
+
x = self.dropout(x)
|
|
125
|
+
return x
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class Block(nn.Module):
|
|
129
|
+
def __init__(self, config):
|
|
130
|
+
super().__init__()
|
|
131
|
+
self.ln_1 = nn.LayerNorm(config.n_embd)
|
|
132
|
+
self.attn = CausalSelfAttention(config)
|
|
133
|
+
self.ln_2 = nn.LayerNorm(config.n_embd)
|
|
134
|
+
self.mlp = MLP(config)
|
|
135
|
+
|
|
136
|
+
def forward(self, x):
|
|
137
|
+
x = x + self.attn(self.ln_1(x))
|
|
138
|
+
x = x + self.mlp(self.ln_2(x))
|
|
139
|
+
return x
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class GPTConfig:
|
|
144
|
+
block_size: int = 1024
|
|
145
|
+
input_dim: int = 256
|
|
146
|
+
output_dim: int = 256
|
|
147
|
+
n_layer: int = 12
|
|
148
|
+
n_head: int = 12
|
|
149
|
+
n_embd: int = 768
|
|
150
|
+
dropout: float = 0.1
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class GPT(nn.Module):
|
|
154
|
+
def __init__(self, config):
|
|
155
|
+
super().__init__()
|
|
156
|
+
assert config.input_dim is not None
|
|
157
|
+
assert config.output_dim is not None
|
|
158
|
+
assert config.block_size is not None
|
|
159
|
+
self.config = config
|
|
160
|
+
|
|
161
|
+
self.transformer = nn.ModuleDict(
|
|
162
|
+
dict(
|
|
163
|
+
wte=nn.Linear(config.input_dim, config.n_embd),
|
|
164
|
+
wpe=nn.Embedding(config.block_size, config.n_embd),
|
|
165
|
+
drop=nn.Dropout(config.dropout),
|
|
166
|
+
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
|
167
|
+
ln_f=nn.LayerNorm(config.n_embd),
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False)
|
|
171
|
+
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
|
|
172
|
+
self.apply(self._init_weights)
|
|
173
|
+
for pn, p in self.named_parameters():
|
|
174
|
+
if pn.endswith("c_proj.weight"):
|
|
175
|
+
torch.nn.init.normal_(
|
|
176
|
+
p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# report number of parameters
|
|
180
|
+
n_params = sum(p.numel() for p in self.parameters())
|
|
181
|
+
|
|
182
|
+
def forward(self, input, targets=None):
|
|
183
|
+
device = input.device
|
|
184
|
+
b, t, d = input.size()
|
|
185
|
+
assert (
|
|
186
|
+
t <= self.config.block_size
|
|
187
|
+
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
|
188
|
+
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
|
|
189
|
+
0
|
|
190
|
+
) # shape (1, t)
|
|
191
|
+
|
|
192
|
+
# forward the GPT model itself
|
|
193
|
+
tok_emb = self.transformer.wte(
|
|
194
|
+
input
|
|
195
|
+
) # token embeddings of shape (b, t, n_embd)
|
|
196
|
+
pos_emb = self.transformer.wpe(
|
|
197
|
+
pos
|
|
198
|
+
) # position embeddings of shape (1, t, n_embd)
|
|
199
|
+
x = self.transformer.drop(tok_emb + pos_emb)
|
|
200
|
+
for block in self.transformer.h:
|
|
201
|
+
x = block(x)
|
|
202
|
+
x = self.transformer.ln_f(x)
|
|
203
|
+
logits = self.lm_head(x)
|
|
204
|
+
return logits
|
|
205
|
+
|
|
206
|
+
def _init_weights(self, module):
|
|
207
|
+
if isinstance(module, nn.Linear):
|
|
208
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
209
|
+
if module.bias is not None:
|
|
210
|
+
torch.nn.init.zeros_(module.bias)
|
|
211
|
+
elif isinstance(module, nn.Embedding):
|
|
212
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
213
|
+
elif isinstance(module, nn.LayerNorm):
|
|
214
|
+
torch.nn.init.zeros_(module.bias)
|
|
215
|
+
torch.nn.init.ones_(module.weight)
|
|
216
|
+
|
|
217
|
+
def crop_block_size(self, block_size):
|
|
218
|
+
assert block_size <= self.config.block_size
|
|
219
|
+
self.config.block_size = block_size
|
|
220
|
+
self.transformer.wpe.weight = nn.Parameter(
|
|
221
|
+
self.transformer.wpe.weight[:block_size]
|
|
222
|
+
)
|
|
223
|
+
for block in self.transformer.h:
|
|
224
|
+
block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
|
|
225
|
+
|
|
226
|
+
def configure_optimizers(self, weight_decay, learning_rate, betas):
|
|
227
|
+
"""
|
|
228
|
+
This long function is unfortunately doing something very simple and is being very defensive:
|
|
229
|
+
We are separating out all parameters of the model into two buckets: those that will experience
|
|
230
|
+
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
|
231
|
+
We are then returning the PyTorch optimizer object.
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
235
|
+
decay = set()
|
|
236
|
+
no_decay = set()
|
|
237
|
+
whitelist_weight_modules = (torch.nn.Linear,)
|
|
238
|
+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
239
|
+
for mn, m in self.named_modules():
|
|
240
|
+
for pn, p in m.named_parameters():
|
|
241
|
+
fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
|
|
242
|
+
if pn.endswith("bias"):
|
|
243
|
+
# all biases will not be decayed
|
|
244
|
+
no_decay.add(fpn)
|
|
245
|
+
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
|
246
|
+
# weights of whitelist modules will be weight decayed
|
|
247
|
+
decay.add(fpn)
|
|
248
|
+
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
|
249
|
+
# weights of blacklist modules will NOT be weight decayed
|
|
250
|
+
no_decay.add(fpn)
|
|
251
|
+
|
|
252
|
+
# validate that we considered every parameter
|
|
253
|
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
254
|
+
inter_params = decay & no_decay
|
|
255
|
+
union_params = decay | no_decay
|
|
256
|
+
assert len(inter_params) == 0, (
|
|
257
|
+
"parameters %s made it into both decay/no_decay sets!"
|
|
258
|
+
% (str(inter_params),)
|
|
259
|
+
)
|
|
260
|
+
assert len(param_dict.keys() - union_params) == 0, (
|
|
261
|
+
"parameters %s were not separated into either decay/no_decay set!"
|
|
262
|
+
% (str(param_dict.keys() - union_params),)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# create the pytorch optimizer object
|
|
266
|
+
optim_groups = [
|
|
267
|
+
{
|
|
268
|
+
"params": [param_dict[pn] for pn in sorted(list(decay))],
|
|
269
|
+
"weight_decay": weight_decay,
|
|
270
|
+
},
|
|
271
|
+
{
|
|
272
|
+
"params": [param_dict[pn] for pn in sorted(list(no_decay))],
|
|
273
|
+
"weight_decay": 0.0,
|
|
274
|
+
},
|
|
275
|
+
]
|
|
276
|
+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
|
|
277
|
+
return optimizer
|