alphagenome-pytorch 0.0.10__tar.gz → 0.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: alphagenome-pytorch
3
- Version: 0.0.10
3
+ Version: 0.0.11
4
4
  Summary: AlphaGenome
5
5
  Project-URL: Homepage, https://pypi.org/project/alphagenome-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/alphagenome
@@ -243,6 +243,7 @@ class Attention(Module):
243
243
  dim,
244
244
  dim_head = 64,
245
245
  heads = 8,
246
+ kv_heads = 1,
246
247
  dim_head_qk = 128,
247
248
  dim_head_v = 192,
248
249
  dim_pairwise = None,
@@ -257,8 +258,13 @@ class Attention(Module):
257
258
 
258
259
  # splitting and merging of attention heads
259
260
 
260
- self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
261
- self.merge_heads = Rearrange('b h n d -> b n (h d)')
261
+ assert divisible_by(heads, kv_heads)
262
+ groups = heads // kv_heads
263
+
264
+ self.split_q_heads = Rearrange('b n (g h d) -> b g h n d', h = kv_heads, g = groups)
265
+ self.split_kv_heads = Rearrange('b n (h d) -> b h n d', h = kv_heads)
266
+
267
+ self.merge_heads = Rearrange('b g h n d -> b n (g h d)')
262
268
 
263
269
  # projections
264
270
 
@@ -277,7 +283,7 @@ class Attention(Module):
277
283
  nn.RMSNorm(dim_pairwise), # replace with BatchRMSNorm once crafted
278
284
  nn.GELU(),
279
285
  LinearNoBias(dim_pairwise, heads),
280
- Rearrange('b i j h -> b h i j')
286
+ Rearrange('b i j (g h) -> b g h i j', g = groups)
281
287
  )
282
288
  # variables
283
289
 
@@ -296,6 +302,7 @@ class Attention(Module):
296
302
  # they use multi-query attention, with only 1 key / value head - pretty unconventional, but maybe enough for genomic modeling
297
303
 
298
304
  q = self.split_q_heads(q)
305
+ k, v = tuple(self.split_kv_heads(t) for t in (k, v))
299
306
 
300
307
  q, k, v = self.q_norm(q), self.k_norm(k), self.v_norm(v)
301
308
 
@@ -308,7 +315,7 @@ class Attention(Module):
308
315
 
309
316
  # similarities
310
317
 
311
- sim = einsum(q, k, 'b h i d, b j d -> b h i j')
318
+ sim = einsum(q, k, 'b g h i d, b h j d -> b g h i j')
312
319
 
313
320
  # add attention bias + softclamping
314
321
 
@@ -318,7 +325,7 @@ class Attention(Module):
318
325
  assert divisible_by(sim.shape[-1], attn_bias.shape[-1])
319
326
  expand_factor = sim.shape[-1] // attn_bias.shape[-1]
320
327
 
321
- attn_bias = repeat(attn_bias, 'b h i j -> b h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
328
+ attn_bias = repeat(attn_bias, 'b g h i j -> b g h (i r1) (j r2)', r1 = expand_factor, r2 = expand_factor)
322
329
 
323
330
  sim = softclamp(sim + attn_bias, value = self.softclamp_value)
324
331
 
@@ -328,7 +335,7 @@ class Attention(Module):
328
335
 
329
336
  # aggregate
330
337
 
331
- out = einsum(attn, v, 'b h i j, b j d -> b h i d')
338
+ out = einsum(attn, v, 'b g h i j, b h j d -> b g h i d')
332
339
 
333
340
  out = self.merge_heads(out)
334
341
  return self.to_out(out)
@@ -542,7 +549,7 @@ class TransformerTower(Module):
542
549
  pairwise = maybe_pairwise_attn(pairwise) + pairwise
543
550
  pairwise = maybe_pairwise_ff(pairwise) + pairwise
544
551
 
545
- single = attn(single, rotary_emb = rotary_emb, pairwise = None) + single
552
+ single = attn(single, rotary_emb = rotary_emb, pairwise = pairwise) + single
546
553
  single = ff(single) + single
547
554
 
548
555
  return single, pairwise
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "alphagenome-pytorch"
3
- version = "0.0.10"
3
+ version = "0.0.11"
4
4
  description = "AlphaGenome"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }