alphagenome-pytorch 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,11 @@
1
1
  from alphagenome_pytorch.alphagenome import (
2
2
  AlphaGenome,
3
3
  Attention,
4
- TransformerTower
4
+ PairwiseRowAttention,
5
+ RelativePosFeatures,
6
+ RotaryEmbedding,
7
+ FeedForward,
8
+ TransformerTower,
9
+ UpresBlock,
10
+ DownresBlock,
5
11
  )
@@ -32,6 +32,9 @@ def exists(v):
32
32
  def divisible_by(num, den):
33
33
  return (num % den) == 0
34
34
 
35
+ def last(arr):
36
+ return arr[-1]
37
+
35
38
  def is_odd(num):
36
39
  return not divisible_by(num, 2)
37
40
 
@@ -128,15 +131,15 @@ class UpresBlock(Module):
128
131
  def forward(
129
132
  self,
130
133
  x,
131
- unet_skip = None
134
+ skip = None
132
135
  ):
133
136
 
134
137
  residual = x[:, :-self.pad]
135
138
  out = self.conv(x) + residual
136
139
 
137
- if exists(unet_skip):
140
+ if exists(skip):
138
141
  out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
139
- out = out + self.unet_conv(unet_skip)
142
+ out = out + self.unet_conv(skip)
140
143
 
141
144
  return self.conv_out(out) + out
142
145
 
@@ -240,6 +243,7 @@ class Attention(Module):
240
243
  dim,
241
244
  dim_head = 64,
242
245
  heads = 8,
246
+ kv_heads = 1,
243
247
  dim_head_qk = 128,
244
248
  dim_head_v = 192,
245
249
  dim_pairwise = None,
@@ -254,8 +258,13 @@ class Attention(Module):
254
258
 
255
259
  # splitting and merging of attention heads
256
260
 
257
- self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
258
- self.merge_heads = Rearrange('b h n d -> b n (h d)')
261
+ assert divisible_by(heads, kv_heads)
262
+ groups = heads // kv_heads
263
+
264
+ self.split_q_heads = Rearrange('b n (g h d) -> b g h n d', h = kv_heads, g = groups)
265
+ self.split_kv_heads = Rearrange('b n (h d) -> b h n d', h = kv_heads)
266
+
267
+ self.merge_heads = Rearrange('b g h n d -> b n (g h d)')
259
268
 
260
269
  # projections
261
270
 
@@ -274,7 +283,7 @@ class Attention(Module):
274
283
  nn.RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
275
284
  nn.GELU(),
276
285
  LinearNoBias(dim_pairwise, heads),
277
- Rearrange('b i j h -> b h i j')
286
+ Rearrange('b i j (g h) -> b g h i j', g = groups)
278
287
  )
279
288
  # variables
280
289
 
@@ -293,6 +302,7 @@ class Attention(Module):
293
302
  # they use multi-query attention, with only 1 key / value head - pretty unconventional, but maybe enough for genomic modeling
294
303
 
295
304
  q = self.split_q_heads(q)
305
+ k, v = tuple(self.split_kv_heads(t) for t in (k, v))
296
306
 
297
307
  q, k, v = self.q_norm(q), self.k_norm(k), self.v_norm(v)
298
308
 
@@ -305,7 +315,7 @@ class Attention(Module):
305
315
 
306
316
  # similarities
307
317
 
308
- sim = einsum(q, k, 'b h i d, b j d -> b h i j')
318
+ sim = einsum(q, k, 'b g h i d, b h j d -> b g h i j')
309
319
 
310
320
  # add attention bias + softclamping
311
321
 
@@ -315,7 +325,7 @@ class Attention(Module):
315
325
  assert divisible_by(sim.shape[-1], attn_bias.shape[-1])
316
326
  expand_factor = sim.shape[-1] // attn_bias.shape[-1]
317
327
 
318
- attn_bias = repeat(attn_bias, 'b h i j -> b h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
328
+ attn_bias = repeat(attn_bias, 'b g h i j -> b g h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
319
329
 
320
330
  sim = softclamp(sim + attn_bias, value = self.softclamp_value)
321
331
 
@@ -325,7 +335,7 @@ class Attention(Module):
325
335
 
326
336
  # aggregate
327
337
 
328
- out = einsum(attn, v, 'b h i j, b j d -> b h i d')
338
+ out = einsum(attn, v, 'b g h i j, b h j d -> b g h i d')
329
339
 
330
340
  out = self.merge_heads(out)
331
341
  return self.to_out(out)
@@ -384,7 +394,7 @@ class SingleToPairwise(Module):
384
394
  rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
385
395
  rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
386
396
 
387
- rel_sim = add('b h i j, b h j i -> b h i j', rel_q, rel_k) * 0.5
397
+ rel_sim = add('b h i j, b h j i -> b i j h', rel_q, rel_k) * 0.5
388
398
 
389
399
  sim = sim + rel_sim
390
400
 
@@ -460,7 +470,7 @@ class TransformerTower(Module):
460
470
  dropout = 0.,
461
471
  ff_expansion_factor = 2.,
462
472
  max_positions = 8192,
463
- dim_pairwise = None,
473
+ dim_pairwise = 128,
464
474
  pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
465
475
  single_to_pairwise_heads = 32, # they did 32
466
476
  pool_size = 16,
@@ -539,7 +549,7 @@ class TransformerTower(Module):
539
549
  pairwise = maybe_pairwise_attn(pairwise) + pairwise
540
550
  pairwise = maybe_pairwise_ff(pairwise) + pairwise
541
551
 
542
- single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
552
+ single = attn(single, rotary_emb = rotary_emb, pairwise = pairwise) + single
543
553
  single = ff(single) + single
544
554
 
545
555
  return single, pairwise
@@ -556,8 +566,10 @@ class DNAEmbed(Module):
556
566
  super().__init__()
557
567
  assert is_odd(width)
558
568
  self.dim_input = dim_input
559
- self.conv = nn.Conv1d(dim_input, dim, width, padding = width // 2)
560
- self.pointwise = nn.Conv1d(dim, dim, 1)
569
+ self.conv = Conv1d(dim_input, dim, width, padding = width // 2)
570
+ self.pointwise = Conv1d(dim, dim, 1)
571
+
572
+ self.pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
561
573
 
562
574
  def forward(
563
575
  self,
@@ -568,38 +580,96 @@ class DNAEmbed(Module):
568
580
 
569
581
  out = self.conv(x)
570
582
  out = out + self.pointwise(out)
571
- return rearrange(out, 'b d n -> b n d')
583
+ pooled = self.pool(out) # think they downsample for dna embed block
584
+
585
+ return pooled, x
572
586
 
573
587
  # classes
574
588
 
575
589
  class AlphaGenome(Module):
576
590
  def __init__(
577
591
  self,
578
- dim = 768,
592
+ dims: tuple[int, ...] = (
593
+ 768,
594
+ 896,
595
+ 1024,
596
+ 1152,
597
+ 1280,
598
+ 1408,
599
+ 1536
600
+ ),
579
601
  basepairs = 5,
580
602
  dna_embed_width = 15,
581
- dim_pairwise = None,
582
603
  transformer_kwargs: dict = dict()
583
604
  ):
584
605
  super().__init__()
606
+
585
607
  assert is_odd(dna_embed_width)
586
608
 
587
- self.to_dna_embed = DNAEmbed(dim, dim_input = basepairs, width = dna_embed_width)
609
+ assert len(dims) >= 2
610
+ first_dim, *_, last_dim = dims
611
+
612
+ self.dna_embed = DNAEmbed(first_dim, dim_input = basepairs, width = dna_embed_width)
613
+
614
+ dim_with_input = (basepairs, *dims)
615
+ dim_pairs = zip(dim_with_input[:-1], dim_with_input[1:])
616
+
617
+ downs = []
618
+ ups = []
619
+
620
+ for layer_num, (dim_in, dim_out) in enumerate(dim_pairs, start = 1):
621
+ is_first = layer_num == 1
622
+ channel_diff = dim_out - dim_in
623
+
624
+ assert channel_diff > 0
625
+
626
+ if not is_first:
627
+ down = DownresBlock(dim_in, channels_to_add = channel_diff)
628
+ downs.append(down)
588
629
 
589
- self.transformer = Transformer(
590
- dim = dim,
591
- dim_pairwise = dim_pairwise,
630
+ up = UpresBlock(dim_out, channels_to_remove = channel_diff)
631
+ ups.insert(0, up)
632
+
633
+
634
+ self.downs = ModuleList(downs)
635
+ self.ups = ModuleList(ups)
636
+
637
+ self.transformer = TransformerTower(
638
+ dim = last_dim,
592
639
  **transformer_kwargs
593
640
  )
594
641
 
595
642
  def forward(
596
643
  self,
597
- seq,
598
- pairwise
644
+ seq # Int['b n']
599
645
  ):
600
646
 
601
- dna_embed = self.to_dna_embed(seq)
647
+ skips = []
648
+
649
+ # embed with one hot and add skip
650
+
651
+ x, skip = self.dna_embed(seq)
652
+ skips.append(skip)
653
+
654
+ # downs
655
+
656
+ for down in self.downs:
657
+ skips.append(x)
658
+ x = down(x)
659
+
660
+ x = rearrange(x, 'b d n -> b n d')
661
+
662
+ # attention
663
+
664
+ single, pairwise = self.transformer(x)
665
+
666
+ # ups with skips from down
667
+
668
+ x = rearrange(x, 'b n d -> b d n')
669
+
670
+ for up in self.ups:
671
+ x = up(x, skip = skips.pop())
602
672
 
603
- attended = self.transformer(dna_embed)
673
+ pred = rearrange(x, 'b l n -> b n l') # 1bp resolution
604
674
 
605
- return attended
675
+ return pred, single, pairwise
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.9
3
+ Version: 0.0.11
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -58,13 +58,13 @@ $ pip install alphagenome-pytorch
58
58
 
59
59
  ```python
60
60
  import torch
61
- from alphagenome_pytorch import TransformerTower
61
+ from alphagenome_pytorch import AlphaGenome
62
62
 
63
- transformer = TransformerTower(dim = 768, dim_pairwise = 128)
63
+ model = AlphaGenome()
64
64
 
65
- single = torch.randn(2, 512, 768)
65
+ dna = torch.randint(0, 5, (2, 8192))
66
66
 
67
- attended_single, attended_pairwise = transformer(single)
67
+ pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
68
68
  ```
69
69
 
70
70
  ## Citations
@@ -0,0 +1,6 @@
1
+ alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
2
+ alphagenome_pytorch/alphagenome.py,sha256=7O1bS-_7dlEC_r2rzI_7VvTPsXQaWsIqXd4heocysVA,18206
3
+ alphagenome_pytorch-0.0.11.dist-info/METADATA,sha256=vwetYdPP9P17KBBuDX_q1N1DgRWe_kbqGXtnqdScSvg,3382
4
+ alphagenome_pytorch-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.11.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
- alphagenome_pytorch/alphagenome.py,sha256=sRQXd-wvi0iSgGjuzfek7jpAJqJkiCSEtt0tFpAbTGo,16462
3
- alphagenome_pytorch-0.0.9.dist-info/METADATA,sha256=pqLrVpzTOuFuu6NjvwS6PujEd2BLrUJEB97nxbYTGdc,3386
4
- alphagenome_pytorch-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- alphagenome_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- alphagenome_pytorch-0.0.9.dist-info/RECORD,,