alphagenome-pytorch 0.0.8__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,11 @@
1
1
  from alphagenome_pytorch.alphagenome import (
2
2
  AlphaGenome,
3
3
  Attention,
4
- TransformerTower
4
+ PairwiseRowAttention,
5
+ RelativePosFeatures,
6
+ RotaryEmbedding,
7
+ FeedForward,
8
+ TransformerTower,
9
+ UpresBlock,
10
+ DownresBlock,
5
11
  )
@@ -4,7 +4,9 @@ from functools import partial
4
4
  import torch
5
5
  from torch import nn, cat, stack, arange, logspace
6
6
  import torch.nn.functional as F
7
- from torch.nn import Linear, Sequential, Module, ModuleList
7
+ from torch.nn import Conv1d, Linear, Sequential, Module, ModuleList
8
+
9
+ from torch.nn.utils.parametrize import register_parametrization
8
10
 
9
11
  from einx import add, multiply, greater
10
12
  from einops.layers.torch import Rearrange, Reduce
@@ -30,6 +32,9 @@ def exists(v):
30
32
  def divisible_by(num, den):
31
33
  return (num % den) == 0
32
34
 
35
+ def last(arr):
36
+ return arr[-1]
37
+
33
38
  def is_odd(num):
34
39
  return not divisible_by(num, 2)
35
40
 
@@ -42,6 +47,104 @@ def default(v, d):
42
47
  def softclamp(t, value = 5.):
43
48
  return (t / value).tanh() * value
44
49
 
50
+ # convolutional unet related
51
+
52
+ class WeightStandardConv(Conv1d):
53
+ def __init__(
54
+ self,
55
+ dim,
56
+ dim_out,
57
+ width,
58
+ *args,
59
+ **kwargs
60
+ ):
61
+ super().__init__(dim, dim_out, width, *args, **kwargs)
62
+
63
+ register_parametrization(self, 'weight', nn.LayerNorm(self.weight.shape, elementwise_affine = False))
64
+
65
+ class ConvBlock(Module):
66
+ def __init__(
67
+ self,
68
+ dim,
69
+ width = 5,
70
+ dim_out = None
71
+ ):
72
+ super().__init__()
73
+ assert is_odd(width)
74
+ dim_out = default(dim_out, dim)
75
+
76
+ conv_klass = Conv1d if width == 1 else WeightStandardConv
77
+
78
+ self.conv = conv_klass(dim, dim_out, width, padding = width // 2)
79
+
80
+ def forward(self, x):
81
+
82
+ x = F.gelu(x)
83
+ out = self.conv(x)
84
+ return out
85
+
86
+ class DownresBlock(Module):
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ channels_to_add = 128 # this is new as well? instead of doubling channels, they add 128 at a time, and use padding or slicing for the residual
91
+ ):
92
+ super().__init__()
93
+
94
+ dim_out = dim + channels_to_add
95
+ self.pad = channels_to_add
96
+
97
+ self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
98
+ self.conv_out = ConvBlock(dim_out, width = 1)
99
+
100
+ self.max_pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
101
+
102
+ def forward(self, x):
103
+
104
+ residual = F.pad(x, (0, 0, 0, self.pad), value = 0.)
105
+
106
+ out = self.conv(x) + residual
107
+
108
+ out = self.conv_out(out) + out
109
+
110
+ return self.max_pool(out)
111
+
112
+ class UpresBlock(Module):
113
+ def __init__(
114
+ self,
115
+ dim,
116
+ channels_to_remove = 128,
117
+ residual_scale_init = .9
118
+ ):
119
+ super().__init__()
120
+
121
+ dim_out = dim - channels_to_remove
122
+ self.pad = channels_to_remove
123
+
124
+ self.conv = ConvBlock(dim, width = 1, dim_out = dim_out)
125
+ self.unet_conv = ConvBlock(dim_out, width = 1)
126
+
127
+ self.conv_out = ConvBlock(dim_out, width = 1)
128
+
129
+ self.residual_scale = nn.Parameter(torch.ones(1,) * residual_scale_init)
130
+
131
+ def forward(
132
+ self,
133
+ x,
134
+ skip = None
135
+ ):
136
+
137
+ residual = x[:, :-self.pad]
138
+ out = self.conv(x) + residual
139
+
140
+ if exists(skip):
141
+ out = repeat(out, 'b c n -> b c (n upsample)', upsample = 2) * self.residual_scale
142
+ out = out + self.unet_conv(skip)
143
+
144
+ return self.conv_out(out) + out
145
+
146
+ # position related
147
+
45
148
  def relative_shift(t):
46
149
  *leading_dims, seq_len, dim = t.shape
47
150
  t = F.pad(t, (1, 0), value = 0.)
@@ -284,7 +387,7 @@ class SingleToPairwise(Module):
284
387
  rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
285
388
  rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b h n p'))
286
389
 
287
- rel_sim = add('b h i j, b h j i -> b h i j', rel_q, rel_k) * 0.5
390
+ rel_sim = add('b h i j, b h j i -> b i j h', rel_q, rel_k) * 0.5
288
391
 
289
392
  sim = sim + rel_sim
290
393
 
@@ -360,7 +463,7 @@ class TransformerTower(Module):
360
463
  dropout = 0.,
361
464
  ff_expansion_factor = 2.,
362
465
  max_positions = 8192,
363
- dim_pairwise = None,
466
+ dim_pairwise = 128,
364
467
  pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
365
468
  single_to_pairwise_heads = 32, # they did 32
366
469
  pool_size = 16,
@@ -456,8 +559,10 @@ class DNAEmbed(Module):
456
559
  super().__init__()
457
560
  assert is_odd(width)
458
561
  self.dim_input = dim_input
459
- self.conv = nn.Conv1d(dim_input, dim, width, padding = width // 2)
460
- self.pointwise = nn.Conv1d(dim, dim, 1)
562
+ self.conv = Conv1d(dim_input, dim, width, padding = width // 2)
563
+ self.pointwise = Conv1d(dim, dim, 1)
564
+
565
+ self.pool = Reduce('b d (n pool) -> b d n', 'max', pool = 2)
461
566
 
462
567
  def forward(
463
568
  self,
@@ -468,38 +573,96 @@ class DNAEmbed(Module):
468
573
 
469
574
  out = self.conv(x)
470
575
  out = out + self.pointwise(out)
471
- return rearrange(out, 'b d n -> b n d')
576
+ pooled = self.pool(out) # think they downsample for dna embed block
577
+
578
+ return pooled, x
472
579
 
473
580
  # classes
474
581
 
475
582
  class AlphaGenome(Module):
476
583
  def __init__(
477
584
  self,
478
- dim = 768,
585
+ dims: tuple[int, ...] = (
586
+ 768,
587
+ 896,
588
+ 1024,
589
+ 1152,
590
+ 1280,
591
+ 1408,
592
+ 1536
593
+ ),
479
594
  basepairs = 5,
480
595
  dna_embed_width = 15,
481
- dim_pairwise = None,
482
596
  transformer_kwargs: dict = dict()
483
597
  ):
484
598
  super().__init__()
599
+
485
600
  assert is_odd(dna_embed_width)
486
601
 
487
- self.to_dna_embed = DNAEmbed(dim, dim_input = basepairs, width = dna_embed_width)
602
+ assert len(dims) >= 2
603
+ first_dim, *_, last_dim = dims
604
+
605
+ self.dna_embed = DNAEmbed(first_dim, dim_input = basepairs, width = dna_embed_width)
488
606
 
489
- self.transformer = Transformer(
490
- dim = dim,
491
- dim_pairwise = dim_pairwise,
607
+ dim_with_input = (basepairs, *dims)
608
+ dim_pairs = zip(dim_with_input[:-1], dim_with_input[1:])
609
+
610
+ downs = []
611
+ ups = []
612
+
613
+ for layer_num, (dim_in, dim_out) in enumerate(dim_pairs, start = 1):
614
+ is_first = layer_num == 1
615
+ channel_diff = dim_out - dim_in
616
+
617
+ assert channel_diff > 0
618
+
619
+ if not is_first:
620
+ down = DownresBlock(dim_in, channels_to_add = channel_diff)
621
+ downs.append(down)
622
+
623
+ up = UpresBlock(dim_out, channels_to_remove = channel_diff)
624
+ ups.insert(0, up)
625
+
626
+
627
+ self.downs = ModuleList(downs)
628
+ self.ups = ModuleList(ups)
629
+
630
+ self.transformer = TransformerTower(
631
+ dim = last_dim,
492
632
  **transformer_kwargs
493
633
  )
494
634
 
495
635
  def forward(
496
636
  self,
497
- seq,
498
- pairwise
637
+ seq # Int['b n']
499
638
  ):
500
639
 
501
- dna_embed = self.to_dna_embed(seq)
640
+ skips = []
641
+
642
+ # embed with one hot and add skip
643
+
644
+ x, skip = self.dna_embed(seq)
645
+ skips.append(skip)
646
+
647
+ # downs
648
+
649
+ for down in self.downs:
650
+ skips.append(x)
651
+ x = down(x)
652
+
653
+ x = rearrange(x, 'b d n -> b n d')
654
+
655
+ # attention
656
+
657
+ single, pairwise = self.transformer(x)
658
+
659
+ # ups with skips from down
660
+
661
+ x = rearrange(x, 'b n d -> b d n')
662
+
663
+ for up in self.ups:
664
+ x = up(x, skip = skips.pop())
502
665
 
503
- attended = self.transformer(dna_embed)
666
+ pred = rearrange(x, 'b l n -> b n l') # 1bp resolution
504
667
 
505
- return attended
668
+ return pred, single, pairwise
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.8
3
+ Version: 0.0.10
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -58,13 +58,13 @@ $ pip install alphagenome-pytorch
58
58
 
59
59
  ```python
60
60
  import torch
61
- from alphagenome_pytorch import TransformerTower
61
+ from alphagenome_pytorch import AlphaGenome
62
62
 
63
- transformer = TransformerTower(dim = 768, dim_pairwise = 128)
63
+ model = AlphaGenome()
64
64
 
65
- single = torch.randn(2, 512, 768)
65
+ dna = torch.randint(0, 5, (2, 8192))
66
66
 
67
- attended_single, attended_pairwise = transformer(single)
67
+ pred_nucleotide, single, pairwise = model(dna) # (2, 8192, 5), (2, 64, 1536), (2, 4, 4, 1536)
68
68
  ```
69
69
 
70
70
  ## Citations
@@ -0,0 +1,6 @@
1
+ alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
2
+ alphagenome_pytorch/alphagenome.py,sha256=gfctPlDfzjl7CUkWJBX-9omT1MrvQDIgsvKSR0lLCQc,17902
3
+ alphagenome_pytorch-0.0.10.dist-info/METADATA,sha256=SWZoRmpq6HVnddFPVMt108RFL-PywE-mniDLVRYkI3o,3382
4
+ alphagenome_pytorch-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ alphagenome_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ alphagenome_pytorch-0.0.10.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- alphagenome_pytorch/__init__.py,sha256=7plC_YRm0UapNCl9hJEhFxXE-ELGKVy-DtuO5GUQxGI,101
2
- alphagenome_pytorch/alphagenome.py,sha256=UjlOJGTMshcaNqmY0r6IRCgRunu7BTZyhbk2vNi5Mis,13948
3
- alphagenome_pytorch-0.0.8.dist-info/METADATA,sha256=mhlnCRy7Ovq_Gt9ul_iIqws4zdoOcLjfQTBrlbcNic8,3386
4
- alphagenome_pytorch-0.0.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- alphagenome_pytorch-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- alphagenome_pytorch-0.0.8.dist-info/RECORD,,