lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lalamo/__init__.py +1 -1
- lalamo/commands.py +247 -14
- lalamo/common.py +33 -0
- lalamo/data/__init__.py +3 -2
- lalamo/data/huggingface_message.py +4 -5
- lalamo/main.py +274 -9
- lalamo/message_processor.py +19 -1
- lalamo/model_import/common.py +17 -1
- lalamo/model_import/model_specs/mistral.py +5 -0
- lalamo/model_import/remote_registry.py +44 -0
- lalamo/models/__init__.py +3 -0
- lalamo/models/common.py +22 -0
- lalamo/models/compile_helpers.py +58 -0
- lalamo/models/language_model.py +342 -56
- lalamo/models/lm_helpers.py +198 -0
- lalamo/modules/decoder.py +4 -0
- lalamo/modules/token_mixers/mamba.py +345 -105
- lalamo/speculator/__init__.py +0 -2
- lalamo/speculator/inference.py +35 -61
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/METADATA +1 -1
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/RECORD +25 -23
- lalamo/speculator/estimator.py +0 -127
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/WHEEL +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/entry_points.txt +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/licenses/LICENSE +0 -0
- {lalamo-0.6.4.dist-info → lalamo-0.6.6.dist-info}/top_level.txt +0 -0
|
@@ -25,9 +25,65 @@ __all__ = [
|
|
|
25
25
|
"Mamba2Result",
|
|
26
26
|
"SeparableCausalConv",
|
|
27
27
|
"SeparableCausalConvConfig",
|
|
28
|
+
"exp_segsum",
|
|
29
|
+
"fused_ssd_intra_chunk",
|
|
28
30
|
]
|
|
29
31
|
|
|
30
32
|
|
|
33
|
+
def exp_segsum(x: Float[Array, "... T"]) -> Float[Array, "... T T"]:
|
|
34
|
+
"""Compute exp(segsum(x)) as lower-triangular matrix using cumsum difference."""
|
|
35
|
+
seq_len = x.shape[-1]
|
|
36
|
+
cs = jnp.cumsum(x, axis=-1)
|
|
37
|
+
diff = cs[..., :, None] - cs[..., None, :]
|
|
38
|
+
mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))
|
|
39
|
+
return jnp.where(mask, jnp.exp(diff), 0.0)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def fused_ssd_intra_chunk(
|
|
43
|
+
a_cumsum: Float[Array, "groups heads_per_group chunks chunk_size"],
|
|
44
|
+
cb: Float[Array, "chunks chunk_size chunk_size groups"],
|
|
45
|
+
x: Float[Array, "chunks chunk_size groups heads_per_group head_dim"],
|
|
46
|
+
) -> Float[Array, "chunks chunk_size groups heads_per_group head_dim"]:
|
|
47
|
+
"""Compute intra-chunk diagonal block outputs for SSD.
|
|
48
|
+
|
|
49
|
+
Avoids materializing the full global L matrix by computing decay locally per (chunk, group, head).
|
|
50
|
+
"""
|
|
51
|
+
groups, heads_per_group, chunks, chunk_size = a_cumsum.shape
|
|
52
|
+
|
|
53
|
+
def compute_one(
|
|
54
|
+
a_cs: Float[Array, " chunk_size"],
|
|
55
|
+
cb_slice: Float[Array, "chunk_size chunk_size"],
|
|
56
|
+
x_slice: Float[Array, "chunk_size head_dim"],
|
|
57
|
+
) -> Float[Array, "chunk_size head_dim"]:
|
|
58
|
+
diff = a_cs[:, None] - a_cs[None, :]
|
|
59
|
+
mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=jnp.bool_))
|
|
60
|
+
decay_local = jnp.where(mask, jnp.exp(diff), 0.0)
|
|
61
|
+
weighted = decay_local * cb_slice
|
|
62
|
+
return weighted @ x_slice
|
|
63
|
+
|
|
64
|
+
def compute_chunk_group_head(chunk_idx: int, group_idx: int, head_idx: int) -> Float[Array, "chunk_size head_dim"]:
|
|
65
|
+
return compute_one(
|
|
66
|
+
a_cumsum[group_idx, head_idx, chunk_idx, :],
|
|
67
|
+
cb[chunk_idx, :, :, group_idx],
|
|
68
|
+
x[chunk_idx, :, group_idx, head_idx, :],
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def over_heads(chunk_idx: int, group_idx: int) -> Float[Array, "heads_per_group chunk_size head_dim"]:
|
|
72
|
+
return jax.vmap(lambda head_idx: compute_chunk_group_head(chunk_idx, group_idx, head_idx))(
|
|
73
|
+
jnp.arange(heads_per_group),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def over_groups(chunk_idx: int) -> Float[Array, "groups heads_per_group chunk_size head_dim"]:
|
|
77
|
+
return jax.vmap(lambda group_idx: over_heads(chunk_idx, group_idx))(jnp.arange(groups))
|
|
78
|
+
|
|
79
|
+
result = jax.vmap(over_groups)(jnp.arange(chunks))
|
|
80
|
+
|
|
81
|
+
return rearrange(
|
|
82
|
+
result,
|
|
83
|
+
"chunks groups heads_per_group chunk_size head_dim -> chunks chunk_size groups heads_per_group head_dim",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
31
87
|
Mamba2Result = TokenMixerResult[Mamba2StateLayer]
|
|
32
88
|
|
|
33
89
|
|
|
@@ -156,6 +212,19 @@ class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
|
|
|
156
212
|
updated_state,
|
|
157
213
|
)
|
|
158
214
|
|
|
215
|
+
def step(
|
|
216
|
+
self,
|
|
217
|
+
token: Float[Array, " channels"],
|
|
218
|
+
state: Float[Array, "kernel_minus_1 channels"],
|
|
219
|
+
) -> tuple[Float[Array, " channels"], Float[Array, "kernel_minus_1 channels"]]:
|
|
220
|
+
"""Single-token conv update without full convolution overhead."""
|
|
221
|
+
full_input = jnp.concatenate([state, token[None, :]], axis=0)
|
|
222
|
+
output = einsum(full_input, self.weights, "kernel channels, channels kernel -> channels")
|
|
223
|
+
if self.biases is not None:
|
|
224
|
+
output = output + self.biases
|
|
225
|
+
new_state = jnp.concatenate([state[1:], token[None, :]], axis=0)
|
|
226
|
+
return output, new_state
|
|
227
|
+
|
|
159
228
|
def export_weights(self) -> ParameterTree:
|
|
160
229
|
result: dict[str, Array] = {"weights": self.weights}
|
|
161
230
|
if self.biases is not None:
|
|
@@ -188,6 +257,8 @@ class Mamba2Config(TokenMixerConfigBase):
|
|
|
188
257
|
has_in_biases: bool
|
|
189
258
|
has_out_biases: bool
|
|
190
259
|
|
|
260
|
+
chunk_size: int = 256
|
|
261
|
+
|
|
191
262
|
@property
|
|
192
263
|
def inner_dim(self) -> int:
|
|
193
264
|
return self.num_heads * self.head_dim
|
|
@@ -330,99 +401,258 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
|
|
|
330
401
|
f"Number of value heads ({self.num_heads}) must be divisible by number of groups ({self.num_groups})",
|
|
331
402
|
)
|
|
332
403
|
|
|
333
|
-
def
|
|
404
|
+
def _step(
|
|
405
|
+
self,
|
|
406
|
+
values: Float[Array, "heads head_dim"],
|
|
407
|
+
keys: Float[Array, "groups state_dim"],
|
|
408
|
+
queries: Float[Array, "groups state_dim"],
|
|
409
|
+
dt_log: Float[Array, " heads"],
|
|
410
|
+
state: Float[Array, "heads head_dim state_dim"],
|
|
411
|
+
) -> tuple[Float[Array, "heads head_dim"], Float[Array, "heads head_dim state_dim"]]:
|
|
412
|
+
"""Single-token SSM state update without scan overhead."""
|
|
413
|
+
heads_per_group = self.num_heads // self.num_groups
|
|
414
|
+
|
|
415
|
+
dt = jax.nn.softplus(dt_log)
|
|
416
|
+
decay = jnp.exp(-dt)[:, None, None]
|
|
417
|
+
mix = dt[:, None, None]
|
|
418
|
+
|
|
419
|
+
keys_expanded = jnp.repeat(keys, heads_per_group, axis=0)
|
|
420
|
+
queries_expanded = jnp.repeat(queries, heads_per_group, axis=0)
|
|
421
|
+
values_norm = values / (dt[:, None] + 1e-8)
|
|
422
|
+
|
|
423
|
+
input_contribution = mix * values_norm[:, :, None] * keys_expanded[:, None, :]
|
|
424
|
+
new_state = decay * state + input_contribution
|
|
425
|
+
output = einsum(new_state, queries_expanded, "heads head_dim state_dim, heads state_dim -> heads head_dim")
|
|
426
|
+
|
|
427
|
+
return output, new_state
|
|
428
|
+
|
|
429
|
+
def _decode_step(
|
|
334
430
|
self,
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
431
|
+
inputs: Float[Array, "1 channels"],
|
|
432
|
+
state: Mamba2StateLayer,
|
|
433
|
+
) -> Mamba2Result:
|
|
434
|
+
"""Optimized path for single-token decode without scan machinery."""
|
|
435
|
+
token = inputs[0]
|
|
436
|
+
|
|
437
|
+
conv_in, gate, dt_log = self.in_projection(token)
|
|
438
|
+
conv_out, new_conv_state = self.conv.step(conv_in, state.conv_state)
|
|
439
|
+
conv_activated = self.config.activation(conv_out)
|
|
440
|
+
|
|
441
|
+
values_flat, input_proj_flat, output_proj_flat = jnp.split(
|
|
442
|
+
conv_activated,
|
|
443
|
+
[self.inner_dim, self.inner_dim + self.num_groups * self.state_dim],
|
|
444
|
+
)
|
|
445
|
+
values = rearrange(values_flat, "(heads head_dim) -> heads head_dim", heads=self.num_heads)
|
|
446
|
+
keys = rearrange(input_proj_flat, "(groups state_dim) -> groups state_dim", groups=self.num_groups)
|
|
447
|
+
queries = rearrange(output_proj_flat, "(groups state_dim) -> groups state_dim", groups=self.num_groups)
|
|
448
|
+
|
|
449
|
+
y, new_ssm_state = self._step(values, keys, queries, dt_log, state.ssm_state)
|
|
450
|
+
|
|
451
|
+
y = y + self.skip_connection_weight[:, None] * values
|
|
452
|
+
y = rearrange(y, "heads head_dim -> (heads head_dim)")
|
|
453
|
+
gated = y * jax.nn.silu(gate + self.gate_bias)
|
|
454
|
+
(output,) = self.out_projection(gated)
|
|
455
|
+
|
|
456
|
+
return Mamba2Result(
|
|
457
|
+
outputs=output[None, :],
|
|
458
|
+
state=Mamba2StateLayer(new_conv_state, new_ssm_state),
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
def _chunked_scan(
|
|
462
|
+
self,
|
|
463
|
+
values: Float[Array, "suffix_tokens heads head_dim"],
|
|
464
|
+
keys: Float[Array, "suffix_tokens groups state_dim"],
|
|
465
|
+
queries: Float[Array, "suffix_tokens groups state_dim"],
|
|
466
|
+
dt: Float[Array, "suffix_tokens heads"],
|
|
467
|
+
initial_state: Float[Array, "heads head_dim state_dim"],
|
|
468
|
+
chunk_size: int,
|
|
340
469
|
num_steps: Int[Array, ""] | int,
|
|
341
|
-
|
|
342
|
-
Float[Array, "suffix_tokens heads
|
|
343
|
-
Float[Array, "heads
|
|
344
|
-
]:
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
470
|
+
d: Float[Array, " heads"] | None = None,
|
|
471
|
+
z: Float[Array, "suffix_tokens heads head_dim"] | None = None,
|
|
472
|
+
z_bias: Float[Array, "heads head_dim"] | None = None,
|
|
473
|
+
) -> tuple[Float[Array, "suffix_tokens heads head_dim"], Float[Array, "heads head_dim state_dim"]]:
|
|
474
|
+
"""Chunked parallel scan implementing the SSD algorithm."""
|
|
475
|
+
seq_len = values.shape[0]
|
|
476
|
+
num_steps = jnp.asarray(num_steps, dtype=jnp.int32)
|
|
477
|
+
|
|
478
|
+
pad_len = (chunk_size - seq_len % chunk_size) % chunk_size
|
|
479
|
+
if pad_len > 0:
|
|
480
|
+
values = jnp.pad(values, ((0, pad_len), (0, 0), (0, 0)))
|
|
481
|
+
keys = jnp.pad(keys, ((0, pad_len), (0, 0), (0, 0)))
|
|
482
|
+
queries = jnp.pad(queries, ((0, pad_len), (0, 0), (0, 0)))
|
|
483
|
+
dt = jnp.pad(dt, ((0, pad_len), (0, 0)))
|
|
484
|
+
if z is not None:
|
|
485
|
+
z = jnp.pad(z, ((0, pad_len), (0, 0), (0, 0)))
|
|
486
|
+
|
|
487
|
+
values_orig = values
|
|
488
|
+
keys_orig = keys
|
|
489
|
+
dt_orig = dt
|
|
490
|
+
|
|
491
|
+
padded_len = values.shape[0]
|
|
492
|
+
position_indices = jnp.arange(padded_len)
|
|
493
|
+
valid_mask = (position_indices < num_steps).astype(values.dtype)
|
|
494
|
+
values = values * valid_mask[:, None, None]
|
|
495
|
+
keys = keys * valid_mask[:, None, None]
|
|
496
|
+
|
|
497
|
+
values = rearrange(
|
|
498
|
+
values,
|
|
499
|
+
"(chunks chunk_size) (groups heads_per_group) head_dim"
|
|
500
|
+
" -> chunks chunk_size groups heads_per_group head_dim",
|
|
501
|
+
chunk_size=chunk_size,
|
|
502
|
+
groups=self.num_groups,
|
|
503
|
+
)
|
|
504
|
+
log_decay = rearrange(
|
|
505
|
+
-dt,
|
|
506
|
+
"(chunks chunk_size) (groups heads_per_group) -> groups heads_per_group chunks chunk_size",
|
|
507
|
+
chunk_size=chunk_size,
|
|
508
|
+
groups=self.num_groups,
|
|
509
|
+
)
|
|
510
|
+
keys_chunked = rearrange(
|
|
511
|
+
keys,
|
|
512
|
+
"(chunks chunk_size) groups state_dim -> chunks chunk_size groups state_dim",
|
|
513
|
+
chunk_size=chunk_size,
|
|
514
|
+
)
|
|
515
|
+
queries_chunked = rearrange(
|
|
516
|
+
queries,
|
|
517
|
+
"(chunks chunk_size) groups state_dim -> chunks chunk_size groups state_dim",
|
|
518
|
+
chunk_size=chunk_size,
|
|
519
|
+
)
|
|
520
|
+
log_decay_cumsum = jnp.cumsum(log_decay, axis=-1)
|
|
379
521
|
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
522
|
+
queries_keys_prod = einsum(
|
|
523
|
+
queries_chunked,
|
|
524
|
+
keys_chunked,
|
|
525
|
+
"chunks query_pos groups state_dim, chunks key_pos groups state_dim -> chunks query_pos key_pos groups",
|
|
526
|
+
)
|
|
527
|
+
y_diag = fused_ssd_intra_chunk(log_decay_cumsum, queries_keys_prod, values)
|
|
528
|
+
|
|
529
|
+
decay_states = jnp.exp(log_decay_cumsum[:, :, :, -1:] - log_decay_cumsum)
|
|
530
|
+
states = einsum(
|
|
531
|
+
keys_chunked,
|
|
532
|
+
decay_states,
|
|
533
|
+
values,
|
|
534
|
+
"chunks chunk_size groups state_dim, groups heads_per_group chunks chunk_size,"
|
|
535
|
+
" chunks chunk_size groups heads_per_group head_dim"
|
|
536
|
+
" -> chunks groups heads_per_group head_dim state_dim",
|
|
537
|
+
)
|
|
394
538
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
539
|
+
initial_state_grouped = rearrange(
|
|
540
|
+
initial_state,
|
|
541
|
+
"(groups heads_per_group) head_dim state_dim -> groups heads_per_group head_dim state_dim",
|
|
542
|
+
groups=self.num_groups,
|
|
543
|
+
)
|
|
544
|
+
states = jnp.concatenate([initial_state_grouped[None, ...], states], axis=0)
|
|
545
|
+
log_decay_chunk_ends = jnp.pad(log_decay_cumsum[:, :, :, -1], ((0, 0), (0, 0), (1, 0)))
|
|
546
|
+
decay_chunk = exp_segsum(log_decay_chunk_ends)
|
|
547
|
+
new_states = einsum(
|
|
548
|
+
decay_chunk,
|
|
549
|
+
states,
|
|
550
|
+
"groups heads_per_group out_idx chunks,"
|
|
551
|
+
" chunks groups heads_per_group head_dim state_dim"
|
|
552
|
+
" -> out_idx groups heads_per_group head_dim state_dim",
|
|
553
|
+
)
|
|
554
|
+
states = new_states[:-1]
|
|
555
|
+
|
|
556
|
+
state_decay_out = jnp.exp(log_decay_cumsum)
|
|
557
|
+
y_off = einsum(
|
|
558
|
+
queries_chunked,
|
|
559
|
+
states,
|
|
560
|
+
state_decay_out,
|
|
561
|
+
"chunks chunk_size groups state_dim,"
|
|
562
|
+
" chunks groups heads_per_group head_dim state_dim,"
|
|
563
|
+
" groups heads_per_group chunks chunk_size"
|
|
564
|
+
" -> chunks chunk_size groups heads_per_group head_dim",
|
|
565
|
+
)
|
|
566
|
+
|
|
567
|
+
y = y_diag + y_off
|
|
568
|
+
if d is not None:
|
|
569
|
+
d_grouped = rearrange(d, "(groups heads_per_group) -> groups heads_per_group", groups=self.num_groups)
|
|
570
|
+
y = y + d_grouped[None, None, :, :, None] * values
|
|
571
|
+
y = rearrange(
|
|
572
|
+
y,
|
|
573
|
+
"chunks chunk_size groups heads_per_group head_dim"
|
|
574
|
+
" -> (chunks chunk_size) (groups heads_per_group) head_dim",
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
if z is not None:
|
|
578
|
+
gate = z + z_bias[None, :, :] if z_bias is not None else z
|
|
579
|
+
y = y * jax.nn.silu(gate)
|
|
403
580
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
581
|
+
y = y[:seq_len]
|
|
582
|
+
|
|
583
|
+
new_states_flat = rearrange(
|
|
584
|
+
new_states,
|
|
585
|
+
"chunks groups heads_per_group head_dim state_dim -> chunks (groups heads_per_group) head_dim state_dim",
|
|
586
|
+
)
|
|
587
|
+
final_state = self._compute_final_state(
|
|
588
|
+
values_orig,
|
|
589
|
+
keys_orig,
|
|
590
|
+
dt_orig,
|
|
591
|
+
new_states_flat,
|
|
592
|
+
num_steps,
|
|
593
|
+
chunk_size,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
return y, final_state
|
|
597
|
+
|
|
598
|
+
def _compute_final_state(
|
|
599
|
+
self,
|
|
600
|
+
values: Float[Array, "suffix_tokens heads head_dim"],
|
|
601
|
+
keys: Float[Array, "suffix_tokens groups state_dim"],
|
|
602
|
+
dt: Float[Array, "suffix_tokens heads"],
|
|
603
|
+
chunk_states: Float[Array, "chunks_plus_1 heads head_dim state_dim"],
|
|
604
|
+
num_steps: Int[Array, ""],
|
|
605
|
+
chunk_size: int,
|
|
606
|
+
) -> Float[Array, "heads head_dim state_dim"]:
|
|
607
|
+
"""Compute the exact final state at position num_steps using precomputed chunk_states."""
|
|
608
|
+
heads_per_group = self.num_heads // self.num_groups
|
|
609
|
+
|
|
610
|
+
chunk_idx = num_steps // chunk_size
|
|
611
|
+
pos_in_chunk = num_steps % chunk_size
|
|
612
|
+
chunk_start_state = jax.lax.dynamic_index_in_dim(chunk_states, chunk_idx, axis=0, keepdims=False)
|
|
613
|
+
|
|
614
|
+
def at_boundary() -> Float[Array, "heads head_dim state_dim"]:
|
|
615
|
+
return chunk_start_state
|
|
616
|
+
|
|
617
|
+
def within_chunk() -> Float[Array, "heads head_dim state_dim"]:
|
|
618
|
+
chunk_start_pos = chunk_idx * chunk_size
|
|
619
|
+
values_chunk = jax.lax.dynamic_slice(
|
|
620
|
+
values,
|
|
621
|
+
(chunk_start_pos, 0, 0),
|
|
622
|
+
(chunk_size, values.shape[1], values.shape[2]),
|
|
408
623
|
)
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
624
|
+
keys_chunk = jax.lax.dynamic_slice(
|
|
625
|
+
keys,
|
|
626
|
+
(chunk_start_pos, 0, 0),
|
|
627
|
+
(chunk_size, keys.shape[1], keys.shape[2]),
|
|
412
628
|
)
|
|
413
|
-
|
|
629
|
+
dt_chunk = jax.lax.dynamic_slice(dt, (chunk_start_pos, 0), (chunk_size, dt.shape[1]))
|
|
414
630
|
|
|
415
|
-
|
|
631
|
+
log_decay_cumsum = jnp.cumsum(-dt_chunk, axis=0)
|
|
632
|
+
last_pos_idx = pos_in_chunk - 1
|
|
633
|
+
log_decay_cumsum_at_last = jax.lax.dynamic_index_in_dim(
|
|
634
|
+
log_decay_cumsum,
|
|
635
|
+
last_pos_idx,
|
|
636
|
+
axis=0,
|
|
637
|
+
keepdims=False,
|
|
638
|
+
)
|
|
416
639
|
|
|
417
|
-
|
|
640
|
+
decayed_start = jnp.exp(log_decay_cumsum_at_last)[:, None, None] * chunk_start_state
|
|
641
|
+
|
|
642
|
+
decay_to_last = jnp.exp(log_decay_cumsum_at_last[None, :] - log_decay_cumsum)
|
|
643
|
+
mask = jnp.arange(chunk_size) <= last_pos_idx
|
|
644
|
+
masked_decay = jnp.where(mask[:, None], decay_to_last, 0.0)
|
|
645
|
+
keys_expanded = jnp.repeat(keys_chunk, heads_per_group, axis=1)
|
|
646
|
+
input_contrib = einsum(
|
|
647
|
+
masked_decay,
|
|
648
|
+
keys_expanded,
|
|
649
|
+
values_chunk,
|
|
650
|
+
"chunk_size heads, chunk_size heads state_dim, chunk_size heads head_dim -> heads head_dim state_dim",
|
|
651
|
+
)
|
|
418
652
|
|
|
419
|
-
|
|
420
|
-
scan_fn,
|
|
421
|
-
(jnp.zeros((), dtype=jnp.int32), initial_state),
|
|
422
|
-
(hidden_states, input_projection, output_projection, time_delta_log),
|
|
423
|
-
)
|
|
653
|
+
return decayed_start + input_contrib
|
|
424
654
|
|
|
425
|
-
return
|
|
655
|
+
return jax.lax.cond(pos_in_chunk == 0, at_boundary, within_chunk)
|
|
426
656
|
|
|
427
657
|
@eqx.filter_jit
|
|
428
658
|
def __call__(
|
|
@@ -436,8 +666,6 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
|
|
|
436
666
|
if positional_embeddings is not None:
|
|
437
667
|
raise ValueError("Positional embeddings are not supported for Mamba2.")
|
|
438
668
|
|
|
439
|
-
conv_inputs, gate_values, time_delta_log = vmap(self.in_projection)(inputs)
|
|
440
|
-
|
|
441
669
|
if state is None:
|
|
442
670
|
state = Mamba2StateLayer.init(
|
|
443
671
|
self.config.kernel_size,
|
|
@@ -449,6 +677,13 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
|
|
|
449
677
|
self.activation_precision,
|
|
450
678
|
)
|
|
451
679
|
|
|
680
|
+
seq_len, _ = inputs.shape
|
|
681
|
+
|
|
682
|
+
if seq_len == 1 and return_updated_state:
|
|
683
|
+
return self._decode_step(inputs, state)
|
|
684
|
+
|
|
685
|
+
conv_inputs, gate_values, time_delta_log = vmap(self.in_projection)(inputs)
|
|
686
|
+
|
|
452
687
|
conv_output, updated_conv_state = self.conv(
|
|
453
688
|
conv_inputs,
|
|
454
689
|
length_without_padding,
|
|
@@ -466,50 +701,55 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
|
|
|
466
701
|
axis=-1,
|
|
467
702
|
)
|
|
468
703
|
|
|
469
|
-
|
|
704
|
+
values = rearrange(
|
|
470
705
|
x_channels,
|
|
471
706
|
"suffix_tokens (heads head_channels) -> suffix_tokens heads head_channels",
|
|
472
707
|
heads=self.num_heads,
|
|
473
708
|
)
|
|
474
|
-
|
|
709
|
+
keys = rearrange(
|
|
475
710
|
input_proj_channels,
|
|
476
711
|
"suffix_tokens (groups state_channels) -> suffix_tokens groups state_channels",
|
|
477
712
|
groups=self.num_groups,
|
|
478
713
|
)
|
|
479
|
-
|
|
714
|
+
queries = rearrange(
|
|
480
715
|
output_proj_channels,
|
|
481
716
|
"suffix_tokens (groups state_channels) -> suffix_tokens groups state_channels",
|
|
482
717
|
groups=self.num_groups,
|
|
483
718
|
)
|
|
484
|
-
time_delta_log = rearrange(
|
|
485
|
-
time_delta_log,
|
|
486
|
-
"suffix_tokens heads -> suffix_tokens heads",
|
|
487
|
-
heads=self.num_heads,
|
|
488
|
-
)
|
|
489
719
|
|
|
490
720
|
if length_without_padding is None:
|
|
491
721
|
length_without_padding, _ = inputs.shape
|
|
492
722
|
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
723
|
+
gate_values_reshaped = rearrange(
|
|
724
|
+
gate_values,
|
|
725
|
+
"suffix_tokens (heads head_channels) -> suffix_tokens heads head_channels",
|
|
726
|
+
heads=self.num_heads,
|
|
727
|
+
)
|
|
728
|
+
gate_bias_reshaped = rearrange(
|
|
729
|
+
self.gate_bias,
|
|
730
|
+
"(heads head_channels) -> heads head_channels",
|
|
731
|
+
heads=self.num_heads,
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
dt = jax.nn.softplus(time_delta_log)
|
|
735
|
+
ssm_outputs, final_ssm_state = self._chunked_scan(
|
|
736
|
+
values,
|
|
737
|
+
keys,
|
|
738
|
+
queries,
|
|
739
|
+
dt,
|
|
498
740
|
state.ssm_state,
|
|
741
|
+
self.config.chunk_size,
|
|
499
742
|
length_without_padding,
|
|
743
|
+
d=self.skip_connection_weight,
|
|
744
|
+
z=gate_values_reshaped,
|
|
745
|
+
z_bias=gate_bias_reshaped,
|
|
500
746
|
)
|
|
501
747
|
|
|
502
|
-
|
|
503
|
-
ssm_outputs = ssm_outputs + skip_contribution
|
|
504
|
-
|
|
505
|
-
ssm_outputs = rearrange(
|
|
748
|
+
ssm_outputs_flat = rearrange(
|
|
506
749
|
ssm_outputs,
|
|
507
750
|
"suffix_tokens heads head_channels -> suffix_tokens (heads head_channels)",
|
|
508
751
|
)
|
|
509
|
-
|
|
510
|
-
gated_outputs = ssm_outputs * jax.nn.silu(gate_values + self.gate_bias)
|
|
511
|
-
|
|
512
|
-
(outputs,) = vmap(self.out_projection)(gated_outputs)
|
|
752
|
+
(outputs,) = vmap(self.out_projection)(ssm_outputs_flat)
|
|
513
753
|
|
|
514
754
|
if return_updated_state:
|
|
515
755
|
assert updated_conv_state is not None
|
lalamo/speculator/__init__.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from .common import Speculator
|
|
2
|
-
from .estimator import estimate_batchsize_from_memory
|
|
3
2
|
from .inference import CollectTracesEvent, inference_collect_traces
|
|
4
3
|
from .ngram import NGramSpeculator
|
|
5
4
|
from .utils import SpeculatorTrainingEvent, train_speculator
|
|
@@ -9,7 +8,6 @@ __all__ = [
|
|
|
9
8
|
"NGramSpeculator",
|
|
10
9
|
"Speculator",
|
|
11
10
|
"SpeculatorTrainingEvent",
|
|
12
|
-
"estimate_batchsize_from_memory",
|
|
13
11
|
"inference_collect_traces",
|
|
14
12
|
"train_speculator",
|
|
15
13
|
]
|