alphagenome-pytorch 0.0.10__py3-none-any.whl → 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alphagenome_pytorch/alphagenome.py +14 -7
- {alphagenome_pytorch-0.0.10.dist-info → alphagenome_pytorch-0.0.11.dist-info}/METADATA +1 -1
- alphagenome_pytorch-0.0.11.dist-info/RECORD +6 -0
- alphagenome_pytorch-0.0.10.dist-info/RECORD +0 -6
- {alphagenome_pytorch-0.0.10.dist-info → alphagenome_pytorch-0.0.11.dist-info}/WHEEL +0 -0
- {alphagenome_pytorch-0.0.10.dist-info → alphagenome_pytorch-0.0.11.dist-info}/licenses/LICENSE +0 -0
@@ -243,6 +243,7 @@ class Attention(Module):
|
|
243
243
|
dim,
|
244
244
|
dim_head = 64,
|
245
245
|
heads = 8,
|
246
|
+
kv_heads = 1,
|
246
247
|
dim_head_qk = 128,
|
247
248
|
dim_head_v = 192,
|
248
249
|
dim_pairwise = None,
|
@@ -257,8 +258,13 @@ class Attention(Module):
|
|
257
258
|
|
258
259
|
# splitting and merging of attention heads
|
259
260
|
|
260
|
-
|
261
|
-
|
261
|
+
assert divisible_by(heads, kv_heads)
|
262
|
+
groups = heads // kv_heads
|
263
|
+
|
264
|
+
self.split_q_heads = Rearrange('b n (g h d) -> b g h n d', h = kv_heads, g = groups)
|
265
|
+
self.split_kv_heads = Rearrange('b n (h d) -> b h n d', h = kv_heads)
|
266
|
+
|
267
|
+
self.merge_heads = Rearrange('b g h n d -> b n (g h d)')
|
262
268
|
|
263
269
|
# projections
|
264
270
|
|
@@ -277,7 +283,7 @@ class Attention(Module):
|
|
277
283
|
nn.RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
|
278
284
|
nn.GELU(),
|
279
285
|
LinearNoBias(dim_pairwise, heads),
|
280
|
-
Rearrange('b i j h -> b h i j')
|
286
|
+
Rearrange('b i j (g h) -> b g h i j', g = groups)
|
281
287
|
)
|
282
288
|
# variables
|
283
289
|
|
@@ -296,6 +302,7 @@ class Attention(Module):
|
|
296
302
|
# they use multi-query attention, with only 1 key / value head - pretty unconventional, but maybe enough for genomic modeling
|
297
303
|
|
298
304
|
q = self.split_q_heads(q)
|
305
|
+
k, v = tuple(self.split_kv_heads(t) for t in (k, v))
|
299
306
|
|
300
307
|
q, k, v = self.q_norm(q), self.k_norm(k), self.v_norm(v)
|
301
308
|
|
@@ -308,7 +315,7 @@ class Attention(Module):
|
|
308
315
|
|
309
316
|
# similarities
|
310
317
|
|
311
|
-
sim = einsum(q, k, 'b h i d, b j d -> b h i j')
|
318
|
+
sim = einsum(q, k, 'b g h i d, b h j d -> b g h i j')
|
312
319
|
|
313
320
|
# add attention bias + softclamping
|
314
321
|
|
@@ -318,7 +325,7 @@ class Attention(Module):
|
|
318
325
|
assert divisible_by(sim.shape[-1], attn_bias.shape[-1])
|
319
326
|
expand_factor = sim.shape[-1] // attn_bias.shape[-1]
|
320
327
|
|
321
|
-
attn_bias = repeat(attn_bias, 'b h i j -> b h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
|
328
|
+
attn_bias = repeat(attn_bias, 'b g h i j -> b g h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
|
322
329
|
|
323
330
|
sim = softclamp(sim + attn_bias, value = self.softclamp_value)
|
324
331
|
|
@@ -328,7 +335,7 @@ class Attention(Module):
|
|
328
335
|
|
329
336
|
# aggregate
|
330
337
|
|
331
|
-
out = einsum(attn, v, 'b h i j, b j d -> b h i d')
|
338
|
+
out = einsum(attn, v, 'b g h i j, b h j d -> b g h i d')
|
332
339
|
|
333
340
|
out = self.merge_heads(out)
|
334
341
|
return self.to_out(out)
|
@@ -542,7 +549,7 @@ class TransformerTower(Module):
|
|
542
549
|
pairwise = maybe_pairwise_attn(pairwise) + pairwise
|
543
550
|
pairwise = maybe_pairwise_ff(pairwise) + pairwise
|
544
551
|
|
545
|
-
single = attn(single, rotary_emb = rotary_emb, pairwise =
|
552
|
+
single = attn(single, rotary_emb = rotary_emb, pairwise = pairwise) + single
|
546
553
|
single = ff(single) + single
|
547
554
|
|
548
555
|
return single, pairwise
|
@@ -0,0 +1,6 @@
|
|
1
|
+
alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
|
2
|
+
alphagenome_pytorch/alphagenome.py,sha256=7O1bS-_7dlEC_r2rzI_7VvTPsXQaWsIqXd4heocysVA,18206
|
3
|
+
alphagenome_pytorch-0.0.11.dist-info/METADATA,sha256=vwetYdPP9P17KBBuDX_q1N1DgRWe_kbqGXtnqdScSvg,3382
|
4
|
+
alphagenome_pytorch-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
alphagenome_pytorch-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
+
alphagenome_pytorch-0.0.11.dist-info/RECORD,,
|
@@ -1,6 +0,0 @@
|
|
1
|
-
alphagenome_pytorch/__init__.py,sha256=XPNDv0q_c3nkiQo-4ROb_RQsbbKNV8KXmD6X5VnErKI,225
|
2
|
-
alphagenome_pytorch/alphagenome.py,sha256=gfctPlDfzjl7CUkWJBX-9omT1MrvQDIgsvKSR0lLCQc,17902
|
3
|
-
alphagenome_pytorch-0.0.10.dist-info/METADATA,sha256=SWZoRmpq6HVnddFPVMt108RFL-PywE-mniDLVRYkI3o,3382
|
4
|
-
alphagenome_pytorch-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
-
alphagenome_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
6
|
-
alphagenome_pytorch-0.0.10.dist-info/RECORD,,
|
File without changes
|
{alphagenome_pytorch-0.0.10.dist-info → alphagenome_pytorch-0.0.11.dist-info}/licenses/LICENSE
RENAMED
File without changes
|