lalamo 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,9 +25,65 @@ __all__ = [
25
25
  "Mamba2Result",
26
26
  "SeparableCausalConv",
27
27
  "SeparableCausalConvConfig",
28
+ "exp_segsum",
29
+ "fused_ssd_intra_chunk",
28
30
  ]
29
31
 
30
32
 
33
+ def exp_segsum(x: Float[Array, "... T"]) -> Float[Array, "... T T"]:
34
+ """Compute exp(segsum(x)) as lower-triangular matrix using cumsum difference."""
35
+ seq_len = x.shape[-1]
36
+ cs = jnp.cumsum(x, axis=-1)
37
+ diff = cs[..., :, None] - cs[..., None, :]
38
+ mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))
39
+ return jnp.where(mask, jnp.exp(diff), 0.0)
40
+
41
+
42
+ def fused_ssd_intra_chunk(
43
+ a_cumsum: Float[Array, "groups heads_per_group chunks chunk_size"],
44
+ cb: Float[Array, "chunks chunk_size chunk_size groups"],
45
+ x: Float[Array, "chunks chunk_size groups heads_per_group head_dim"],
46
+ ) -> Float[Array, "chunks chunk_size groups heads_per_group head_dim"]:
47
+ """Compute intra-chunk diagonal block outputs for SSD.
48
+
49
+ Avoids materializing the full global L matrix by computing decay locally per (chunk, group, head).
50
+ """
51
+ groups, heads_per_group, chunks, chunk_size = a_cumsum.shape
52
+
53
+ def compute_one(
54
+ a_cs: Float[Array, " chunk_size"],
55
+ cb_slice: Float[Array, "chunk_size chunk_size"],
56
+ x_slice: Float[Array, "chunk_size head_dim"],
57
+ ) -> Float[Array, "chunk_size head_dim"]:
58
+ diff = a_cs[:, None] - a_cs[None, :]
59
+ mask = jnp.tril(jnp.ones((chunk_size, chunk_size), dtype=jnp.bool_))
60
+ decay_local = jnp.where(mask, jnp.exp(diff), 0.0)
61
+ weighted = decay_local * cb_slice
62
+ return weighted @ x_slice
63
+
64
+ def compute_chunk_group_head(chunk_idx: int, group_idx: int, head_idx: int) -> Float[Array, "chunk_size head_dim"]:
65
+ return compute_one(
66
+ a_cumsum[group_idx, head_idx, chunk_idx, :],
67
+ cb[chunk_idx, :, :, group_idx],
68
+ x[chunk_idx, :, group_idx, head_idx, :],
69
+ )
70
+
71
+ def over_heads(chunk_idx: int, group_idx: int) -> Float[Array, "heads_per_group chunk_size head_dim"]:
72
+ return jax.vmap(lambda head_idx: compute_chunk_group_head(chunk_idx, group_idx, head_idx))(
73
+ jnp.arange(heads_per_group),
74
+ )
75
+
76
+ def over_groups(chunk_idx: int) -> Float[Array, "groups heads_per_group chunk_size head_dim"]:
77
+ return jax.vmap(lambda group_idx: over_heads(chunk_idx, group_idx))(jnp.arange(groups))
78
+
79
+ result = jax.vmap(over_groups)(jnp.arange(chunks))
80
+
81
+ return rearrange(
82
+ result,
83
+ "chunks groups heads_per_group chunk_size head_dim -> chunks chunk_size groups heads_per_group head_dim",
84
+ )
85
+
86
+
31
87
  Mamba2Result = TokenMixerResult[Mamba2StateLayer]
32
88
 
33
89
 
@@ -156,6 +212,19 @@ class SeparableCausalConv(LalamoModule[SeparableCausalConvConfig]):
156
212
  updated_state,
157
213
  )
158
214
 
215
+ def step(
216
+ self,
217
+ token: Float[Array, " channels"],
218
+ state: Float[Array, "kernel_minus_1 channels"],
219
+ ) -> tuple[Float[Array, " channels"], Float[Array, "kernel_minus_1 channels"]]:
220
+ """Single-token conv update without full convolution overhead."""
221
+ full_input = jnp.concatenate([state, token[None, :]], axis=0)
222
+ output = einsum(full_input, self.weights, "kernel channels, channels kernel -> channels")
223
+ if self.biases is not None:
224
+ output = output + self.biases
225
+ new_state = jnp.concatenate([state[1:], token[None, :]], axis=0)
226
+ return output, new_state
227
+
159
228
  def export_weights(self) -> ParameterTree:
160
229
  result: dict[str, Array] = {"weights": self.weights}
161
230
  if self.biases is not None:
@@ -188,6 +257,8 @@ class Mamba2Config(TokenMixerConfigBase):
188
257
  has_in_biases: bool
189
258
  has_out_biases: bool
190
259
 
260
+ chunk_size: int = 256
261
+
191
262
  @property
192
263
  def inner_dim(self) -> int:
193
264
  return self.num_heads * self.head_dim
@@ -330,99 +401,258 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
330
401
  f"Number of value heads ({self.num_heads}) must be divisible by number of groups ({self.num_groups})",
331
402
  )
332
403
 
333
- def _scan(
404
+ def _step(
405
+ self,
406
+ values: Float[Array, "heads head_dim"],
407
+ keys: Float[Array, "groups state_dim"],
408
+ queries: Float[Array, "groups state_dim"],
409
+ dt_log: Float[Array, " heads"],
410
+ state: Float[Array, "heads head_dim state_dim"],
411
+ ) -> tuple[Float[Array, "heads head_dim"], Float[Array, "heads head_dim state_dim"]]:
412
+ """Single-token SSM state update without scan overhead."""
413
+ heads_per_group = self.num_heads // self.num_groups
414
+
415
+ dt = jax.nn.softplus(dt_log)
416
+ decay = jnp.exp(-dt)[:, None, None]
417
+ mix = dt[:, None, None]
418
+
419
+ keys_expanded = jnp.repeat(keys, heads_per_group, axis=0)
420
+ queries_expanded = jnp.repeat(queries, heads_per_group, axis=0)
421
+ values_norm = values / (dt[:, None] + 1e-8)
422
+
423
+ input_contribution = mix * values_norm[:, :, None] * keys_expanded[:, None, :]
424
+ new_state = decay * state + input_contribution
425
+ output = einsum(new_state, queries_expanded, "heads head_dim state_dim, heads state_dim -> heads head_dim")
426
+
427
+ return output, new_state
428
+
429
+ def _decode_step(
334
430
  self,
335
- hidden_states: Float[Array, "suffix_tokens heads head_channels"],
336
- input_projection: Float[Array, "suffix_tokens groups state_channels"],
337
- output_projection: Float[Array, "suffix_tokens groups state_channels"],
338
- time_delta_log: Float[Array, "suffix_tokens heads"],
339
- initial_state: Float[Array, "heads head_channels state_channels"],
431
+ inputs: Float[Array, "1 channels"],
432
+ state: Mamba2StateLayer,
433
+ ) -> Mamba2Result:
434
+ """Optimized path for single-token decode without scan machinery."""
435
+ token = inputs[0]
436
+
437
+ conv_in, gate, dt_log = self.in_projection(token)
438
+ conv_out, new_conv_state = self.conv.step(conv_in, state.conv_state)
439
+ conv_activated = self.config.activation(conv_out)
440
+
441
+ values_flat, input_proj_flat, output_proj_flat = jnp.split(
442
+ conv_activated,
443
+ [self.inner_dim, self.inner_dim + self.num_groups * self.state_dim],
444
+ )
445
+ values = rearrange(values_flat, "(heads head_dim) -> heads head_dim", heads=self.num_heads)
446
+ keys = rearrange(input_proj_flat, "(groups state_dim) -> groups state_dim", groups=self.num_groups)
447
+ queries = rearrange(output_proj_flat, "(groups state_dim) -> groups state_dim", groups=self.num_groups)
448
+
449
+ y, new_ssm_state = self._step(values, keys, queries, dt_log, state.ssm_state)
450
+
451
+ y = y + self.skip_connection_weight[:, None] * values
452
+ y = rearrange(y, "heads head_dim -> (heads head_dim)")
453
+ gated = y * jax.nn.silu(gate + self.gate_bias)
454
+ (output,) = self.out_projection(gated)
455
+
456
+ return Mamba2Result(
457
+ outputs=output[None, :],
458
+ state=Mamba2StateLayer(new_conv_state, new_ssm_state),
459
+ )
460
+
461
+ def _chunked_scan(
462
+ self,
463
+ values: Float[Array, "suffix_tokens heads head_dim"],
464
+ keys: Float[Array, "suffix_tokens groups state_dim"],
465
+ queries: Float[Array, "suffix_tokens groups state_dim"],
466
+ dt: Float[Array, "suffix_tokens heads"],
467
+ initial_state: Float[Array, "heads head_dim state_dim"],
468
+ chunk_size: int,
340
469
  num_steps: Int[Array, ""] | int,
341
- ) -> tuple[
342
- Float[Array, "suffix_tokens heads head_channels"],
343
- Float[Array, "heads head_channels state_channels"],
344
- ]:
345
- def scan_fn(
346
- index_and_carry_state: tuple[Int[Array, ""], Float[Array, "heads head_channels state_channels"]],
347
- step_inputs: tuple[
348
- Float[Array, "heads head_channels"],
349
- Float[Array, "groups state_channels"],
350
- Float[Array, "groups state_channels"],
351
- Float[Array, " heads"],
352
- ],
353
- ) -> tuple[
354
- tuple[Int[Array, ""], Float[Array, "heads head_channels state_channels"]],
355
- Float[Array, "heads head_channels"],
356
- ]:
357
- index, carry_state = index_and_carry_state
358
- hidden_state_t, input_proj_t, output_proj_t, time_delta_log_t = step_inputs
359
- dt = jax.nn.softplus(time_delta_log_t)[:, None]
360
- heads_per_group = self.num_heads // self.num_groups
361
-
362
- hidden_grouped = rearrange(
363
- hidden_state_t,
364
- "(groups heads) head_channels -> groups heads head_channels",
365
- groups=self.num_groups,
366
- heads=heads_per_group,
367
- )
368
- x_norm_grouped = hidden_grouped / (
369
- dt.reshape(self.num_heads)[
370
- rearrange(
371
- jnp.arange(self.num_heads),
372
- "(groups heads)-> groups heads",
373
- groups=self.num_groups,
374
- heads=heads_per_group,
375
- )
376
- ][:, :, None]
377
- + 1e-8
378
- )
470
+ d: Float[Array, " heads"] | None = None,
471
+ z: Float[Array, "suffix_tokens heads head_dim"] | None = None,
472
+ z_bias: Float[Array, "heads head_dim"] | None = None,
473
+ ) -> tuple[Float[Array, "suffix_tokens heads head_dim"], Float[Array, "heads head_dim state_dim"]]:
474
+ """Chunked parallel scan implementing the SSD algorithm."""
475
+ seq_len = values.shape[0]
476
+ num_steps = jnp.asarray(num_steps, dtype=jnp.int32)
477
+
478
+ pad_len = (chunk_size - seq_len % chunk_size) % chunk_size
479
+ if pad_len > 0:
480
+ values = jnp.pad(values, ((0, pad_len), (0, 0), (0, 0)))
481
+ keys = jnp.pad(keys, ((0, pad_len), (0, 0), (0, 0)))
482
+ queries = jnp.pad(queries, ((0, pad_len), (0, 0), (0, 0)))
483
+ dt = jnp.pad(dt, ((0, pad_len), (0, 0)))
484
+ if z is not None:
485
+ z = jnp.pad(z, ((0, pad_len), (0, 0), (0, 0)))
486
+
487
+ values_orig = values
488
+ keys_orig = keys
489
+ dt_orig = dt
490
+
491
+ padded_len = values.shape[0]
492
+ position_indices = jnp.arange(padded_len)
493
+ valid_mask = (position_indices < num_steps).astype(values.dtype)
494
+ values = values * valid_mask[:, None, None]
495
+ keys = keys * valid_mask[:, None, None]
496
+
497
+ values = rearrange(
498
+ values,
499
+ "(chunks chunk_size) (groups heads_per_group) head_dim"
500
+ " -> chunks chunk_size groups heads_per_group head_dim",
501
+ chunk_size=chunk_size,
502
+ groups=self.num_groups,
503
+ )
504
+ log_decay = rearrange(
505
+ -dt,
506
+ "(chunks chunk_size) (groups heads_per_group) -> groups heads_per_group chunks chunk_size",
507
+ chunk_size=chunk_size,
508
+ groups=self.num_groups,
509
+ )
510
+ keys_chunked = rearrange(
511
+ keys,
512
+ "(chunks chunk_size) groups state_dim -> chunks chunk_size groups state_dim",
513
+ chunk_size=chunk_size,
514
+ )
515
+ queries_chunked = rearrange(
516
+ queries,
517
+ "(chunks chunk_size) groups state_dim -> chunks chunk_size groups state_dim",
518
+ chunk_size=chunk_size,
519
+ )
520
+ log_decay_cumsum = jnp.cumsum(log_decay, axis=-1)
379
521
 
380
- decay = jnp.exp(-dt)[:, :, None]
381
- mix = dt[:, :, None]
382
- decay_group = rearrange(
383
- decay,
384
- "(groups heads) 1 1 -> groups heads 1 1",
385
- groups=self.num_groups,
386
- heads=heads_per_group,
387
- )
388
- mix_group = rearrange(
389
- mix,
390
- "(groups heads) 1 1 -> groups heads 1 1",
391
- groups=self.num_groups,
392
- heads=heads_per_group,
393
- )
522
+ queries_keys_prod = einsum(
523
+ queries_chunked,
524
+ keys_chunked,
525
+ "chunks query_pos groups state_dim, chunks key_pos groups state_dim -> chunks query_pos key_pos groups",
526
+ )
527
+ y_diag = fused_ssd_intra_chunk(log_decay_cumsum, queries_keys_prod, values)
528
+
529
+ decay_states = jnp.exp(log_decay_cumsum[:, :, :, -1:] - log_decay_cumsum)
530
+ states = einsum(
531
+ keys_chunked,
532
+ decay_states,
533
+ values,
534
+ "chunks chunk_size groups state_dim, groups heads_per_group chunks chunk_size,"
535
+ " chunks chunk_size groups heads_per_group head_dim"
536
+ " -> chunks groups heads_per_group head_dim state_dim",
537
+ )
394
538
 
395
- input_contribution_group = mix_group * x_norm_grouped[:, :, :, None] * input_proj_t[:, None, None, :]
396
- carry_state_group = rearrange(
397
- carry_state,
398
- "(groups heads) head_channels state_channels -> groups heads head_channels state_channels",
399
- groups=self.num_groups,
400
- heads=heads_per_group,
401
- )
402
- updated_state_group = decay_group * carry_state_group + input_contribution_group
539
+ initial_state_grouped = rearrange(
540
+ initial_state,
541
+ "(groups heads_per_group) head_dim state_dim -> groups heads_per_group head_dim state_dim",
542
+ groups=self.num_groups,
543
+ )
544
+ states = jnp.concatenate([initial_state_grouped[None, ...], states], axis=0)
545
+ log_decay_chunk_ends = jnp.pad(log_decay_cumsum[:, :, :, -1], ((0, 0), (0, 0), (1, 0)))
546
+ decay_chunk = exp_segsum(log_decay_chunk_ends)
547
+ new_states = einsum(
548
+ decay_chunk,
549
+ states,
550
+ "groups heads_per_group out_idx chunks,"
551
+ " chunks groups heads_per_group head_dim state_dim"
552
+ " -> out_idx groups heads_per_group head_dim state_dim",
553
+ )
554
+ states = new_states[:-1]
555
+
556
+ state_decay_out = jnp.exp(log_decay_cumsum)
557
+ y_off = einsum(
558
+ queries_chunked,
559
+ states,
560
+ state_decay_out,
561
+ "chunks chunk_size groups state_dim,"
562
+ " chunks groups heads_per_group head_dim state_dim,"
563
+ " groups heads_per_group chunks chunk_size"
564
+ " -> chunks chunk_size groups heads_per_group head_dim",
565
+ )
566
+
567
+ y = y_diag + y_off
568
+ if d is not None:
569
+ d_grouped = rearrange(d, "(groups heads_per_group) -> groups heads_per_group", groups=self.num_groups)
570
+ y = y + d_grouped[None, None, :, :, None] * values
571
+ y = rearrange(
572
+ y,
573
+ "chunks chunk_size groups heads_per_group head_dim"
574
+ " -> (chunks chunk_size) (groups heads_per_group) head_dim",
575
+ )
576
+
577
+ if z is not None:
578
+ gate = z + z_bias[None, :, :] if z_bias is not None else z
579
+ y = y * jax.nn.silu(gate)
403
580
 
404
- output_group = einsum(
405
- updated_state_group,
406
- output_proj_t,
407
- "groups heads head_channels state_channels, groups state_channels -> groups heads head_channels",
581
+ y = y[:seq_len]
582
+
583
+ new_states_flat = rearrange(
584
+ new_states,
585
+ "chunks groups heads_per_group head_dim state_dim -> chunks (groups heads_per_group) head_dim state_dim",
586
+ )
587
+ final_state = self._compute_final_state(
588
+ values_orig,
589
+ keys_orig,
590
+ dt_orig,
591
+ new_states_flat,
592
+ num_steps,
593
+ chunk_size,
594
+ )
595
+
596
+ return y, final_state
597
+
598
+ def _compute_final_state(
599
+ self,
600
+ values: Float[Array, "suffix_tokens heads head_dim"],
601
+ keys: Float[Array, "suffix_tokens groups state_dim"],
602
+ dt: Float[Array, "suffix_tokens heads"],
603
+ chunk_states: Float[Array, "chunks_plus_1 heads head_dim state_dim"],
604
+ num_steps: Int[Array, ""],
605
+ chunk_size: int,
606
+ ) -> Float[Array, "heads head_dim state_dim"]:
607
+ """Compute the exact final state at position num_steps using precomputed chunk_states."""
608
+ heads_per_group = self.num_heads // self.num_groups
609
+
610
+ chunk_idx = num_steps // chunk_size
611
+ pos_in_chunk = num_steps % chunk_size
612
+ chunk_start_state = jax.lax.dynamic_index_in_dim(chunk_states, chunk_idx, axis=0, keepdims=False)
613
+
614
+ def at_boundary() -> Float[Array, "heads head_dim state_dim"]:
615
+ return chunk_start_state
616
+
617
+ def within_chunk() -> Float[Array, "heads head_dim state_dim"]:
618
+ chunk_start_pos = chunk_idx * chunk_size
619
+ values_chunk = jax.lax.dynamic_slice(
620
+ values,
621
+ (chunk_start_pos, 0, 0),
622
+ (chunk_size, values.shape[1], values.shape[2]),
408
623
  )
409
- updated_state = rearrange(
410
- updated_state_group,
411
- "groups heads head_channels state_channels -> (groups heads) head_channels state_channels",
624
+ keys_chunk = jax.lax.dynamic_slice(
625
+ keys,
626
+ (chunk_start_pos, 0, 0),
627
+ (chunk_size, keys.shape[1], keys.shape[2]),
412
628
  )
413
- output_t = rearrange(output_group, "groups heads head_channels -> (groups heads) head_channels")
629
+ dt_chunk = jax.lax.dynamic_slice(dt, (chunk_start_pos, 0), (chunk_size, dt.shape[1]))
414
630
 
415
- propagated_state = jax.lax.cond(index < num_steps, lambda: updated_state, lambda: carry_state)
631
+ log_decay_cumsum = jnp.cumsum(-dt_chunk, axis=0)
632
+ last_pos_idx = pos_in_chunk - 1
633
+ log_decay_cumsum_at_last = jax.lax.dynamic_index_in_dim(
634
+ log_decay_cumsum,
635
+ last_pos_idx,
636
+ axis=0,
637
+ keepdims=False,
638
+ )
416
639
 
417
- return (index + 1, propagated_state), output_t
640
+ decayed_start = jnp.exp(log_decay_cumsum_at_last)[:, None, None] * chunk_start_state
641
+
642
+ decay_to_last = jnp.exp(log_decay_cumsum_at_last[None, :] - log_decay_cumsum)
643
+ mask = jnp.arange(chunk_size) <= last_pos_idx
644
+ masked_decay = jnp.where(mask[:, None], decay_to_last, 0.0)
645
+ keys_expanded = jnp.repeat(keys_chunk, heads_per_group, axis=1)
646
+ input_contrib = einsum(
647
+ masked_decay,
648
+ keys_expanded,
649
+ values_chunk,
650
+ "chunk_size heads, chunk_size heads state_dim, chunk_size heads head_dim -> heads head_dim state_dim",
651
+ )
418
652
 
419
- (_, final_state), outputs = jax.lax.scan(
420
- scan_fn,
421
- (jnp.zeros((), dtype=jnp.int32), initial_state),
422
- (hidden_states, input_projection, output_projection, time_delta_log),
423
- )
653
+ return decayed_start + input_contrib
424
654
 
425
- return outputs, final_state
655
+ return jax.lax.cond(pos_in_chunk == 0, at_boundary, within_chunk)
426
656
 
427
657
  @eqx.filter_jit
428
658
  def __call__(
@@ -436,8 +666,6 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
436
666
  if positional_embeddings is not None:
437
667
  raise ValueError("Positional embeddings are not supported for Mamba2.")
438
668
 
439
- conv_inputs, gate_values, time_delta_log = vmap(self.in_projection)(inputs)
440
-
441
669
  if state is None:
442
670
  state = Mamba2StateLayer.init(
443
671
  self.config.kernel_size,
@@ -449,6 +677,13 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
449
677
  self.activation_precision,
450
678
  )
451
679
 
680
+ seq_len, _ = inputs.shape
681
+
682
+ if seq_len == 1 and return_updated_state:
683
+ return self._decode_step(inputs, state)
684
+
685
+ conv_inputs, gate_values, time_delta_log = vmap(self.in_projection)(inputs)
686
+
452
687
  conv_output, updated_conv_state = self.conv(
453
688
  conv_inputs,
454
689
  length_without_padding,
@@ -466,50 +701,55 @@ class Mamba2(TokenMixerBase[Mamba2Config, Mamba2StateLayer]):
466
701
  axis=-1,
467
702
  )
468
703
 
469
- hidden_states = rearrange(
704
+ values = rearrange(
470
705
  x_channels,
471
706
  "suffix_tokens (heads head_channels) -> suffix_tokens heads head_channels",
472
707
  heads=self.num_heads,
473
708
  )
474
- input_projection = rearrange(
709
+ keys = rearrange(
475
710
  input_proj_channels,
476
711
  "suffix_tokens (groups state_channels) -> suffix_tokens groups state_channels",
477
712
  groups=self.num_groups,
478
713
  )
479
- output_projection = rearrange(
714
+ queries = rearrange(
480
715
  output_proj_channels,
481
716
  "suffix_tokens (groups state_channels) -> suffix_tokens groups state_channels",
482
717
  groups=self.num_groups,
483
718
  )
484
- time_delta_log = rearrange(
485
- time_delta_log,
486
- "suffix_tokens heads -> suffix_tokens heads",
487
- heads=self.num_heads,
488
- )
489
719
 
490
720
  if length_without_padding is None:
491
721
  length_without_padding, _ = inputs.shape
492
722
 
493
- ssm_outputs, final_ssm_state = self._scan(
494
- hidden_states,
495
- input_projection,
496
- output_projection,
497
- time_delta_log,
723
+ gate_values_reshaped = rearrange(
724
+ gate_values,
725
+ "suffix_tokens (heads head_channels) -> suffix_tokens heads head_channels",
726
+ heads=self.num_heads,
727
+ )
728
+ gate_bias_reshaped = rearrange(
729
+ self.gate_bias,
730
+ "(heads head_channels) -> heads head_channels",
731
+ heads=self.num_heads,
732
+ )
733
+
734
+ dt = jax.nn.softplus(time_delta_log)
735
+ ssm_outputs, final_ssm_state = self._chunked_scan(
736
+ values,
737
+ keys,
738
+ queries,
739
+ dt,
498
740
  state.ssm_state,
741
+ self.config.chunk_size,
499
742
  length_without_padding,
743
+ d=self.skip_connection_weight,
744
+ z=gate_values_reshaped,
745
+ z_bias=gate_bias_reshaped,
500
746
  )
501
747
 
502
- skip_contribution = self.skip_connection_weight[None, :, None] * hidden_states
503
- ssm_outputs = ssm_outputs + skip_contribution
504
-
505
- ssm_outputs = rearrange(
748
+ ssm_outputs_flat = rearrange(
506
749
  ssm_outputs,
507
750
  "suffix_tokens heads head_channels -> suffix_tokens (heads head_channels)",
508
751
  )
509
-
510
- gated_outputs = ssm_outputs * jax.nn.silu(gate_values + self.gate_bias)
511
-
512
- (outputs,) = vmap(self.out_projection)(gated_outputs)
752
+ (outputs,) = vmap(self.out_projection)(ssm_outputs_flat)
513
753
 
514
754
  if return_updated_state:
515
755
  assert updated_conv_state is not None
@@ -1,5 +1,4 @@
1
1
  from .common import Speculator
2
- from .estimator import estimate_batchsize_from_memory
3
2
  from .inference import CollectTracesEvent, inference_collect_traces
4
3
  from .ngram import NGramSpeculator
5
4
  from .utils import SpeculatorTrainingEvent, train_speculator
@@ -9,7 +8,6 @@ __all__ = [
9
8
  "NGramSpeculator",
10
9
  "Speculator",
11
10
  "SpeculatorTrainingEvent",
12
- "estimate_batchsize_from_memory",
13
11
  "inference_collect_traces",
14
12
  "train_speculator",
15
13
  ]