x-transformers 1.30.10__py3-none-any.whl → 1.30.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +30 -5
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.12.dist-info}/METADATA +1 -1
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.12.dist-info}/RECORD +6 -6
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.12.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.12.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.10.dist-info → x_transformers-1.30.12.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -80,6 +80,13 @@ class equals():
|
|
80
80
|
def __call__(self, x, *args, **kwargs):
|
81
81
|
return x == self.val
|
82
82
|
|
83
|
+
class Identity(Module):
|
84
|
+
def __init__(self, *args, **kwargs):
|
85
|
+
super().__init__()
|
86
|
+
|
87
|
+
def __call__(self, x, *args, **kwargs):
|
88
|
+
return x
|
89
|
+
|
83
90
|
def Sequential(*modules):
|
84
91
|
return nn.Sequential(*filter(exists, modules))
|
85
92
|
|
@@ -314,21 +321,29 @@ class CoPE(Module):
|
|
314
321
|
heads,
|
315
322
|
max_pos,
|
316
323
|
soft_onehot = False,
|
317
|
-
talking_heads = False
|
324
|
+
talking_heads = False,
|
325
|
+
soft_onehot_temp = 5e-2
|
318
326
|
):
|
319
327
|
super () . __init__ ()
|
320
328
|
self.max_pos = max_pos
|
321
329
|
self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
|
322
330
|
|
323
|
-
self.
|
331
|
+
self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
|
324
332
|
self.soft_onehot = soft_onehot
|
333
|
+
self.soft_onehot_temp = soft_onehot_temp
|
325
334
|
|
326
335
|
if soft_onehot:
|
327
336
|
self.register_buffer('positions', torch.arange(max_pos))
|
328
337
|
|
329
|
-
def forward(self, query, attn_logits
|
338
|
+
def forward(self, query, attn_logits):
|
330
339
|
|
331
|
-
|
340
|
+
if exists(self.talking_heads):
|
341
|
+
i, j = attn_logits.shape[-2:]
|
342
|
+
causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()
|
343
|
+
|
344
|
+
attn_logits = self.talking_heads(attn_logits)
|
345
|
+
|
346
|
+
attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)
|
332
347
|
|
333
348
|
# compute positions
|
334
349
|
|
@@ -341,7 +356,7 @@ class CoPE(Module):
|
|
341
356
|
|
342
357
|
if self.soft_onehot:
|
343
358
|
diff_pos = (pos[..., None] - self.positions).abs()
|
344
|
-
soft_onehot_pos = F.softmax(-diff_pos /
|
359
|
+
soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
|
345
360
|
cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
|
346
361
|
else:
|
347
362
|
# interpolate from integer positions
|
@@ -1121,6 +1136,7 @@ class AttentionLayers(Module):
|
|
1121
1136
|
use_scalenorm = False,
|
1122
1137
|
use_rmsnorm = False,
|
1123
1138
|
use_simple_rmsnorm = False,
|
1139
|
+
no_pre_or_postnorm = False,
|
1124
1140
|
alibi_pos_bias = False,
|
1125
1141
|
alibi_num_heads = None,
|
1126
1142
|
rel_pos_bias = False,
|
@@ -1350,6 +1366,15 @@ class AttentionLayers(Module):
|
|
1350
1366
|
residual_fn = GRUGating if gate_residual else Residual
|
1351
1367
|
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1352
1368
|
|
1369
|
+
# for the peri-layernorm config from https://arxiv.org/abs/2405.16039
|
1370
|
+
# must be paired with qk norm
|
1371
|
+
|
1372
|
+
layer_norm_fn = norm_fn
|
1373
|
+
if no_pre_or_postnorm:
|
1374
|
+
layer_norm_fn = Identity()
|
1375
|
+
|
1376
|
+
# all normalizations of the layer
|
1377
|
+
|
1353
1378
|
pre_branch_norm = norm_fn() if pre_norm else None
|
1354
1379
|
post_branch_norm = norm_fn() if sandwich_norm else None
|
1355
1380
|
post_main_norm = norm_fn() if not pre_norm else None
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=hp7upZuaffVtrhZdL1Ra5sqBHcMrUMUlNvtbK8ilBPI,69497
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.12.dist-info/METADATA,sha256=agGqwtqqfvOYid5YrtmwaW46jxFU5kQ-QvTke5EhuiE,662
|
12
|
+
x_transformers-1.30.12.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|