x-transformers 1.40.11__py3-none-any.whl → 1.41.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/dpo.py CHANGED
@@ -5,6 +5,7 @@ from torch.nn import Module
5
5
  import torch.nn.functional as F
6
6
  from x_transformers.x_transformers import TransformerWrapper
7
7
 
8
+ import einx
8
9
  from einops import rearrange
9
10
 
10
11
  # helper functions
@@ -17,16 +18,18 @@ def freeze_all_layers_(module):
17
18
  param.requires_grad = False
18
19
 
19
20
  def log_prob_from_model_and_seq(model, seq):
20
- logits = model(seq)
21
+ src_seq, tgt_seq = seq[:, :-1], seq[:, 1:]
22
+ logits = model(src_seq)
21
23
  log_prob = logits.log_softmax(dim = -1)
22
- indices = rearrange(seq, '... -> ... 1')
23
- log_probs = log_prob.gather(-1, indices)
24
- return rearrange(log_probs, '... 1 -> ...')
24
+ return einx.get_at('b n [l], b n -> b n', log_prob, tgt_seq)
25
25
 
26
26
  def masked_mean(log_probs, mask = None):
27
27
  if not exists(mask):
28
28
  return log_probs.mean(dim = -1)
29
29
 
30
+ if mask.shape[-1] == (log_probs.shape[-1] + 1):
31
+ mask = mask[:, :-1]
32
+
30
33
  log_probs = log_probs.masked_fill(~mask, 0.)
31
34
  num = log_probs.sum(dim = -1)
32
35
  den = mask.sum(dim = -1)
@@ -16,8 +16,9 @@ from collections import namedtuple
16
16
  from contextlib import nullcontext
17
17
  from dataclasses import dataclass
18
18
 
19
- from einops import rearrange, repeat, reduce, pack, unpack
19
+ import einx
20
20
  from einops.layers.torch import Rearrange
21
+ from einops import rearrange, repeat, reduce, pack, unpack
21
22
 
22
23
  from x_transformers.attend import Attend, Intermediates
23
24
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
@@ -100,6 +101,12 @@ def log(t, eps = 1e-20):
100
101
  def max_neg_value(tensor):
101
102
  return -torch.finfo(tensor.dtype).max
102
103
 
104
+ def reverse_cumsum(t, dim = -1):
105
+ t = t.flip(dims = (dim,))
106
+ t = t.cumsum(dim = dim)
107
+ t = t.flip(dims = (dim,))
108
+ return t
109
+
103
110
  def l2norm(t, groups = 1):
104
111
  t = rearrange(t, '... (g d) -> ... g d', g = groups)
105
112
  t = F.normalize(t, p = 2, dim = -1)
@@ -324,7 +331,7 @@ class RelativePositionBias(Module):
324
331
  device = self.device
325
332
  q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
326
333
  k_pos = torch.arange(j, dtype = torch.long, device = device)
327
- rel_pos = k_pos[None, :] - q_pos[:, None]
334
+ rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
328
335
  rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
329
336
  values = self.relative_attention_bias(rp_bucket)
330
337
  bias = rearrange(values, 'i j h -> h i j')
@@ -351,8 +358,10 @@ class CoPE(Module):
351
358
  self.soft_onehot = soft_onehot
352
359
  self.soft_onehot_temp = soft_onehot_temp
353
360
 
354
- if soft_onehot:
355
- self.register_buffer('positions', torch.arange(max_pos))
361
+ if not soft_onehot:
362
+ return
363
+
364
+ self.register_buffer('positions', torch.arange(max_pos))
356
365
 
357
366
  def forward(self, query, attn_logits):
358
367
 
@@ -374,7 +383,7 @@ class CoPE(Module):
374
383
  logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
375
384
 
376
385
  if self.soft_onehot:
377
- diff_pos = (pos[..., None] - self.positions).abs()
386
+ diff_pos = einx.subtract('i, j -> i j', pos, self.positions).abs()
378
387
  soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
379
388
  cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
380
389
  else:
@@ -423,7 +432,7 @@ class DynamicPositionBias(Module):
423
432
  # get the (n x n) matrix of distances
424
433
  seq_arange = torch.arange(n, device = device)
425
434
  context_arange = torch.arange(n, device = device)
426
- indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
435
+ indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
427
436
  indices += (n - 1)
428
437
 
429
438
  # input to continuous positions MLP
@@ -453,9 +462,9 @@ class AlibiPositionalBias(Module):
453
462
  self.register_buffer('bias', None, persistent = False)
454
463
 
455
464
  def get_bias(self, i, j, device):
456
- i_arange = torch.arange(j - i, j, device = device)
457
- j_arange = torch.arange(j, device = device)
458
- bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
465
+ seq_arange = torch.arange(j - i, j, device = device)
466
+ context_arange = torch.arange(j, device = device)
467
+ bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
459
468
  return bias
460
469
 
461
470
  @staticmethod
@@ -490,6 +499,44 @@ class AlibiPositionalBias(Module):
490
499
 
491
500
  return self.bias
492
501
 
502
+ class DataDependentAlibi(Module):
503
+ """ https://openreview.net/forum?id=q2Lnyegkr8 """
504
+
505
+ def __init__(
506
+ self,
507
+ dim,
508
+ heads
509
+ ):
510
+ super().__init__()
511
+
512
+ linear = nn.Linear(dim, heads)
513
+
514
+ self.to_forget_gates = nn.Sequential(
515
+ linear,
516
+ Rearrange('b n h -> b h n'),
517
+ nn.Sigmoid()
518
+ )
519
+
520
+ nn.init.constant_(linear.bias, 5.)
521
+
522
+ def forward(self, x):
523
+ seq = x.shape[-2]
524
+
525
+ forget_gates = self.to_forget_gates(x).log()
526
+ forget_gates = repeat(forget_gates, 'b h j -> b h i j', i = seq)
527
+
528
+ # causal mask out, including diagonal (so token to itself attention is never masked out)
529
+
530
+ causal_mask = torch.ones((seq, seq), dtype = torch.bool, device = x.device).triu()
531
+
532
+ forget_gates = forget_gates.masked_fill(causal_mask, 0.)
533
+
534
+ # reverse cumulative sum in log space (equivalent to cumprod)
535
+
536
+ forget_gates = reverse_cumsum(forget_gates)
537
+
538
+ return forget_gates
539
+
493
540
  class RotaryEmbedding(Module):
494
541
  def __init__(
495
542
  self,
@@ -938,6 +985,7 @@ class Attention(Module):
938
985
  tensor_product = False, # https://arxiv.org/abs/2208.06061
939
986
  add_zero_kv = False, # same as add_zero_attn in pytorch
940
987
  rotary_embed_values = False,
988
+ data_dependent_alibi = False,
941
989
  use_cope = False,
942
990
  cope_max_pos = 16,
943
991
  cope_soft_onehot_pos = False,
@@ -1041,6 +1089,19 @@ class Attention(Module):
1041
1089
  soft_onehot = cope_soft_onehot_pos
1042
1090
  )
1043
1091
 
1092
+ # data dependent alibi
1093
+ # https://openreview.net/forum?id=q2Lnyegkr8
1094
+
1095
+ self.data_dependent_alibi = None
1096
+
1097
+ if data_dependent_alibi:
1098
+ assert causal, 'data dependent alibi only works for autoregressive for now until further research'
1099
+
1100
+ self.data_dependent_alibi = DataDependentAlibi(
1101
+ dim,
1102
+ heads = heads
1103
+ )
1104
+
1044
1105
  # attend class - includes core attention algorithm + talking heads
1045
1106
 
1046
1107
  self.attend = Attend(
@@ -1236,7 +1297,7 @@ class Attention(Module):
1236
1297
  if exists(self.max_attend_past):
1237
1298
  range_q = torch.arange(j - i, j, device = device)
1238
1299
  range_k = torch.arange(j, device = device)
1239
- dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1300
+ dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
1240
1301
  max_attend_past_mask = dist > self.max_attend_past
1241
1302
  max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
1242
1303
  masks.append(max_attend_past_mask)
@@ -1251,6 +1312,11 @@ class Attention(Module):
1251
1312
  attn_bias = rel_pos(i, j)
1252
1313
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1253
1314
 
1315
+ # prepare data dependent alibi from forgetting transformers paper, if needed
1316
+
1317
+ if exists(self.data_dependent_alibi):
1318
+ attn_bias = self.data_dependent_alibi(x)
1319
+
1254
1320
  # if previous values passed in for residual, either invoke resformer or neutreno
1255
1321
 
1256
1322
  if exists(value_residual):
@@ -1291,7 +1357,7 @@ class Attention(Module):
1291
1357
 
1292
1358
  if exists(self.to_v_head_gate):
1293
1359
  head_gate = self.to_v_head_gate(x)
1294
- out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1360
+ out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
1295
1361
 
1296
1362
  # merge heads
1297
1363
 
@@ -1308,8 +1374,7 @@ class Attention(Module):
1308
1374
  out = self.to_out(out)
1309
1375
 
1310
1376
  if exists(mask):
1311
- mask = rearrange(mask, 'b n -> b n 1')
1312
- out = out.masked_fill(~mask, 0.)
1377
+ out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
1313
1378
 
1314
1379
  if not return_intermediates:
1315
1380
  return out
@@ -1389,10 +1454,13 @@ class AttentionLayers(Module):
1389
1454
  attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1390
1455
  cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
1391
1456
 
1457
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1458
+ data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
1459
+ neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
1460
+
1392
1461
  assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1393
1462
 
1394
- dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1395
- add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)
1463
+ add_value_residual |= neutreno_value_residual
1396
1464
 
1397
1465
  self.dim = dim
1398
1466
  self.causal = causal
@@ -1405,7 +1473,7 @@ class AttentionLayers(Module):
1405
1473
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1406
1474
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1407
1475
 
1408
- assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1476
+ assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), or T5 relative positional bias'
1409
1477
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1410
1478
 
1411
1479
  # relative positional bias
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.11
3
+ Version: 1.41.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -15,5 +15,6 @@ Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
17
  Requires-Dist: torch >=2.0
18
+ Requires-Dist: einx >=0.3.0
18
19
  Requires-Dist: einops >=0.8.0
19
20
 
@@ -2,14 +2,14 @@ x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,79
2
2
  x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
- x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
5
+ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=RfpihlGygZz4ICq4IGOgGNOipInXUiYWYNs1tej2Orw,88290
8
+ x_transformers/x_transformers.py,sha256=fz71zW2IQ3NQU_csHbzCwFzGHNwrdIF9rZTLhUjmM_Q,90260
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.40.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.11.dist-info/METADATA,sha256=D97orsPC5EYEtJN6EN75bLOfOY-FBmodr2eaFIovwu8,662
13
- x_transformers-1.40.11.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.11.dist-info/RECORD,,
11
+ x_transformers-1.41.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.41.0.dist-info/METADATA,sha256=832V1ChJ77viLmQO7d9RU7R9SV_bspLgVl9vtdRdq5Q,689
13
+ x_transformers-1.41.0.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
+ x_transformers-1.41.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.41.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5