egogym 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. baselines/pi_policy.py +110 -0
  2. baselines/rum/__init__.py +1 -0
  3. baselines/rum/loss_fns/__init__.py +37 -0
  4. baselines/rum/loss_fns/abstract_loss_fn.py +13 -0
  5. baselines/rum/loss_fns/diffusion_policy_loss_fn.py +114 -0
  6. baselines/rum/loss_fns/rvq_loss_fn.py +104 -0
  7. baselines/rum/loss_fns/vqbet_loss_fn.py +202 -0
  8. baselines/rum/models/__init__.py +1 -0
  9. baselines/rum/models/bet/__init__.py +3 -0
  10. baselines/rum/models/bet/bet.py +347 -0
  11. baselines/rum/models/bet/gpt.py +277 -0
  12. baselines/rum/models/bet/tokenized_bet.py +454 -0
  13. baselines/rum/models/bet/utils.py +124 -0
  14. baselines/rum/models/bet/vqbet.py +410 -0
  15. baselines/rum/models/bet/vqvae/__init__.py +3 -0
  16. baselines/rum/models/bet/vqvae/residual_vq.py +346 -0
  17. baselines/rum/models/bet/vqvae/vector_quantize_pytorch.py +1194 -0
  18. baselines/rum/models/bet/vqvae/vqvae.py +313 -0
  19. baselines/rum/models/bet/vqvae/vqvae_utils.py +30 -0
  20. baselines/rum/models/custom.py +33 -0
  21. baselines/rum/models/encoders/__init__.py +0 -0
  22. baselines/rum/models/encoders/abstract_base_encoder.py +70 -0
  23. baselines/rum/models/encoders/identity.py +45 -0
  24. baselines/rum/models/encoders/timm_encoders.py +82 -0
  25. baselines/rum/models/policies/diffusion_policy.py +881 -0
  26. baselines/rum/models/policies/open_loop.py +122 -0
  27. baselines/rum/models/policies/simple_open_loop.py +108 -0
  28. baselines/rum/molmo/server.py +144 -0
  29. baselines/rum/policy.py +293 -0
  30. baselines/rum/utils/__init__.py +212 -0
  31. baselines/rum/utils/action_transforms.py +22 -0
  32. baselines/rum/utils/decord_transforms.py +135 -0
  33. baselines/rum/utils/rpc.py +249 -0
  34. baselines/rum/utils/schedulers.py +71 -0
  35. baselines/rum/utils/trajectory_vis.py +128 -0
  36. baselines/rum/utils/zmq_utils.py +281 -0
  37. baselines/rum_policy.py +108 -0
  38. egogym/__init__.py +8 -0
  39. egogym/assets/constants.py +1804 -0
  40. egogym/components/__init__.py +1 -0
  41. egogym/components/object.py +94 -0
  42. egogym/egogym.py +106 -0
  43. egogym/embodiments/__init__.py +10 -0
  44. egogym/embodiments/arms/__init__.py +4 -0
  45. egogym/embodiments/arms/arm.py +65 -0
  46. egogym/embodiments/arms/droid.py +49 -0
  47. egogym/embodiments/grippers/__init__.py +4 -0
  48. egogym/embodiments/grippers/floating_gripper.py +58 -0
  49. egogym/embodiments/grippers/rum.py +6 -0
  50. egogym/embodiments/robot.py +95 -0
  51. egogym/evaluate.py +216 -0
  52. egogym/managers/__init__.py +2 -0
  53. egogym/managers/objects_managers.py +30 -0
  54. egogym/managers/textures_manager.py +21 -0
  55. egogym/misc/molmo_client.py +49 -0
  56. egogym/misc/molmo_server.py +197 -0
  57. egogym/policies/__init__.py +1 -0
  58. egogym/policies/base_policy.py +13 -0
  59. egogym/scripts/analayze.py +834 -0
  60. egogym/scripts/plot.py +87 -0
  61. egogym/scripts/plot_correlation.py +392 -0
  62. egogym/scripts/plot_correlation_hardcoded.py +338 -0
  63. egogym/scripts/plot_failure.py +248 -0
  64. egogym/scripts/plot_failure_hardcoded.py +195 -0
  65. egogym/scripts/plot_failure_vlm.py +257 -0
  66. egogym/scripts/plot_failure_vlm_hardcoded.py +177 -0
  67. egogym/scripts/plot_line.py +303 -0
  68. egogym/scripts/plot_line_hardcoded.py +285 -0
  69. egogym/scripts/plot_pi0_bars.py +169 -0
  70. egogym/tasks/close.py +84 -0
  71. egogym/tasks/open.py +85 -0
  72. egogym/tasks/pick.py +121 -0
  73. egogym/utils.py +969 -0
  74. egogym/wrappers/__init__.py +20 -0
  75. egogym/wrappers/episode_monitor.py +282 -0
  76. egogym/wrappers/unprivileged_chatgpt.py +163 -0
  77. egogym/wrappers/unprivileged_gemini.py +157 -0
  78. egogym/wrappers/unprivileged_molmo.py +88 -0
  79. egogym/wrappers/unprivileged_moondream.py +121 -0
  80. egogym-0.1.0.dist-info/METADATA +52 -0
  81. egogym-0.1.0.dist-info/RECORD +83 -0
  82. egogym-0.1.0.dist-info/WHEEL +5 -0
  83. egogym-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,347 @@
1
+ import logging
2
+ from enum import Enum
3
+ from typing import Dict, Optional, Tuple
4
+
5
+ import accelerate
6
+ import einops
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import tqdm
11
+
12
+ from baselines.rum.models.bet.gpt import GPT
13
+ from baselines.rum.models.bet.utils import MLP
14
+
15
+
16
+ GENERATOR_SEED_FIXED = 123456789
17
+
18
+
19
+ class KMeansDiscretizer:
20
+ """
21
+ Simplified and modified version of KMeans algorithm from sklearn.
22
+ We initialize this with a fixed seed to ensure that on each GPU we come up with the same
23
+ clusters.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ num_bins: int = 100,
29
+ kmeans_iters: int = 50,
30
+ ):
31
+ super().__init__()
32
+ self.n_bins = num_bins
33
+ self.kmeans_iters = kmeans_iters
34
+
35
+ def fit(self, input_actions: torch.Tensor) -> None:
36
+ self.bin_centers = KMeansDiscretizer._kmeans(
37
+ input_actions, ncluster=self.n_bins, niter=self.kmeans_iters
38
+ )
39
+
40
+ @classmethod
41
+ def _kmeans(cls, x: torch.Tensor, ncluster: int = 512, niter: int = 50):
42
+ """
43
+ Simple k-means clustering algorithm adapted from Karpathy's minGPT libary
44
+ https://github.com/karpathy/minGPT/blob/master/play_image.ipynb
45
+ """
46
+ N, D = x.size()
47
+ generator = torch.Generator()
48
+ generator.manual_seed(GENERATOR_SEED_FIXED)
49
+
50
+ c = x[
51
+ torch.randperm(N, generator=generator)[:ncluster]
52
+ ] # init clusters at random, with a fixed seed
53
+
54
+ pbar = tqdm.trange(niter)
55
+ pbar.set_description("K-means clustering")
56
+ for i in pbar:
57
+ # assign all pixels to the closest codebook element
58
+ a = ((x[:, None, :] - c[None, :, :]) ** 2).sum(-1).argmin(1)
59
+ # move each codebook element to be the mean of the pixels that assigned to it
60
+ c = torch.stack([x[a == k].mean(0) for k in range(ncluster)])
61
+ # re-assign any poorly positioned codebook elements
62
+ nanix = torch.any(torch.isnan(c), dim=1)
63
+ ndead = nanix.sum().item()
64
+ if ndead:
65
+ tqdm.tqdm.write(
66
+ "done step %d/%d, re-initialized %d dead clusters"
67
+ % (i + 1, niter, ndead)
68
+ )
69
+ c[nanix] = x[
70
+ torch.randperm(N, generator=generator)[:ndead]
71
+ ] # re-init dead clusters
72
+ return c
73
+
74
+
75
+ class BehaviorTransformer(nn.Module):
76
+ GOAL_SPEC = Enum("GOAL_SPEC", "concat stack unconditional")
77
+
78
+ def __init__(
79
+ self,
80
+ obs_dim: int,
81
+ act_dim: int,
82
+ goal_dim: int,
83
+ gpt_model: GPT,
84
+ num_extra_predicted_actions: Optional[int] = None,
85
+ trainable_obs_padding: bool = True,
86
+ n_clusters: int = 32,
87
+ kmeans_fit_steps: int = 500,
88
+ kmeans_iters: int = 50,
89
+ offset_loss_multiplier: float = 1.0e3,
90
+ offset_distance_metric: str = "L2",
91
+ gamma: float = 2.0,
92
+ **kwargs,
93
+ ):
94
+ super().__init__()
95
+ self._obs_dim = obs_dim
96
+ self._act_dim = act_dim
97
+ self._goal_dim = goal_dim
98
+ self._num_extra_predicted_actions = num_extra_predicted_actions
99
+ # Gradient-free, all zeros if we don't want to train this.
100
+ self._obs_padding = nn.Parameter(
101
+ trainable_obs_padding * torch.randn(obs_dim),
102
+ requires_grad=trainable_obs_padding,
103
+ )
104
+
105
+ if goal_dim <= 0:
106
+ self._cbet_method = self.GOAL_SPEC.unconditional
107
+ elif obs_dim == goal_dim:
108
+ self._cbet_method = self.GOAL_SPEC.concat
109
+ else:
110
+ self._cbet_method = self.GOAL_SPEC.stack
111
+
112
+ self._gpt_model = gpt_model
113
+ # For now, we assume the number of clusters is given.
114
+ assert n_clusters > 0
115
+ self._K = n_clusters
116
+ self._kmeans_fit_steps = kmeans_fit_steps
117
+ self._clustering_algo = KMeansDiscretizer(
118
+ num_bins=n_clusters, kmeans_iters=kmeans_iters
119
+ )
120
+ self._current_steps = 0
121
+ self._map_to_cbet_preds = MLP(
122
+ in_channels=gpt_model.config.output_dim,
123
+ hidden_channels=[(act_dim + 1) * n_clusters],
124
+ )
125
+ self._collected_actions = []
126
+ self._have_fit_kmeans = False
127
+ self._offset_loss_multiplier = offset_loss_multiplier
128
+ # Placeholder for the cluster centers.
129
+ generator = torch.Generator()
130
+ generator.manual_seed(GENERATOR_SEED_FIXED)
131
+ self.register_buffer(
132
+ "_cluster_centers",
133
+ torch.randn(
134
+ (n_clusters, act_dim), generator=generator, dtype=torch.float32
135
+ ),
136
+ )
137
+ self._criterion = FocalLoss(gamma=gamma, reduction="none")
138
+ self._offset_criterion = (
139
+ nn.MSELoss(reduction="none")
140
+ if offset_distance_metric == "L2"
141
+ else nn.L1Loss(reduction="none")
142
+ )
143
+ self._accelerator = accelerate.Accelerator()
144
+
145
+ def _load_from_state_dict(self, *args, **kwargs):
146
+ # Don't fit kmeans if we are loading from a state dict.
147
+ self._current_steps = self._kmeans_fit_steps
148
+ self._have_fit_kmeans = True
149
+ return super()._load_from_state_dict(*args, **kwargs)
150
+
151
+ def forward(
152
+ self,
153
+ obs_seq: torch.Tensor,
154
+ goal_seq: Optional[torch.Tensor],
155
+ action_seq: Optional[torch.Tensor],
156
+ padding_seq: Optional[torch.Tensor],
157
+ predict_with_offset: bool = True,
158
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
159
+ if self._current_steps == 0:
160
+ self._cluster_centers = self._cluster_centers.to(obs_seq.device)
161
+ if self._current_steps < self._kmeans_fit_steps and action_seq is not None:
162
+ self._current_steps += 1
163
+ self._fit_kmeans(obs_seq, goal_seq, action_seq, padding_seq)
164
+ return self._predict(
165
+ obs_seq,
166
+ goal_seq,
167
+ action_seq,
168
+ padding_seq,
169
+ predict_with_offset=predict_with_offset,
170
+ )
171
+
172
+ def _fit_kmeans(
173
+ self,
174
+ obs_seq: torch.Tensor,
175
+ goal_seq: Optional[torch.Tensor],
176
+ action_seq: Optional[torch.Tensor],
177
+ padding_seq: Optional[torch.Tensor],
178
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
179
+ assert self._current_steps <= self._kmeans_fit_steps
180
+ if self._current_steps == 1:
181
+ self._cluster_centers = self._cluster_centers.to(action_seq.device)
182
+
183
+ all_action_seq = self._accelerator.gather(action_seq)
184
+ all_padding_seq = self._accelerator.gather(padding_seq)
185
+ self._collected_actions.append(
186
+ all_action_seq[torch.logical_not(all_padding_seq)]
187
+ )
188
+ if self._current_steps == self._kmeans_fit_steps:
189
+ logging.info("Fitting KMeans")
190
+ self._clustering_algo.fit(
191
+ torch.cat(self._collected_actions, dim=0).view(-1, self._act_dim)
192
+ )
193
+ self._have_fit_kmeans = True
194
+ self._cluster_centers = self._clustering_algo.bin_centers.float().to(
195
+ action_seq.device
196
+ )
197
+
198
+ def _predict(
199
+ self,
200
+ obs_seq: torch.Tensor,
201
+ goal_seq: Optional[torch.Tensor],
202
+ action_seq: Optional[torch.Tensor],
203
+ is_padded_action_seq: Optional[torch.Tensor],
204
+ predict_with_offset: bool = True,
205
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Dict[str, float]]:
206
+ batch_size, obs_T, _ = obs_seq.shape
207
+ _, action_T, _ = (
208
+ action_seq.shape if action_seq is not None else (None, None, None)
209
+ )
210
+ # Take the one that is not None.
211
+ actions_to_predict = action_T or obs_T
212
+ if self._num_extra_predicted_actions:
213
+ actions_to_predict += self._num_extra_predicted_actions
214
+ # Now, figure out if we should pad the obs seq.
215
+ if obs_T < actions_to_predict:
216
+ # We need to pad the obs seq.
217
+ pad_size = actions_to_predict - obs_T
218
+ padded_obs_seq = torch.cat(
219
+ [
220
+ obs_seq,
221
+ einops.repeat(
222
+ self._obs_padding, "D -> N T D", N=batch_size, T=pad_size
223
+ ),
224
+ ],
225
+ dim=1,
226
+ )
227
+ else:
228
+ padded_obs_seq = obs_seq
229
+ # Assume dimensions are N T D for N sequences of T timesteps with dimension D.
230
+ if self._cbet_method == self.GOAL_SPEC.unconditional:
231
+ gpt_input = padded_obs_seq
232
+ elif self._cbet_method == self.GOAL_SPEC.concat:
233
+ gpt_input = torch.cat([goal_seq, padded_obs_seq], dim=1)
234
+ elif self._cbet_method == self.GOAL_SPEC.stack:
235
+ gpt_input = torch.cat([goal_seq, padded_obs_seq], dim=-1)
236
+ else:
237
+ raise NotImplementedError
238
+
239
+ gpt_output = self._gpt_model(gpt_input)
240
+ if self._cbet_method == self.GOAL_SPEC.concat:
241
+ # Chop off the goal encodings.
242
+ gpt_output = gpt_output[:, goal_seq.size(1) :, :]
243
+ cbet_preds = self._map_to_cbet_preds(gpt_output)
244
+ cbet_logits, cbet_offsets = torch.split(
245
+ cbet_preds, [self._K, self._K * self._act_dim], dim=-1
246
+ )
247
+ cbet_offsets = einops.rearrange(cbet_offsets, "N T (K A) -> N T K A", K=self._K)
248
+
249
+ cbet_probs = torch.softmax(cbet_logits, dim=-1)
250
+ N, T, choices = cbet_probs.shape
251
+ # Sample from the multinomial distribution, one per row.
252
+ sampled_centers = einops.rearrange(
253
+ torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
254
+ "(N T) 1 -> N T 1",
255
+ N=N,
256
+ )
257
+ flattened_cbet_offsets = einops.rearrange(cbet_offsets, "N T K A -> (N T) K A")
258
+ sampled_offsets = flattened_cbet_offsets[
259
+ torch.arange(flattened_cbet_offsets.shape[0]), sampled_centers.flatten()
260
+ ].view(N, T, self._act_dim)
261
+ centers = self._cluster_centers[sampled_centers.flatten()].view(
262
+ N, T, self._act_dim
263
+ )
264
+ a_hat = centers
265
+ if predict_with_offset:
266
+ a_hat += sampled_offsets
267
+ if action_seq is None:
268
+ return a_hat, None, {}
269
+ # We are in training, so figure out the loss for the actions.
270
+ # First, we need to find the closest cluster center for each action.
271
+ action_bins = self._find_closest_cluster(action_seq)
272
+ true_offsets = action_seq - self._cluster_centers[action_bins]
273
+ predicted_offsets = flattened_cbet_offsets[
274
+ torch.arange(flattened_cbet_offsets.shape[0]), action_bins.flatten()
275
+ ].view(N, T, self._act_dim)
276
+ # Now we can compute the loss.
277
+ offset_loss = self._offset_criterion(predicted_offsets, true_offsets)
278
+ cbet_loss = self._criterion(
279
+ einops.rearrange(cbet_logits, "N T D -> (N T) D"),
280
+ einops.rearrange(action_bins, "N T -> (N T)"),
281
+ )
282
+ # Now, use the padding mask to mask out the loss.
283
+ if is_padded_action_seq is not None:
284
+ cbet_loss *= ~is_padded_action_seq.view(-1)
285
+ offset_loss *= ~is_padded_action_seq.unsqueeze(-1)
286
+ cbet_loss, offset_loss = cbet_loss.mean(), offset_loss.mean()
287
+ loss = cbet_loss + self._offset_loss_multiplier * offset_loss
288
+ action_mse = F.mse_loss(a_hat, action_seq, reduction="none")
289
+ action_l1 = F.l1_loss(a_hat, action_seq, reduction="none")
290
+ norm = torch.norm(action_seq, p=2, dim=-1, keepdim=True) + 1e-9
291
+ normalized_mse = (action_mse / norm).mean()
292
+ if self._current_steps < self._kmeans_fit_steps:
293
+ loss = loss.detach() + (loss * 0.0)
294
+
295
+ loss_dict = {
296
+ "classification_loss": cbet_loss.detach().cpu().item(),
297
+ "offset_loss": offset_loss.detach().cpu().item(),
298
+ "loss": loss.detach().cpu().item(),
299
+ "L2_loss": action_mse.mean().detach().cpu().item(),
300
+ "L2_loss_normalized": normalized_mse.mean().detach().cpu().item(),
301
+ "L1_loss": action_l1.mean().detach().cpu().item(),
302
+ }
303
+ return a_hat, loss, loss_dict
304
+
305
+ def _find_closest_cluster(self, action_seq: torch.Tensor) -> torch.Tensor:
306
+ N, T, _ = action_seq.shape
307
+ flattened_actions = einops.rearrange(action_seq, "N T A -> (N T) A")
308
+ cluster_center_distance = torch.sum(
309
+ (flattened_actions[:, None, :] - self._cluster_centers[None, :, :]) ** 2,
310
+ dim=2,
311
+ ) # (N T) K A -> (N T) K
312
+ closest_cluster_center = torch.argmin(cluster_center_distance, dim=1) # (N T)
313
+ discretized_action = einops.rearrange(
314
+ closest_cluster_center, "(N T) -> N T", N=N, T=T
315
+ )
316
+ return discretized_action
317
+
318
+ def configure_optimizers(self, weight_decay, learning_rate, betas):
319
+ optimizer = self._gpt_model.configure_optimizers(
320
+ weight_decay=weight_decay,
321
+ learning_rate=learning_rate,
322
+ betas=betas,
323
+ )
324
+ optimizer.add_param_group({"params": self._map_to_cbet_preds.parameters()})
325
+ return optimizer
326
+
327
+
328
+ class FocalLoss(nn.Module):
329
+ def __init__(self, gamma: float = 0, reduction: str = "mean"):
330
+ super(FocalLoss, self).__init__()
331
+ self.gamma = gamma
332
+ if reduction not in ("mean", "sum", "none"):
333
+ raise NotImplementedError
334
+ self.reduction = reduction
335
+
336
+ def forward(self, input, target):
337
+ logpt = F.log_softmax(input, dim=-1)
338
+ logpt = logpt.gather(1, target.view(-1, 1)).view(-1)
339
+ pt = logpt.exp()
340
+
341
+ loss = -1 * (1 - pt) ** self.gamma * logpt
342
+ if self.reduction == "mean":
343
+ return loss.mean()
344
+ elif self.reduction == "sum":
345
+ return loss.sum()
346
+ else:
347
+ return loss
@@ -0,0 +1,277 @@
1
+ """
2
+ An adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
3
+ Original source: https://github.com/karpathy/nanoGPT
4
+
5
+ Original License:
6
+ MIT License
7
+
8
+ Copyright (c) 2022 Andrej Karpathy
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Original comments:
29
+ Full definition of a GPT Language Model, all of it in this single file.
30
+ References:
31
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
32
+ https://github.com/openai/gpt-2/blob/master/src/model.py
33
+ 2) huggingface/transformers PyTorch implementation:
34
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
35
+ """
36
+
37
+ import math
38
+ from dataclasses import dataclass
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ from torch.nn import functional as F
43
+
44
+
45
+ # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
46
+ def new_gelu(x):
47
+ """
48
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
49
+ Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
50
+ """
51
+ return (
52
+ 0.5
53
+ * x
54
+ * (
55
+ 1.0
56
+ + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))
57
+ )
58
+ )
59
+
60
+
61
+ class CausalSelfAttention(nn.Module):
62
+ def __init__(self, config):
63
+ super().__init__()
64
+ assert config.n_embd % config.n_head == 0
65
+ # key, query, value projections for all heads, but in a batch
66
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
67
+ # output projection
68
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
69
+ # regularization
70
+ self.attn_dropout = nn.Dropout(config.dropout)
71
+ self.resid_dropout = nn.Dropout(config.dropout)
72
+ # causal mask to ensure that attention is only applied to the left in the input sequence
73
+ self.register_buffer(
74
+ "bias",
75
+ torch.tril(torch.ones(config.block_size, config.block_size)).view(
76
+ 1, 1, config.block_size, config.block_size
77
+ ),
78
+ )
79
+ self.n_head = config.n_head
80
+ self.n_embd = config.n_embd
81
+
82
+ def forward(self, x):
83
+ # batch size, sequence length, embedding dimensionality (n_embd)
84
+ B, T, C = x.size()
85
+
86
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
87
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
88
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(
89
+ 1, 2
90
+ ) # (B, nh, T, hs)
91
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(
92
+ 1, 2
93
+ ) # (B, nh, T, hs)
94
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(
95
+ 1, 2
96
+ ) # (B, nh, T, hs)
97
+
98
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
99
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
100
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
101
+ att = F.softmax(att, dim=-1)
102
+ att = self.attn_dropout(att)
103
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
104
+ y = (
105
+ y.transpose(1, 2).contiguous().view(B, T, C)
106
+ ) # re-assemble all head outputs side by side
107
+
108
+ # output projection
109
+ y = self.resid_dropout(self.c_proj(y))
110
+ return y
111
+
112
+
113
+ class MLP(nn.Module):
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
117
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
118
+ self.dropout = nn.Dropout(config.dropout)
119
+
120
+ def forward(self, x):
121
+ x = self.c_fc(x)
122
+ x = new_gelu(x)
123
+ x = self.c_proj(x)
124
+ x = self.dropout(x)
125
+ return x
126
+
127
+
128
+ class Block(nn.Module):
129
+ def __init__(self, config):
130
+ super().__init__()
131
+ self.ln_1 = nn.LayerNorm(config.n_embd)
132
+ self.attn = CausalSelfAttention(config)
133
+ self.ln_2 = nn.LayerNorm(config.n_embd)
134
+ self.mlp = MLP(config)
135
+
136
+ def forward(self, x):
137
+ x = x + self.attn(self.ln_1(x))
138
+ x = x + self.mlp(self.ln_2(x))
139
+ return x
140
+
141
+
142
+ @dataclass
143
+ class GPTConfig:
144
+ block_size: int = 1024
145
+ input_dim: int = 256
146
+ output_dim: int = 256
147
+ n_layer: int = 12
148
+ n_head: int = 12
149
+ n_embd: int = 768
150
+ dropout: float = 0.1
151
+
152
+
153
+ class GPT(nn.Module):
154
+ def __init__(self, config):
155
+ super().__init__()
156
+ assert config.input_dim is not None
157
+ assert config.output_dim is not None
158
+ assert config.block_size is not None
159
+ self.config = config
160
+
161
+ self.transformer = nn.ModuleDict(
162
+ dict(
163
+ wte=nn.Linear(config.input_dim, config.n_embd),
164
+ wpe=nn.Embedding(config.block_size, config.n_embd),
165
+ drop=nn.Dropout(config.dropout),
166
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
167
+ ln_f=nn.LayerNorm(config.n_embd),
168
+ )
169
+ )
170
+ self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False)
171
+ # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
172
+ self.apply(self._init_weights)
173
+ for pn, p in self.named_parameters():
174
+ if pn.endswith("c_proj.weight"):
175
+ torch.nn.init.normal_(
176
+ p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
177
+ )
178
+
179
+ # report number of parameters
180
+ n_params = sum(p.numel() for p in self.parameters())
181
+
182
+ def forward(self, input, targets=None):
183
+ device = input.device
184
+ b, t, d = input.size()
185
+ assert (
186
+ t <= self.config.block_size
187
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
188
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
189
+ 0
190
+ ) # shape (1, t)
191
+
192
+ # forward the GPT model itself
193
+ tok_emb = self.transformer.wte(
194
+ input
195
+ ) # token embeddings of shape (b, t, n_embd)
196
+ pos_emb = self.transformer.wpe(
197
+ pos
198
+ ) # position embeddings of shape (1, t, n_embd)
199
+ x = self.transformer.drop(tok_emb + pos_emb)
200
+ for block in self.transformer.h:
201
+ x = block(x)
202
+ x = self.transformer.ln_f(x)
203
+ logits = self.lm_head(x)
204
+ return logits
205
+
206
+ def _init_weights(self, module):
207
+ if isinstance(module, nn.Linear):
208
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
209
+ if module.bias is not None:
210
+ torch.nn.init.zeros_(module.bias)
211
+ elif isinstance(module, nn.Embedding):
212
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
213
+ elif isinstance(module, nn.LayerNorm):
214
+ torch.nn.init.zeros_(module.bias)
215
+ torch.nn.init.ones_(module.weight)
216
+
217
+ def crop_block_size(self, block_size):
218
+ assert block_size <= self.config.block_size
219
+ self.config.block_size = block_size
220
+ self.transformer.wpe.weight = nn.Parameter(
221
+ self.transformer.wpe.weight[:block_size]
222
+ )
223
+ for block in self.transformer.h:
224
+ block.attn.bias = block.attn.bias[:, :, :block_size, :block_size]
225
+
226
+ def configure_optimizers(self, weight_decay, learning_rate, betas):
227
+ """
228
+ This long function is unfortunately doing something very simple and is being very defensive:
229
+ We are separating out all parameters of the model into two buckets: those that will experience
230
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
231
+ We are then returning the PyTorch optimizer object.
232
+ """
233
+
234
+ # separate out all parameters to those that will and won't experience regularizing weight decay
235
+ decay = set()
236
+ no_decay = set()
237
+ whitelist_weight_modules = (torch.nn.Linear,)
238
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
239
+ for mn, m in self.named_modules():
240
+ for pn, p in m.named_parameters():
241
+ fpn = "%s.%s" % (mn, pn) if mn else pn # full param name
242
+ if pn.endswith("bias"):
243
+ # all biases will not be decayed
244
+ no_decay.add(fpn)
245
+ elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
246
+ # weights of whitelist modules will be weight decayed
247
+ decay.add(fpn)
248
+ elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
249
+ # weights of blacklist modules will NOT be weight decayed
250
+ no_decay.add(fpn)
251
+
252
+ # validate that we considered every parameter
253
+ param_dict = {pn: p for pn, p in self.named_parameters()}
254
+ inter_params = decay & no_decay
255
+ union_params = decay | no_decay
256
+ assert len(inter_params) == 0, (
257
+ "parameters %s made it into both decay/no_decay sets!"
258
+ % (str(inter_params),)
259
+ )
260
+ assert len(param_dict.keys() - union_params) == 0, (
261
+ "parameters %s were not separated into either decay/no_decay set!"
262
+ % (str(param_dict.keys() - union_params),)
263
+ )
264
+
265
+ # create the pytorch optimizer object
266
+ optim_groups = [
267
+ {
268
+ "params": [param_dict[pn] for pn in sorted(list(decay))],
269
+ "weight_decay": weight_decay,
270
+ },
271
+ {
272
+ "params": [param_dict[pn] for pn in sorted(list(no_decay))],
273
+ "weight_decay": 0.0,
274
+ },
275
+ ]
276
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
277
+ return optimizer