alphagenome-pytorch 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/PKG-INFO +1 -1
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/alphagenome_pytorch/alphagenome.py +18 -6
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/pyproject.toml +1 -1
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/.github/workflows/python-publish.yml +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/.github/workflows/test.yml +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/.gitignore +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/LICENSE +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/README.md +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/alphagenome_pytorch/__init__.py +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/extended-figure-1.png +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/tests/test_alphagenome.py +0 -0
@@ -15,6 +15,7 @@ from einops import rearrange, repeat, einsum
|
|
15
15
|
# b - batch
|
16
16
|
# h - heads
|
17
17
|
# n - sequence
|
18
|
+
# p - relative positions
|
18
19
|
# d - feature dimension
|
19
20
|
|
20
21
|
# constants
|
@@ -79,9 +80,15 @@ def apply_rotary_pos_emb(pos, t):
|
|
79
80
|
# 'central mask features' - relative positions for constituting pairwise rep
|
80
81
|
|
81
82
|
class RelativePosFeatures(Module):
|
83
|
+
def __init__(self, pool_size = 16):
|
84
|
+
super().__init__()
|
85
|
+
self.pool_size = pool_size
|
86
|
+
|
82
87
|
def forward(self, single):
|
83
88
|
|
84
89
|
_, seq_len, dim = single.shape
|
90
|
+
|
91
|
+
seq_len //= self.pool_size
|
85
92
|
half_dim = dim // 2
|
86
93
|
|
87
94
|
rel_pos = arange(2 * seq_len - 1) - (seq_len - 1)
|
@@ -260,6 +267,8 @@ class SingleToPairwise(Module):
|
|
260
267
|
|
261
268
|
single = self.avg_pool(single)
|
262
269
|
|
270
|
+
pool_seq_len = single.shape[1]
|
271
|
+
|
263
272
|
q, k = self.to_qk(single).chunk(2, dim = -1)
|
264
273
|
q, k = tuple(self.split_heads(t) for t in (q, k))
|
265
274
|
|
@@ -271,12 +280,14 @@ class SingleToPairwise(Module):
|
|
271
280
|
|
272
281
|
q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
|
273
282
|
|
274
|
-
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b
|
275
|
-
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b
|
283
|
+
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
|
284
|
+
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
|
285
|
+
|
286
|
+
rel_q, rel_k = tuple(t[..., :pool_seq_len, :] for t in (rel_q, rel_k))
|
276
287
|
|
277
|
-
rel_sim =
|
288
|
+
rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
|
278
289
|
|
279
|
-
sim =
|
290
|
+
sim = sim + rel_sim
|
280
291
|
|
281
292
|
pairwise_from_sim = self.qk_to_pairwise(sim)
|
282
293
|
|
@@ -353,6 +364,7 @@ class TransformerTower(Module):
|
|
353
364
|
dim_pairwise = None,
|
354
365
|
pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
|
355
366
|
single_to_pairwise_heads = 32, # they did 32
|
367
|
+
pool_size = 16,
|
356
368
|
attn_kwargs: dict = dict(),
|
357
369
|
ff_kwargs: dict = dict()
|
358
370
|
):
|
@@ -363,7 +375,7 @@ class TransformerTower(Module):
|
|
363
375
|
|
364
376
|
self.pairwise_every = pairwise_every_num_single_blocks
|
365
377
|
|
366
|
-
self.rel_pos_features = RelativePosFeatures()
|
378
|
+
self.rel_pos_features = RelativePosFeatures(pool_size)
|
367
379
|
|
368
380
|
self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
|
369
381
|
|
@@ -381,7 +393,7 @@ class TransformerTower(Module):
|
|
381
393
|
single_to_pairwise, pairwise_attn, pairwise_ff = None, None, None
|
382
394
|
|
383
395
|
if divisible_by(layer_index, self.pairwise_every):
|
384
|
-
single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads)
|
396
|
+
single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads, pool_size = pool_size)
|
385
397
|
pairwise_attn = PairwiseRowAttention(dim_pairwise)
|
386
398
|
pairwise_ff = FeedForward(dim = dim_pairwise, expansion_factor = ff_expansion_factor)
|
387
399
|
|
{alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.6}/.github/workflows/python-publish.yml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|