x-transformers 1.30.10__py3-none-any.whl → 1.30.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -314,21 +314,29 @@ class CoPE(Module):
314
314
  heads,
315
315
  max_pos,
316
316
  soft_onehot = False,
317
- talking_heads = False
317
+ talking_heads = False,
318
+ soft_onehot_temp = 5e-2
318
319
  ):
319
320
  super () . __init__ ()
320
321
  self.max_pos = max_pos
321
322
  self.pos_emb = nn.Parameter(torch.zeros(max_pos, dim))
322
323
 
323
- self.maybe_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()
324
+ self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else None
324
325
  self.soft_onehot = soft_onehot
326
+ self.soft_onehot_temp = soft_onehot_temp
325
327
 
326
328
  if soft_onehot:
327
329
  self.register_buffer('positions', torch.arange(max_pos))
328
330
 
329
- def forward(self, query, attn_logits, temp = 5e-2):
331
+ def forward(self, query, attn_logits):
332
+
333
+ if exists(self.talking_heads):
334
+ i, j = attn_logits.shape[-2:]
335
+ causal_mask = attn_logits.new_ones(i, j).triu_(j - i + 1).bool()
336
+
337
+ attn_logits = self.talking_heads(attn_logits)
330
338
 
331
- attn_logits = self.maybe_talking_heads(attn_logits)
339
+ attn_logits = attn_logits.masked_fill(causal_mask, -torch.finfo(attn_logits.dtype).max)
332
340
 
333
341
  # compute positions
334
342
 
@@ -341,7 +349,7 @@ class CoPE(Module):
341
349
 
342
350
  if self.soft_onehot:
343
351
  diff_pos = (pos[..., None] - self.positions).abs()
344
- soft_onehot_pos = F.softmax(-diff_pos / temp, dim = -1)
352
+ soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
345
353
  cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
346
354
  else:
347
355
  # interpolate from integer positions
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.10
3
+ Version: 1.30.11
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=qErIXwntVnZEmdO4MwTTsvxMc_vQPBU8zLLpEh535tE,68690
7
+ x_transformers/x_transformers.py,sha256=97WREYHfGljhrveGNUUSEWk2xbFvKdM52QXW7cnoBpk,69019
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.10.dist-info/METADATA,sha256=Fo2EWkxK4rLMcwUqcyeSW9AljUJA6b_2hZ59W8axIO4,662
12
- x_transformers-1.30.10.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.10.dist-info/RECORD,,
10
+ x_transformers-1.30.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.11.dist-info/METADATA,sha256=kpv8_mEb4DU4zL_oyLwzbWJ7z8WV7mf7c466Hehn-6c,662
12
+ x_transformers-1.30.11.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.11.dist-info/RECORD,,