x-transformers 1.40.11__py3-none-any.whl → 1.41.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/dpo.py CHANGED
@@ -5,6 +5,7 @@ from torch.nn import Module
5
5
  import torch.nn.functional as F
6
6
  from x_transformers.x_transformers import TransformerWrapper
7
7
 
8
+ import einx
8
9
  from einops import rearrange
9
10
 
10
11
  # helper functions
@@ -17,16 +18,18 @@ def freeze_all_layers_(module):
17
18
  param.requires_grad = False
18
19
 
19
20
  def log_prob_from_model_and_seq(model, seq):
20
- logits = model(seq)
21
+ src_seq, tgt_seq = seq[:, :-1], seq[:, 1:]
22
+ logits = model(src_seq)
21
23
  log_prob = logits.log_softmax(dim = -1)
22
- indices = rearrange(seq, '... -> ... 1')
23
- log_probs = log_prob.gather(-1, indices)
24
- return rearrange(log_probs, '... 1 -> ...')
24
+ return einx.get_at('b n [l], b n -> b n', log_prob, tgt_seq)
25
25
 
26
26
  def masked_mean(log_probs, mask = None):
27
27
  if not exists(mask):
28
28
  return log_probs.mean(dim = -1)
29
29
 
30
+ if mask.shape[-1] == (log_probs.shape[-1] + 1):
31
+ mask = mask[:, :-1]
32
+
30
33
  log_probs = log_probs.masked_fill(~mask, 0.)
31
34
  num = log_probs.sum(dim = -1)
32
35
  den = mask.sum(dim = -1)
@@ -16,8 +16,9 @@ from collections import namedtuple
16
16
  from contextlib import nullcontext
17
17
  from dataclasses import dataclass
18
18
 
19
- from einops import rearrange, repeat, reduce, pack, unpack
19
+ import einx
20
20
  from einops.layers.torch import Rearrange
21
+ from einops import rearrange, repeat, reduce, pack, unpack
21
22
 
22
23
  from x_transformers.attend import Attend, Intermediates
23
24
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
@@ -324,7 +325,7 @@ class RelativePositionBias(Module):
324
325
  device = self.device
325
326
  q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
326
327
  k_pos = torch.arange(j, dtype = torch.long, device = device)
327
- rel_pos = k_pos[None, :] - q_pos[:, None]
328
+ rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
328
329
  rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
329
330
  values = self.relative_attention_bias(rp_bucket)
330
331
  bias = rearrange(values, 'i j h -> h i j')
@@ -351,8 +352,10 @@ class CoPE(Module):
351
352
  self.soft_onehot = soft_onehot
352
353
  self.soft_onehot_temp = soft_onehot_temp
353
354
 
354
- if soft_onehot:
355
- self.register_buffer('positions', torch.arange(max_pos))
355
+ if not soft_onehot:
356
+ return
357
+
358
+ self.register_buffer('positions', torch.arange(max_pos))
356
359
 
357
360
  def forward(self, query, attn_logits):
358
361
 
@@ -374,7 +377,7 @@ class CoPE(Module):
374
377
  logits_int = einsum('b h n d, p d -> b h n p', query, self.pos_emb)
375
378
 
376
379
  if self.soft_onehot:
377
- diff_pos = (pos[..., None] - self.positions).abs()
380
+ diff_pos = einx.subtract('i, j -> i j', pos, self.positions).abs()
378
381
  soft_onehot_pos = F.softmax(-diff_pos / self.soft_onehot_temp, dim = -1)
379
382
  cope_pos_emb = einsum('b h i j p, b h i p -> b h i j', soft_onehot_pos, logits_int)
380
383
  else:
@@ -423,7 +426,7 @@ class DynamicPositionBias(Module):
423
426
  # get the (n x n) matrix of distances
424
427
  seq_arange = torch.arange(n, device = device)
425
428
  context_arange = torch.arange(n, device = device)
426
- indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
429
+ indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
427
430
  indices += (n - 1)
428
431
 
429
432
  # input to continuous positions MLP
@@ -453,9 +456,9 @@ class AlibiPositionalBias(Module):
453
456
  self.register_buffer('bias', None, persistent = False)
454
457
 
455
458
  def get_bias(self, i, j, device):
456
- i_arange = torch.arange(j - i, j, device = device)
457
- j_arange = torch.arange(j, device = device)
458
- bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
459
+ seq_arange = torch.arange(j - i, j, device = device)
460
+ context_arange = torch.arange(j, device = device)
461
+ bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
459
462
  return bias
460
463
 
461
464
  @staticmethod
@@ -490,6 +493,35 @@ class AlibiPositionalBias(Module):
490
493
 
491
494
  return self.bias
492
495
 
496
+ class DataDependentAlibi(Module):
497
+ """ https://openreview.net/forum?id=q2Lnyegkr8 """
498
+
499
+ def __init__(
500
+ self,
501
+ dim,
502
+ heads
503
+ ):
504
+ super().__init__()
505
+
506
+ linear = nn.Linear(dim, heads)
507
+
508
+ self.to_forget_gates = nn.Sequential(
509
+ linear,
510
+ Rearrange('b n h -> b h n'),
511
+ nn.LogSigmoid()
512
+ )
513
+
514
+ nn.init.constant_(linear.bias, 5.)
515
+
516
+ def forward(self, x):
517
+ seq = x.shape[-2]
518
+
519
+ forget_gates = self.to_forget_gates(x)
520
+ forget_gates = forget_gates.cumsum(dim = -1)
521
+ forget_gates = einx.subtract('b h i, b h j -> b h i j', forget_gates, forget_gates)
522
+
523
+ return forget_gates
524
+
493
525
  class RotaryEmbedding(Module):
494
526
  def __init__(
495
527
  self,
@@ -938,6 +970,7 @@ class Attention(Module):
938
970
  tensor_product = False, # https://arxiv.org/abs/2208.06061
939
971
  add_zero_kv = False, # same as add_zero_attn in pytorch
940
972
  rotary_embed_values = False,
973
+ data_dependent_alibi = False,
941
974
  use_cope = False,
942
975
  cope_max_pos = 16,
943
976
  cope_soft_onehot_pos = False,
@@ -1041,6 +1074,19 @@ class Attention(Module):
1041
1074
  soft_onehot = cope_soft_onehot_pos
1042
1075
  )
1043
1076
 
1077
+ # data dependent alibi
1078
+ # https://openreview.net/forum?id=q2Lnyegkr8
1079
+
1080
+ self.data_dependent_alibi = None
1081
+
1082
+ if data_dependent_alibi:
1083
+ assert causal, 'data dependent alibi only works for autoregressive for now until further research'
1084
+
1085
+ self.data_dependent_alibi = DataDependentAlibi(
1086
+ dim,
1087
+ heads = heads
1088
+ )
1089
+
1044
1090
  # attend class - includes core attention algorithm + talking heads
1045
1091
 
1046
1092
  self.attend = Attend(
@@ -1236,7 +1282,7 @@ class Attention(Module):
1236
1282
  if exists(self.max_attend_past):
1237
1283
  range_q = torch.arange(j - i, j, device = device)
1238
1284
  range_k = torch.arange(j, device = device)
1239
- dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1285
+ dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
1240
1286
  max_attend_past_mask = dist > self.max_attend_past
1241
1287
  max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
1242
1288
  masks.append(max_attend_past_mask)
@@ -1251,6 +1297,11 @@ class Attention(Module):
1251
1297
  attn_bias = rel_pos(i, j)
1252
1298
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1253
1299
 
1300
+ # prepare data dependent alibi from forgetting transformers paper, if needed
1301
+
1302
+ if exists(self.data_dependent_alibi):
1303
+ attn_bias = self.data_dependent_alibi(x)
1304
+
1254
1305
  # if previous values passed in for residual, either invoke resformer or neutreno
1255
1306
 
1256
1307
  if exists(value_residual):
@@ -1291,7 +1342,7 @@ class Attention(Module):
1291
1342
 
1292
1343
  if exists(self.to_v_head_gate):
1293
1344
  head_gate = self.to_v_head_gate(x)
1294
- out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
1345
+ out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
1295
1346
 
1296
1347
  # merge heads
1297
1348
 
@@ -1308,8 +1359,7 @@ class Attention(Module):
1308
1359
  out = self.to_out(out)
1309
1360
 
1310
1361
  if exists(mask):
1311
- mask = rearrange(mask, 'b n -> b n 1')
1312
- out = out.masked_fill(~mask, 0.)
1362
+ out = einx.where('b n, b n d, -> b n d', mask, out, 0.)
1313
1363
 
1314
1364
  if not return_intermediates:
1315
1365
  return out
@@ -1389,10 +1439,13 @@ class AttentionLayers(Module):
1389
1439
  attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1390
1440
  cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
1391
1441
 
1442
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1443
+ data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
1444
+ neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
1445
+
1392
1446
  assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1393
1447
 
1394
- dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1395
- add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)
1448
+ add_value_residual |= neutreno_value_residual
1396
1449
 
1397
1450
  self.dim = dim
1398
1451
  self.causal = causal
@@ -1405,7 +1458,7 @@ class AttentionLayers(Module):
1405
1458
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1406
1459
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
1407
1460
 
1408
- assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1461
+ assert at_most_one_of(alibi_pos_bias, rel_pos_bias, data_dependent_alibi), 'you can only choose one of Alibi positional bias, data dependent Alibi (forgetting transformers), or T5 relative positional bias'
1409
1462
  assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1410
1463
 
1411
1464
  # relative positional bias
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.11
3
+ Version: 1.41.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -15,5 +15,6 @@ Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
17
  Requires-Dist: torch >=2.0
18
+ Requires-Dist: einx >=0.3.0
18
19
  Requires-Dist: einops >=0.8.0
19
20
 
@@ -2,14 +2,14 @@ x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,79
2
2
  x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
3
  x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
- x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
5
+ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=RfpihlGygZz4ICq4IGOgGNOipInXUiYWYNs1tej2Orw,88290
8
+ x_transformers/x_transformers.py,sha256=n8W19Pnhbz-JxbC7QATApWrhI_yC4oqTHGQ1NLuindY,89814
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.40.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.11.dist-info/METADATA,sha256=D97orsPC5EYEtJN6EN75bLOfOY-FBmodr2eaFIovwu8,662
13
- x_transformers-1.40.11.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.11.dist-info/RECORD,,
11
+ x_transformers-1.41.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.41.1.dist-info/METADATA,sha256=UIPYbEVBLrWDGuezlnyh2tFKPlM_Mdj-pYTGxse_NMI,689
13
+ x_transformers-1.41.1.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
14
+ x_transformers-1.41.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.41.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5