alphagenome-pytorch 0.0.4__tar.gz → 0.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/PKG-INFO +1 -1
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/alphagenome_pytorch/alphagenome.py +19 -6
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/pyproject.toml +1 -1
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/.github/workflows/python-publish.yml +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/.github/workflows/test.yml +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/.gitignore +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/LICENSE +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/README.md +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/alphagenome_pytorch/__init__.py +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/extended-figure-1.png +0 -0
- {alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/tests/test_alphagenome.py +0 -0
@@ -15,6 +15,7 @@ from einops import rearrange, repeat, einsum
|
|
15
15
|
# b - batch
|
16
16
|
# h - heads
|
17
17
|
# n - sequence
|
18
|
+
# p - relative positions
|
18
19
|
# d - feature dimension
|
19
20
|
|
20
21
|
# constants
|
@@ -79,9 +80,15 @@ def apply_rotary_pos_emb(pos, t):
|
|
79
80
|
# 'central mask features' - relative positions for constituting pairwise rep
|
80
81
|
|
81
82
|
class RelativePosFeatures(Module):
|
83
|
+
def __init__(self, pool_size = 16):
|
84
|
+
super().__init__()
|
85
|
+
self.pool_size = pool_size
|
86
|
+
|
82
87
|
def forward(self, single):
|
83
88
|
|
84
89
|
_, seq_len, dim = single.shape
|
90
|
+
|
91
|
+
seq_len //= self.pool_size
|
85
92
|
half_dim = dim // 2
|
86
93
|
|
87
94
|
rel_pos = arange(2 * seq_len - 1) - (seq_len - 1)
|
@@ -271,12 +278,17 @@ class SingleToPairwise(Module):
|
|
271
278
|
|
272
279
|
q_rel_bias, k_rel_bias = self.qk_rel_pos_bias
|
273
280
|
|
274
|
-
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b
|
275
|
-
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b
|
281
|
+
rel_q = relative_shift(einsum(q + q_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
|
282
|
+
rel_k = relative_shift(einsum(k + k_rel_bias, rel_pos_encoding, 'b n h d, p h d -> b n p h'))
|
276
283
|
|
277
|
-
|
284
|
+
_, seq, rel_pos_seq, _ = rel_q.shape
|
285
|
+
crop_padding = (rel_pos_seq - seq) // 2
|
278
286
|
|
279
|
-
|
287
|
+
rel_q, rel_k = tuple(t[..., crop_padding:(crop_padding + seq), :] for t in (rel_q, rel_k))
|
288
|
+
|
289
|
+
rel_sim = add('b i j d, b j i d -> b i j d', rel_q, rel_k) * 0.5
|
290
|
+
|
291
|
+
sim = sim + rel_sim
|
280
292
|
|
281
293
|
pairwise_from_sim = self.qk_to_pairwise(sim)
|
282
294
|
|
@@ -353,6 +365,7 @@ class TransformerTower(Module):
|
|
353
365
|
dim_pairwise = None,
|
354
366
|
pairwise_every_num_single_blocks = 2, # how often to do a pairwise block
|
355
367
|
single_to_pairwise_heads = 32, # they did 32
|
368
|
+
pool_size = 16,
|
356
369
|
attn_kwargs: dict = dict(),
|
357
370
|
ff_kwargs: dict = dict()
|
358
371
|
):
|
@@ -363,7 +376,7 @@ class TransformerTower(Module):
|
|
363
376
|
|
364
377
|
self.pairwise_every = pairwise_every_num_single_blocks
|
365
378
|
|
366
|
-
self.rel_pos_features = RelativePosFeatures()
|
379
|
+
self.rel_pos_features = RelativePosFeatures(pool_size)
|
367
380
|
|
368
381
|
self.rotary_emb = RotaryEmbedding(dim_head_qk, max_positions = max_positions)
|
369
382
|
|
@@ -381,7 +394,7 @@ class TransformerTower(Module):
|
|
381
394
|
single_to_pairwise, pairwise_attn, pairwise_ff = None, None, None
|
382
395
|
|
383
396
|
if divisible_by(layer_index, self.pairwise_every):
|
384
|
-
single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads)
|
397
|
+
single_to_pairwise = SingleToPairwise(dim = dim, dim_pairwise = dim_pairwise, heads = single_to_pairwise_heads, pool_size = pool_size)
|
385
398
|
pairwise_attn = PairwiseRowAttention(dim_pairwise)
|
386
399
|
pairwise_ff = FeedForward(dim = dim_pairwise, expansion_factor = ff_expansion_factor)
|
387
400
|
|
{alphagenome_pytorch-0.0.4 → alphagenome_pytorch-0.0.5}/.github/workflows/python-publish.yml
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|