x-transformers 1.30.10__py3-none-any.whl → 1.30.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -80,6 +80,13 @@ class equals():
80
80
  def __call__(self, x, *args, **kwargs):
81
81
  return x == self.val
82
82
 
83
+ class Identity(Module):
84
+ def __init__(self, *args, **kwargs):
85
+ super().__init__()
86
+
87
+ def __call__(self, x, *args, **kwargs):
88
+ return x
89
+
83
90
  def Sequential(*modules):
84
91
  return nn.Sequential(*filter(exists, modules))
85
92
 
@@ -314,21 +321,29 @@ class CoPE(Module):
314
321
  heads,
315
322
  max_pos,
316
323
  soft_onehot = False,
317
- talking_heads = False
324
+ talking_heads = False,
325
+ soft_onehot_temp = 5e-2
318
326
  ):
319
327
  super () . __init__ ()
320
328
  self.max_pos = max_pos
321
329
  self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
322
330
 
323
- self.maybe_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
331
+ self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
324
332
  self.soft_onehot = soft_onehot
333
+ self.soft_onehot_temp = soft_onehot_temp
325
334
 
326
335
  if soft_onehot:
327
336
  self.register_buffer('positions', torch.arange(max_pos))
328
337
 
329
- def forward(self, query, attn_logits, temp = 5e-2):
338
+ def forward(self, query, attn_logits):
330
339
 
331
- attn_logits = self.maybe_talking_heads(attn_logits)
340
+ if exists(self.talking_heads):
341
+ i, j = attn_logits.shape[-2:]
342
+ causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()
343
+
344
+ attn_logits = self.talking_heads(attn_logits)
345
+
346
+ attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)
332
347
 
333
348
  # compute positions
334
349
 
@@ -341,7 +356,7 @@ class CoPE(Module):
341
356
 
342
357
  if self.soft_onehot:
343
358
  diff_pos = (pos[..., None] - self.positions).abs()
344
- soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
359
+ soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
345
360
  cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
346
361
  else:
347
362
  # interpolate from integer positions
@@ -1121,6 +1136,7 @@ class AttentionLayers(Module):
1121
1136
  use_scalenorm = False,
1122
1137
  use_rmsnorm = False,
1123
1138
  use_simple_rmsnorm = False,
1139
+ no_pre_or_postnorm = False,
1124
1140
  alibi_pos_bias = False,
1125
1141
  alibi_num_heads = None,
1126
1142
  rel_pos_bias = False,
@@ -1350,6 +1366,15 @@ class AttentionLayers(Module):
1350
1366
  residual_fn = GRUGating if gate_residual else Residual
1351
1367
  residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1352
1368
 
1369
+ # for the peri-layernorm config from https://arxiv.org/abs/2405.16039
1370
+ # must be paired with qk norm
1371
+
1372
+ layer_norm_fn = norm_fn
1373
+ if no_pre_or_postnorm:
1374
+ layer_norm_fn = Identity()
1375
+
1376
+ # all normalizations of the layer
1377
+
1353
1378
  pre_branch_norm = norm_fn() if pre_norm else None
1354
1379
  post_branch_norm = norm_fn() if sandwich_norm else None
1355
1380
  post_main_norm = norm_fn() if not pre_norm else None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.10
3
+ Version: 1.30.12
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=qErIXwntVnZEmdO4MwTTsvxMc_vQPBU8zLLpEh535tE,68690
7
+ x_transformers/x_transformers.py,sha256=hp7upZuaffVtrhZdL1Ra5sqBHcMrUMUlNvtbK8ilBPI,69497
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.10.dist-info/METADATA,sha256=Fo2EWkxK4rLMcwUqcyeSW9AljUJA6b_2hZ59W8axIO4,662
12
- x_transformers-1.30.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.10.dist-info/RECORD,,
10
+ x_transformers-1.30.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.12.dist-info/METADATA,sha256=agGqwtqqfvOYid5YrtmwaW46jxFU5kQ-QvTke5EhuiE,662
12
+ x_transformers-1.30.12.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.12.dist-info/RECORD,,