x-transformers 2.0.4__py3-none-any.whl → 2.0.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +4 -5
- {x_transformers-2.0.4.dist-info → x_transformers-2.0.5.dist-info}/METADATA +1 -1
- {x_transformers-2.0.4.dist-info → x_transformers-2.0.5.dist-info}/RECORD +5 -5
- {x_transformers-2.0.4.dist-info → x_transformers-2.0.5.dist-info}/WHEEL +0 -0
- {x_transformers-2.0.4.dist-info → x_transformers-2.0.5.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -449,17 +449,16 @@ class DynamicPositionBias(Module):
|
|
449
449
|
return next(self.parameters()).device
|
450
450
|
|
451
451
|
def forward(self, i, j):
|
452
|
-
assert i == j
|
453
452
|
n, device = j, self.device
|
454
453
|
|
455
454
|
# get the (n x n) matrix of distances
|
456
|
-
seq_arange = arange(
|
457
|
-
context_arange = arange(
|
455
|
+
seq_arange = arange(j - i, j, device = device)
|
456
|
+
context_arange = arange(j, device = device)
|
458
457
|
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
|
459
|
-
indices += (
|
458
|
+
indices += (j - 1)
|
460
459
|
|
461
460
|
# input to continuous positions MLP
|
462
|
-
pos = arange(-
|
461
|
+
pos = arange(-j + 1, j, device = device).float()
|
463
462
|
pos = rearrange(pos, '... -> ... 1')
|
464
463
|
|
465
464
|
if self.log_distance:
|
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=yijnlpQnhC0lK5qYzSxII7IkVf7ILhsTyntw_S5MvRU,107670
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-2.0.
|
13
|
-
x_transformers-2.0.
|
14
|
-
x_transformers-2.0.
|
15
|
-
x_transformers-2.0.
|
12
|
+
x_transformers-2.0.5.dist-info/METADATA,sha256=9U0kHbTwa2sv4z-pCqlkcm998SDcELMonQ9JoHaYgR4,86938
|
13
|
+
x_transformers-2.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
14
|
+
x_transformers-2.0.5.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
15
|
+
x_transformers-2.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|