x-transformers 1.30.1__tar.gz → 1.30.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {x_transformers-1.30.1/x_transformers.egg-info → x_transformers-1.30.2}/PKG-INFO +1 -1
- {x_transformers-1.30.1 → x_transformers-1.30.2}/README.md +1 -1
- {x_transformers-1.30.1 → x_transformers-1.30.2}/setup.py +1 -1
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/x_transformers.py +5 -2
- {x_transformers-1.30.1 → x_transformers-1.30.2/x_transformers.egg-info}/PKG-INFO +1 -1
- {x_transformers-1.30.1 → x_transformers-1.30.2}/LICENSE +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/setup.cfg +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/__init__.py +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/attend.py +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/autoregressive_wrapper.py +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/continuous.py +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/dpo.py +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/nonautoregressive_wrapper.py +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/xl_autoregressive_wrapper.py +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers/xval.py +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers.egg-info/SOURCES.txt +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers.egg-info/dependency_links.txt +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers.egg-info/requires.txt +0 -0
- {x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers.egg-info/top_level.txt +0 -0
@@ -693,7 +693,7 @@ model = TransformerWrapper(
|
|
693
693
|
)
|
694
694
|
```
|
695
695
|
|
696
|
-
If you wish to do something more sophisticated, say 3 layers, with each layer recurrent 4 times before onto the next, that is possible as well. Be aware the `layers_execute_order` is 0-indexed
|
696
|
+
If you wish to do something more sophisticated, say 3 layers, with each layer recurrent 4 times before onto the next (similar to <a href="https://arxiv.org/abs/2405.15071">this paper</a>), that is possible as well. Be aware the `layers_execute_order` is 0-indexed
|
697
697
|
|
698
698
|
```python
|
699
699
|
import torch
|
@@ -468,7 +468,8 @@ def rotate_half(x):
|
|
468
468
|
|
469
469
|
@autocast(enabled = False)
|
470
470
|
def apply_rotary_pos_emb(t, freqs, scale = 1):
|
471
|
-
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
|
471
|
+
rot_dim, seq_len, orig_dtype = freqs.shape[-1], t.shape[-2], t.dtype
|
472
|
+
|
472
473
|
freqs = freqs[-seq_len:, :]
|
473
474
|
scale = scale[-seq_len:, :] if isinstance(scale, torch.Tensor) else scale
|
474
475
|
|
@@ -478,7 +479,9 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
478
479
|
# partial rotary embeddings, Wang et al. GPT-J
|
479
480
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
480
481
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
481
|
-
|
482
|
+
out = torch.cat((t, t_unrotated), dim = -1)
|
483
|
+
|
484
|
+
return out.type(orig_dtype)
|
482
485
|
|
483
486
|
# norms
|
484
487
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{x_transformers-1.30.1 → x_transformers-1.30.2}/x_transformers.egg-info/dependency_links.txt
RENAMED
File without changes
|
File without changes
|
File without changes
|