tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (59) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -1
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/test_envs.py +110 -12
  9. tests/test_quantization.py +3 -0
  10. tests/test_utils.py +1 -2
  11. tpu_inference/distributed/tpu_connector.py +1 -1
  12. tpu_inference/envs.py +92 -8
  13. tpu_inference/executors/ray_distributed_executor.py +5 -1
  14. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  17. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  18. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  19. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  20. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  21. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
  22. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
  23. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  25. tpu_inference/layers/common/attention_interface.py +7 -1
  26. tpu_inference/layers/common/sharding.py +11 -7
  27. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  28. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  29. tpu_inference/layers/vllm/fused_moe.py +170 -208
  30. tpu_inference/layers/vllm/linear_common.py +43 -21
  31. tpu_inference/layers/vllm/quantization/common.py +11 -6
  32. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  33. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  34. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  35. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  36. tpu_inference/models/common/model_loader.py +78 -22
  37. tpu_inference/models/jax/deepseek_v3.py +185 -64
  38. tpu_inference/models/jax/gpt_oss.py +3 -3
  39. tpu_inference/models/jax/llama_eagle3.py +4 -5
  40. tpu_inference/models/jax/qwen2_5_vl.py +161 -47
  41. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  42. tpu_inference/models/jax/utils/weight_utils.py +203 -155
  43. tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
  44. tpu_inference/platforms/tpu_platform.py +29 -48
  45. tpu_inference/runner/compilation_manager.py +112 -46
  46. tpu_inference/runner/kv_cache.py +40 -20
  47. tpu_inference/runner/kv_cache_manager.py +40 -31
  48. tpu_inference/runner/persistent_batch_manager.py +40 -2
  49. tpu_inference/runner/structured_decoding_manager.py +2 -3
  50. tpu_inference/runner/tpu_runner.py +94 -51
  51. tpu_inference/runner/utils.py +2 -2
  52. tpu_inference/spec_decode/jax/eagle3.py +71 -22
  53. tpu_inference/utils.py +41 -14
  54. tpu_inference/worker/tpu_worker.py +43 -45
  55. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
  56. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
  57. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  58. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  59. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,13 @@ from jax.sharding import Mesh
11
11
  from jax.sharding import PartitionSpec as P
12
12
 
13
13
  from tpu_inference import utils
14
+ from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention
14
15
  from tpu_inference.kernels.ragged_paged_attention.v3.kernel import \
15
16
  ragged_paged_attention
17
+ from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import \
18
+ get_tuned_block_sizes
16
19
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
+ from tpu_inference.layers.common.sharding import ShardingAxisName
17
21
  from tpu_inference.layers.jax.base import create_param
18
22
  from tpu_inference.layers.jax.layers import RMSNorm
19
23
  from tpu_inference.layers.jax.rope import DeepseekScalingRotaryEmbedding
@@ -66,6 +70,7 @@ class MLA(nnx.Module):
66
70
  rope_input_ordering: str = "split"
67
71
  quant: Any | None = None
68
72
  rope_mscale_all_dim: float = 1.0
73
+ use_mla_kernel: bool = False
69
74
 
70
75
  rngs: InitVar[nnx.Rngs]
71
76
 
@@ -77,10 +82,10 @@ class MLA(nnx.Module):
77
82
  self.N = self.num_attention_heads
78
83
  self.K = self.num_key_value_heads
79
84
  self.D = self.hidden_size
80
-
81
85
  self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
82
86
 
83
- assert self.N == self.K, "N and K must be equal for MLA"
87
+ if not self.use_mla_kernel:
88
+ assert self.N == self.K, "N and K must be equal for MLA"
84
89
 
85
90
  if self.rope_scaling["factor"] <= 1.0:
86
91
  yarn_mscale = 1.0
@@ -122,14 +127,30 @@ class MLA(nnx.Module):
122
127
  self.dtype,
123
128
  random_init=self.random_init,
124
129
  )
125
- self.kernel_kv_up_proj_ANH = create_param(
126
- rngs,
127
- (self.kv_lora_rank, self.N,
128
- self.qk_nope_head_dim + self.v_head_dim),
129
- self.anh_sharding,
130
- self.dtype,
131
- random_init=self.random_init,
132
- )
130
+ if self.use_mla_kernel:
131
+ self.kernel_k_up_proj_ANH = create_param(
132
+ rngs,
133
+ (self.kv_lora_rank, self.N, self.qk_nope_head_dim),
134
+ self.anh_sharding,
135
+ self.dtype,
136
+ random_init=self.random_init,
137
+ )
138
+ self.kernel_v_up_proj_ANH = create_param(
139
+ rngs,
140
+ (self.kv_lora_rank, self.N, self.v_head_dim),
141
+ self.anh_sharding,
142
+ self.dtype,
143
+ random_init=self.random_init,
144
+ )
145
+ else:
146
+ self.kernel_kv_up_proj_ANH = create_param(
147
+ rngs,
148
+ (self.kv_lora_rank, self.N,
149
+ self.qk_nope_head_dim + self.v_head_dim),
150
+ self.anh_sharding,
151
+ self.dtype,
152
+ random_init=self.random_init,
153
+ )
133
154
  self.kernel_o_proj_NHD = create_param(
134
155
  rngs, (self.N, self.v_head_dim, self.D),
135
156
  self.nhd_sharding,
@@ -195,10 +216,16 @@ class MLA(nnx.Module):
195
216
  q_nope_TNH = q_TNH[..., :self.qk_nope_head_dim]
196
217
  q_rope_TNH = q_TNH[..., self.qk_nope_head_dim:]
197
218
  q_rope_TNH = self.rope.apply_rope(md.input_positions, q_rope_TNH)
198
- # Concatenate the nope and rope queries.
199
- q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
200
- # Multiple the query by scaling factor
201
- q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
219
+ if self.use_mla_kernel:
220
+ # Absorb the k up-projection matrix into q
221
+ q_TNA = jnp.einsum("TNH,ANH -> TNA", q_nope_TNH,
222
+ self.kernel_k_up_proj_ANH.value)
223
+ q_TNA = nnx.with_sharding_constraint(q_TNA, self.query_tnh)
224
+ else:
225
+ # Concatenate the nope and rope queries.
226
+ q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
227
+ # Multiply the query by scaling factor
228
+ q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
202
229
 
203
230
  with jax.named_scope("kv_proj"):
204
231
  # KV down projection.
@@ -209,21 +236,27 @@ class MLA(nnx.Module):
209
236
  # Reshape k_rope_BSH to include head dimension for RoPE application
210
237
  k_rope_SNH = k_rope_SH[..., None, :]
211
238
  k_rope_SNH = self.rope.apply_rope(md.input_positions, k_rope_SNH)
212
- k_rope_SNH = jnp.broadcast_to(
213
- k_rope_SNH,
214
- (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
239
+ assert k_rope_SNH.shape[1] == 1
240
+ k_rope_SH = k_rope_SNH[:, 0, :]
241
+
215
242
  kv_SA = kv_SA[..., :self.kv_lora_rank]
216
243
  kv_SA = self.kv_rms_norm(kv_SA)
217
- # KV up projection.
218
- kv_nope_SNH = jnp.einsum("SA,ANH -> SNH", kv_SA,
219
- self.kernel_kv_up_proj_ANH.value)
220
- # Split the latent kv vector into k nope vector and v vector.
221
- k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
222
- v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
223
- # Concatenate the key vector.
224
- k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
225
- k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
226
- v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
244
+ kv_SA = nnx.with_sharding_constraint(kv_SA, self.keyvalue_skh)
245
+
246
+ if not self.use_mla_kernel:
247
+ k_rope_SNH = jnp.broadcast_to(
248
+ k_rope_SNH,
249
+ (k_rope_SNH.shape[0], self.N, self.qk_rope_head_dim))
250
+ # KV up projection.
251
+ kv_nope_SNH = jnp.einsum("SA,ANH -> SNH", kv_SA,
252
+ self.kernel_kv_up_proj_ANH.value)
253
+ # Split the latent kv vector into k nope vector and v vector.
254
+ k_nope_SNH = kv_nope_SNH[..., :self.qk_nope_head_dim]
255
+ v_SNH = kv_nope_SNH[..., self.qk_nope_head_dim:]
256
+ # Concatenate the key vector.
257
+ k_SNH = jnp.concatenate([k_nope_SNH, k_rope_SNH], axis=-1)
258
+ k_SNH = nnx.with_sharding_constraint(k_SNH, self.keyvalue_skh)
259
+ v_SNH = nnx.with_sharding_constraint(v_SNH, self.keyvalue_skh)
227
260
 
228
261
  with jax.named_scope("attn_op"):
229
262
  # TODO(wenxindongwork): K and V have different head dimension,
@@ -234,44 +267,66 @@ class MLA(nnx.Module):
234
267
  # q, k, v head dimension to be multiple of 128. For now, we will
235
268
  # pad the q, k, v dimension to multiple of 128.
236
269
  # We should update the MLA kv cache implementation in the future.
237
- multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
238
- q_TNH = jnp.pad(q_TNH, ((0, 0), (0, 0),
239
- (0, multiple_of_128 - self.qk_head_dim)))
240
- k_SNH = jnp.pad(k_SNH, ((0, 0), (0, 0),
241
- (0, multiple_of_128 - self.qk_head_dim)))
242
- v_SNH = jnp.pad(v_SNH, ((0, 0), (0, 0),
243
- (0, multiple_of_128 - self.v_head_dim)))
270
+ if not self.use_mla_kernel: # MLA kernel handles padding
271
+ multiple_of_128 = ((self.qk_head_dim - 1) // 128 + 1) * 128
272
+ q_TNH = jnp.pad(q_TNH,
273
+ ((0, 0), (0, 0),
274
+ (0, multiple_of_128 - self.qk_head_dim)))
275
+ k_SNH = jnp.pad(k_SNH,
276
+ ((0, 0), (0, 0),
277
+ (0, multiple_of_128 - self.qk_head_dim)))
278
+ v_SNH = jnp.pad(v_SNH,
279
+ ((0, 0), (0, 0),
280
+ (0, multiple_of_128 - self.v_head_dim)))
281
+
244
282
  q_scale = k_scale = v_scale = None
245
- if self.kv_cache_quantized_dtype:
246
- # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
247
- # q_scale = self._q_scale
248
- k_scale = self._k_scale
249
- v_scale = self._v_scale
250
- k_SNH, v_SNH = utils.quantize_kv(k_SNH, v_SNH,
251
- self.kv_cache_quantized_dtype,
252
- k_scale, v_scale)
253
- new_kv_cache, outputs_TNH = self.attention(
254
- is_prefill,
255
- kv_cache,
256
- q_TNH,
257
- k_SNH,
258
- v_SNH,
259
- attention_metadata,
260
- self.mesh,
261
- q_scale,
262
- k_scale,
263
- v_scale,
264
- )
265
- # TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
266
- # We shall add the MLA kv cache implementation in the future.
267
- outputs_TNH = outputs_TNH[..., :self.v_head_dim]
268
283
 
269
- with jax.named_scope("o_proj"):
270
- o_TD = jnp.einsum("TNH,NHD -> TD", outputs_TNH,
271
- self.kernel_o_proj_NHD.value)
272
- o_TD = nnx.with_sharding_constraint(
273
- o_TD, self.activation_attention_out_td)
274
- return new_kv_cache, o_TD
284
+ # TODO(gpolovets): MLA does not currently support quantized KV!
285
+ if not self.use_mla_kernel:
286
+ if self.kv_cache_quantized_dtype:
287
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
288
+ k_scale = self._k_scale
289
+ v_scale = self._v_scale
290
+ k_SNH, v_SNH = utils.quantize_kv(
291
+ k_SNH, v_SNH, self.kv_cache_quantized_dtype, k_scale,
292
+ v_scale)
293
+
294
+ new_kv_cache, outputs_TNH = self.attention(
295
+ is_prefill,
296
+ kv_cache,
297
+ q_TNH,
298
+ k_SNH,
299
+ v_SNH,
300
+ attention_metadata,
301
+ self.mesh,
302
+ q_scale,
303
+ k_scale,
304
+ v_scale,
305
+ )
306
+ # TODO(wenxindongwork): For now, unpad the outputs_TNH to match the v_head_dim.
307
+ # We shall add the MLA kv cache implementation in the future.
308
+ outputs_TNH = outputs_TNH[..., :self.v_head_dim]
309
+
310
+ else:
311
+ new_kv_cache, outputs_TNA = self.mla_attention(
312
+ kv_cache,
313
+ q_TNA,
314
+ q_rope_TNH,
315
+ kv_SA,
316
+ k_rope_SH,
317
+ attention_metadata,
318
+ self.mesh,
319
+ )
320
+ outputs_TNH = jnp.einsum("TNA,ANH -> TNH", outputs_TNA,
321
+ self.kernel_v_up_proj_ANH.value)
322
+
323
+ with jax.named_scope("o_proj"):
324
+ outputs_TNH = nnx.with_sharding_constraint(
325
+ outputs_TNH, self.activation_attention_out_td)
326
+ o_TD = jnp.einsum("TNH,NHD -> TD", outputs_TNH,
327
+ self.kernel_o_proj_NHD.value)
328
+
329
+ return new_kv_cache, o_TD
275
330
 
276
331
  def attention(
277
332
  self,
@@ -326,13 +381,14 @@ class MLA(nnx.Module):
326
381
  out_specs = (self.attn_o_tnh, P(None, None, "model"))
327
382
 
328
383
  def _ragged_paged_attention(*args):
329
- return ragged_paged_attention(
384
+ outputs = ragged_paged_attention(
330
385
  *args,
331
386
  sm_scale=self.scale,
332
387
  q_scale=q_scale,
333
388
  k_scale=k_scale,
334
389
  v_scale=v_scale,
335
390
  )
391
+ return outputs
336
392
 
337
393
  output_TNH, kv_cache = jax.jit(
338
394
  shard_map.shard_map(
@@ -352,3 +408,115 @@ class MLA(nnx.Module):
352
408
  md.request_distribution,
353
409
  )
354
410
  return kv_cache, output_TNH
411
+
412
+ def mla_attention(
413
+ self,
414
+ kv_cache: KVCache,
415
+ q_TNA: jax.Array,
416
+ q_rope_TNH: jax.Array,
417
+ k_SA: jax.Array,
418
+ k_rope_SH: jax.Array,
419
+ attention_metadata: AttentionMetadata,
420
+ mesh: Mesh,
421
+ ) -> Tuple[KVCache, jax.Array]:
422
+ """Performs scaled dot-product attention and updates the KV cache.
423
+
424
+ This function handles the core attention logic, which varies between
425
+ prefill and generation modes. In prefill, it computes self-attention
426
+ over the input sequence with a causal mask. In generation, it attends
427
+ to the full history of keys and values stored in the cache.
428
+
429
+ Args:
430
+ kv_cache: The key-value cache to be updated and used.
431
+ q_TNA: Query tensor of shape `(query_seq, num_attention_heads, lkv_dim)`.
432
+ q_rope_TNH: Query rope tensor of shape `(query_seq, num_attention_heads, rope_dim)`.
433
+ k_SA: Key tensor of shape `(kv_seq, lkv_dim)`.
434
+ k_rope_SH: Key rope tensor of shape `(kv_seq, rope_dim)`.
435
+ attention_metadata: Metadata containing sequence lengths.
436
+ mesh: The JAX device mesh (unused in this specific function but
437
+ kept for potential future use or API consistency).
438
+ q_scale: Quantization scale for q.
439
+ k_scale: Quantization scale for k.
440
+ v_scale: Quantization scale for v.
441
+
442
+ Returns:
443
+ A tuple containing:
444
+ - The updated KV cache.
445
+ - The attention output tensor of shape
446
+ `(seq, num_q_heads, head_dim)`.
447
+ """
448
+ md = attention_metadata
449
+ in_specs = (
450
+ self.query_tnh, # q
451
+ self.query_tnh, # q_rope
452
+ self.keyvalue_skh, # k
453
+ self.keyvalue_skh, # k_rope
454
+ P(ShardingAxisName.MLP_TENSOR), # kv_cache
455
+ P(ShardingAxisName.ATTN_DATA), # md.seq_lens: Replicated
456
+ P(ShardingAxisName.ATTN_DATA), # page_indices_flat: Replicated
457
+ P(ShardingAxisName.ATTN_DATA), # query_start_loc: Replicated
458
+ P(ShardingAxisName.ATTN_DATA), # distribution: Replicated
459
+ )
460
+
461
+ out_specs = (self.attn_o_tnh, P(ShardingAxisName.MLP_TENSOR))
462
+
463
+ def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args):
464
+
465
+ def _initialize_block_sizes():
466
+ # Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes)
467
+ # Referring to get_tuned_block_sizes() in kernels/ragged_paged_attention/v3/tuned_block_sizes.py: 'TPU v7'/128/'q_bfloat16_kv_bfloat16/q_head-128_kv_head-1_head-128'/4096
468
+ max_num_tokens = q.shape[0]
469
+ max_num_seqs = md.seq_lens.shape[0]
470
+ num_page_indices = md.block_tables.shape[0]
471
+ assert num_page_indices % max_num_seqs == 0
472
+ pages_per_seq = num_page_indices // max_num_seqs
473
+ # num_kv_pages_per_block = min(pages_per_seq, 16)
474
+ bkv_p, bq_sz = get_tuned_block_sizes(
475
+ q.dtype,
476
+ kv_cache.dtype,
477
+ self.num_attention_heads,
478
+ 1,
479
+ self.qk_nope_head_dim,
480
+ kv_cache.shape[1], # page size
481
+ max_num_tokens,
482
+ pages_per_seq,
483
+ )
484
+ num_kv_pages_per_block = min(min(pages_per_seq, bkv_p), 4)
485
+ num_queries_per_block = min(min(max_num_tokens, bq_sz),
486
+ 4) # OOMS at 8
487
+ return num_kv_pages_per_block, num_queries_per_block
488
+
489
+ num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes(
490
+ )
491
+ output, kv_cache = mla_ragged_paged_attention(
492
+ q,
493
+ q_rope,
494
+ k,
495
+ k_rope,
496
+ kv_cache,
497
+ *args,
498
+ sm_scale=self.scale,
499
+ num_kv_pages_per_block=num_kv_pages_per_block,
500
+ num_queries_per_block=num_queries_per_block)
501
+
502
+ return kv_cache, output
503
+
504
+ kv_cache, output_TNH = jax.jit(
505
+ shard_map.shard_map(
506
+ _mla_ragged_paged_attention,
507
+ mesh=mesh,
508
+ in_specs=in_specs,
509
+ out_specs=out_specs,
510
+ check_rep=False,
511
+ ), )(
512
+ q_TNA,
513
+ q_rope_TNH,
514
+ k_SA,
515
+ k_rope_SH,
516
+ kv_cache,
517
+ md.seq_lens,
518
+ md.block_tables,
519
+ md.query_start_loc,
520
+ md.request_distribution,
521
+ )
522
+ return kv_cache, output_TNH
@@ -158,17 +158,17 @@ class GptOssAttention(nnx.Module):
158
158
  ) -> Tuple[KVCache, jax.Array]:
159
159
  """Performs scaled dot-product attention by calling the ragged_paged_attention kernel."""
160
160
  md = attention_metadata
161
- kv_cache_spec = P(None, None, "model")
161
+ kv_cache_spec = P("data", None, "model")
162
162
 
163
163
  in_specs = (
164
164
  self.query_tnh, # q
165
165
  self.keyvalue_skh, # k
166
166
  self.keyvalue_skh, # v
167
167
  kv_cache_spec, # kv_cache
168
- P(), # md.seq_lens: Replicated
169
- P(), # page_indices_flat: Replicated
170
- P(), # query_start_loc: Replicated
171
- P(), # distribution: Replicated
168
+ P("data"), # md.seq_lens
169
+ P("data"), # page_indices_flat
170
+ P("data"), # query_start_loc
171
+ P("data"), # distribution
172
172
  P(('model')), # sinks
173
173
  )
174
174
  out_specs = (self.attn_o_tnh, kv_cache_spec)