alphagenome-pytorch 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,11 @@
1
1
  from alphagenome_pytorch.alphagenome import (
2
2
  AlphaGenome,
3
3
  Attention,
4
- TransformerTower
4
+ PairwiseRowAttention,
5
+ RelativePosFeatures,
6
+ RotaryEmbedding,
7
+ FeedForward,
8
+ TransformerTower,
9
+ UpresBlock,
10
+ DownresBlock,
5
11
  )
@@ -32,6 +32,9 @@ def exists(v):
32
32
  def divisible_by(num, den):
33
33
  return (num % den) == 0
34
34
 
35
+ def last(arr):
36
+ return arr[-1]
37
+
35
38
  def is_odd(num):
36
39
  return not divisible_by(num, 2)
37
40
 
@@ -128,15 +131,15 @@ class UpresBlock(Module):
128
131
  def forward(
129
132
  self,
130
133
  x,
131
- unet_skip = None
134
+ skip = None
132
135
  ):
133
136
 
134
137
  residual = x[:, :-self.pad]
135
138
  out = self.conv(x) + residual
136
139
 
137
- if exists(unet_skip):
140
+ if exists(skip):
138
141
  out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
139
- out = out + self.unet_conv(unet_skip)
142
+ out = out + self.unet_conv(skip)
140
143
 
141
144
  return self.conv_out(out) + out
142
145
 
@@ -384,7 +387,7 @@ class SingleToPairwise(Module):
384
387
  rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
385
388
  rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
386
389
 
387
- rel_sim = add('b h i j, b h j i -> b h i j', rel_q, rel_k) * 0.5
390
+ rel_sim = add('b h i j, b h j i -> b i j h', rel_q, rel_k) * 0.5
388
391
 
389
392
  sim = sim + rel_sim
390
393
 
@@ -460,7 +463,7 @@ class TransformerTower(Module):
460
463
  dropout = 0.,
461
464
  ff_expansion_factor = 2.,
462
465
  max_positions = 8192,
463
- dim_pairwise = None,
466
+ dim_pairwise = 128,
464
467
  pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
465
468
  single_to_pairwise_heads = 32, # they did 32
466
469
  pool_size = 16,
@@ -556,8 +559,10 @@ class DNAEmbed(Module):
556
559
  super().__init__()
557
560
  assert is_odd(width)
558
561
  self.dim_input = dim_input
559
- self.conv = nn.Conv1d(dim_input, dim, width, padding = width // 2)
560
- self.pointwise = nn.Conv1d(dim, dim, 1)
562
+ self.conv = Conv1d(dim_input, dim, width, padding = width // 2)
563
+ self.pointwise = Conv1d(dim, dim, 1)
564
+
565
+ self.pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
561
566
 
562
567
  def forward(
563
568
  self,
@@ -568,38 +573,96 @@ class DNAEmbed(Module):
568
573
 
569
574
  out = self.conv(x)
570
575
  out = out + self.pointwise(out)
571
- return rearrange(out, 'b d n -> b n d')
576
+ pooled = self.pool(out) # think they downsample for dna embed block
577
+
578
+ return pooled, x
572
579
 
573
580
  # classes
574
581
 
575
582
  class AlphaGenome(Module):
576
583
  def __init__(
577
584
  self,
578
- dim = 768,
585
+ dims: tuple[int, ...] = (
586
+ 768,
587
+ 896,
588
+ 1024,
589
+ 1152,
590
+ 1280,
591
+ 1408,
592
+ 1536
593
+ ),
579
594
  basepairs = 5,
580
595
  dna_embed_width = 15,
581
- dim_pairwise = None,
582
596
  transformer_kwargs: dict = dict()
583
597
  ):
584
598
  super().__init__()
599
+
585
600
  assert is_odd(dna_embed_width)
586
601
 
587
- self.to_dna_embed = DNAEmbed(dim, dim_input = basepairs, width = dna_embed_width)
602
+ assert len(dims) >= 2
603
+ first_dim, *_, last_dim = dims
604
+
605
+ self.dna_embed = DNAEmbed(first_dim, dim_input = basepairs, width = dna_embed_width)
606
+
607
+ dim_with_input = (basepairs, *dims)
608
+ dim_pairs = zip(dim_with_input[:-1], dim_with_input[1:])
609
+
610
+ downs = []
611
+ ups = []
612
+
613
+ for layer_num, (dim_in, dim_out) in enumerate(dim_pairs, start = 1):
614
+ is_first = layer_num == 1
615
+ channel_diff = dim_out - dim_in
616
+
617
+ assert channel_diff > 0
618
+
619
+ if not is_first:
620
+ down = DownresBlock(dim_in, channels_to_add = channel_diff)
621
+ downs.append(down)
622
+
623
+ up = UpresBlock(dim_out, channels_to_remove = channel_diff)
624
+ ups.insert(0, up)
588
625
 
589
- self.transformer = Transformer(
590
- dim = dim,
591
- dim_pairwise = dim_pairwise,
626
+
627
+ self.downs = ModuleList(downs)
628
+ self.ups = ModuleList(ups)
629
+
630
+ self.transformer = TransformerTower(
631
+ dim = last_dim,
592
632
  **transformer_kwargs
593
633
  )
594
634
 
595
635
  def forward(
596
636
  self,
597
- seq,
598
- pairwise
637
+ seq # Int['b n']
599
638
  ):
600
639
 
601
- dna_embed = self.to_dna_embed(seq)
640
+ skips = []
641
+
642
+ # embed with one hot and add skip
643
+
644
+ x, skip = self.dna_embed(seq)
645
+ skips.append(skip)
646
+
647
+ # downs
648
+
649
+ for down in self.downs:
650
+ skips.append(x)
651
+ x = down(x)
652
+
653
+ x = rearrange(x, 'b d n -> b n d')
654
+
655
+ # attention
656
+
657
+ single, pairwise = self.transformer(x)
658
+
659
+ # ups with skips from down
660
+
661
+ x = rearrange(x, 'b n d -> b d n')
662
+
663
+ for up in self.ups:
664
+ x = up(x, skip = skips.pop())
602
665
 
603
- attended = self.transformer(dna_embed)
666
+ pred = rearrange(x, 'b l n -> b n l') # 1bp resolution
604
667
 
605
- return attended
668
+ return pred, single, pairwise
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.9
3
+ Version: 0.0.10
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -58,13 +58,13 @@ $ pip install alphagenome-pytorch
58
58
 
59
59
  ```python
60
60
  import torch
61
- from alphagenome_pytorch import TransformerTower
61
+ from alphagenome_pytorch import AlphaGenome
62
62
 
63
- transformer = TransformerTower(dim = 768, dim_pairwise = 128)
63
+ model = AlphaGenome()
64
64
 
65
- single = torch.randn(2, 512, 768)
65
+ dna = torch.randint(0, 5, (2, 8192))
66
66
 
67
- attended_single, attended_pairwise = transformer(single)
67
+ pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
68
68
  ```
69
69
 
70
70
  ## Citations
@@ -0,0 +1,6 @@
1
+ alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
2
+ alphagenome_pytorch/alphagenome.py,sha256=gfctPlDfzjl7CUkWJBX-9omT1MrvQDIgsvKSR0lLCQc,17902
3
+ alphagenome_pytorch-0.0.10.dist-info/METADATA,sha256=SWZoRmpq6HVnddFPVMt108RFL-PywE-mniDLVRYkI3o,3382
4
+ alphagenome_pytorch-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.10.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
- alphagenome_pytorch/alphagenome.py,sha256=sRQXd-wvi0iSgGjuzfek7jpAJqJkiCSEtt0tFpAbTGo,16462
3
- alphagenome_pytorch-0.0.9.dist-info/METADATA,sha256=pqLrVpzTOuFuu6NjvwS6PujEd2BLrUJEB97nxbYTGdc,3386
4
- alphagenome_pytorch-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- alphagenome_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- alphagenome_pytorch-0.0.9.dist-info/RECORD,,