x-transformers 2.8.3__py3-none-any.whl → 2.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +14 -0
- {x_transformers-2.8.3.dist-info → x_transformers-2.9.0.dist-info}/METADATA +2 -1
- {x_transformers-2.8.3.dist-info → x_transformers-2.9.0.dist-info}/RECORD +5 -5
- {x_transformers-2.8.3.dist-info → x_transformers-2.9.0.dist-info}/WHEEL +0 -0
- {x_transformers-2.8.3.dist-info → x_transformers-2.9.0.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2493,6 +2493,17 @@ class AttentionLayers(Module):
|
|
2493
2493
|
for attn_layer, attn_inter in zip(attn_layers, attn_intermeds):
|
2494
2494
|
attn_layer.qk_clip_(attn_inter, tau = tau)
|
2495
2495
|
|
2496
|
+
def muon_parameters(self):
|
2497
|
+
params = []
|
2498
|
+
|
2499
|
+
for m in self.modules():
|
2500
|
+
if not isinstance(m, (Attention, FeedForward)):
|
2501
|
+
continue
|
2502
|
+
|
2503
|
+
params.extend(list(m.muon_parameters()))
|
2504
|
+
|
2505
|
+
return params
|
2506
|
+
|
2496
2507
|
def forward(
|
2497
2508
|
self,
|
2498
2509
|
x,
|
@@ -3230,6 +3241,9 @@ class TransformerWrapper(Module):
|
|
3230
3241
|
):
|
3231
3242
|
self.attn_layers.attn_qk_clip_(intermediates, tau = tau)
|
3232
3243
|
|
3244
|
+
def muon_parameters(self):
|
3245
|
+
return self.attn_layers.muon_parameters()
|
3246
|
+
|
3233
3247
|
def forward(
|
3234
3248
|
self,
|
3235
3249
|
x,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: x-transformers
|
3
|
-
Version: 2.
|
3
|
+
Version: 2.9.0
|
4
4
|
Summary: X-Transformers
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/x-transformers/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/x-transformers
|
@@ -40,6 +40,7 @@ Requires-Dist: loguru
|
|
40
40
|
Requires-Dist: packaging>=21.0
|
41
41
|
Requires-Dist: torch>=2.0
|
42
42
|
Provides-Extra: examples
|
43
|
+
Requires-Dist: adam-atan2-pytorch>=0.2.2; extra == 'examples'
|
43
44
|
Requires-Dist: lion-pytorch; extra == 'examples'
|
44
45
|
Requires-Dist: tqdm; extra == 'examples'
|
45
46
|
Provides-Extra: test
|
@@ -10,10 +10,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
|
|
10
10
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
11
11
|
x_transformers/nonautoregressive_wrapper.py,sha256=hMQqNimGtchNIe13cR5LZule1V7I1qM5LmY8VQfVdnA,11698
|
12
12
|
x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
|
13
|
-
x_transformers/x_transformers.py,sha256=
|
13
|
+
x_transformers/x_transformers.py,sha256=xIaM5LGimNmDThr9WK0ZF5vfm3sSp0OMll0aRIqjymQ,125443
|
14
14
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
15
15
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
16
|
-
x_transformers-2.
|
17
|
-
x_transformers-2.
|
18
|
-
x_transformers-2.
|
19
|
-
x_transformers-2.
|
16
|
+
x_transformers-2.9.0.dist-info/METADATA,sha256=tQ93i_TSfQIf9OViWNJj57ezY0Slred1NblIyf4UIs0,94986
|
17
|
+
x_transformers-2.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
18
|
+
x_transformers-2.9.0.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
19
|
+
x_transformers-2.9.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|