litert-torch-nightly 0.9.0.dev20260202__py3-none-any.whl → 0.9.0.dev20260203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. litert_torch/generative/export_hf/core/attention.py +86 -8
  2. litert_torch/generative/export_hf/core/attention_test.py +7 -2
  3. litert_torch/generative/export_hf/core/cache.py +112 -64
  4. litert_torch/generative/export_hf/core/cache_base.py +19 -2
  5. litert_torch/generative/export_hf/core/export_lib.py +55 -6
  6. litert_torch/generative/export_hf/core/exportable_module.py +30 -34
  7. litert_torch/generative/export_hf/core/exportable_module_config.py +39 -0
  8. litert_torch/generative/export_hf/core/split_cache/attention.py +28 -5
  9. litert_torch/generative/export_hf/core/split_cache/cache.py +113 -33
  10. litert_torch/generative/export_hf/core/split_cache/exportable_module.py +21 -14
  11. litert_torch/generative/export_hf/export.py +35 -2
  12. litert_torch/version.py +1 -1
  13. {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/METADATA +1 -1
  14. {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/RECORD +18 -17
  15. {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/WHEEL +0 -0
  16. {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/entry_points.txt +0 -0
  17. {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/licenses/LICENSE +0 -0
  18. {litert_torch_nightly-0.9.0.dev20260202.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/top_level.txt +0 -0
@@ -15,38 +15,35 @@
15
15
  """Exportable modules."""
16
16
 
17
17
  import abc
18
- import dataclasses
19
18
  from litert_torch.generative.export_hf.core import cache as _
20
19
  from litert_torch.generative.export_hf.core import cache_base as kv_cache_lib
20
+ from litert_torch.generative.export_hf.core import exportable_module_config
21
21
  from litert_torch.generative.export_hf.core import utils
22
22
  import torch
23
23
 
24
24
 
25
- @dataclasses.dataclass
26
- class ExportableModuleConfig:
27
- """Config for exportable modules."""
25
+ ExportableModuleConfig = exportable_module_config.ExportableModuleConfig
28
26
 
29
- batch_size: int = 1
30
- cache_length: int = 1280
31
- prefill_lengths: list[int] = dataclasses.field(default_factory=lambda: [128])
32
- # For dynamic shape
33
- cache_length_dim: torch.export.Dim | None = None
34
- prefill_length_dim: torch.export.Dim | None = None
35
27
 
36
- # Export configs
37
- externalize_embedder: bool = False
38
- externalize_rope: bool = False
39
- split_cache: bool = False
28
+ class ExportableModuleBase(torch.nn.Module, abc.ABC):
29
+ """Base class for exportable modules."""
40
30
 
41
- cache_implementation: str = "LiteRTLMCache"
31
+ def __init__(self, export_config: ExportableModuleConfig):
32
+ super().__init__()
33
+ self._export_config = export_config
42
34
 
35
+ @property
36
+ def export_config(self) -> ExportableModuleConfig:
37
+ return self._export_config
43
38
 
44
- class ExportableModuleBase(torch.nn.Module, abc.ABC):
45
- """Base class for exportable modules."""
39
+ def attention_kwargs(self):
40
+ k_ts_idx = self.export_config.k_ts_idx
41
+ v_ts_idx = self.export_config.v_ts_idx
42
+ return {"k_ts_idx": k_ts_idx, "v_ts_idx": v_ts_idx}
46
43
 
47
44
  @abc.abstractmethod
48
45
  def get_sample_inputs(
49
- self, model_config, export_config: ExportableModuleConfig
46
+ self, model_config
50
47
  ) -> dict[str, tuple[dict[str, torch.Tensor], dict[str, torch.export.Dim]]]:
51
48
  """Returns the sample inputs for the model."""
52
49
  ...
@@ -55,8 +52,10 @@ class ExportableModuleBase(torch.nn.Module, abc.ABC):
55
52
  class LiteRTExportableModuleForDecoderOnlyLM(ExportableModuleBase):
56
53
  """Base class for exportable modules for decoder-only LM."""
57
54
 
58
- def __init__(self, model: torch.nn.Module):
59
- super().__init__()
55
+ def __init__(
56
+ self, model: torch.nn.Module, export_config: ExportableModuleConfig
57
+ ):
58
+ super().__init__(export_config)
60
59
  self.model = model
61
60
 
62
61
  def adapt_inputs(
@@ -108,16 +107,13 @@ class LiteRTExportableModuleForDecoderOnlyLM(ExportableModuleBase):
108
107
  })
109
108
  return ret
110
109
 
111
- def get_sample_kv_cache(
112
- self, model_config, export_config: ExportableModuleConfig
113
- ):
110
+ def get_sample_kv_cache(self, model_config):
114
111
  """Returns the input sample KV cache for the model."""
112
+ export_config = self.export_config
115
113
  num_layers = model_config.num_hidden_layers
116
- batch_size = export_config.batch_size
117
- cache_length = export_config.cache_length
118
114
  kv_cache = kv_cache_lib.CACHE_REGISTRY[
119
115
  export_config.cache_implementation
120
- ].create_from_config(model_config, cache_length, batch_size)
116
+ ].create_from_config(model_config, export_config)
121
117
  inputs = {"kv_cache": kv_cache}
122
118
  if export_config.cache_length_dim is not None:
123
119
  all_k_shapes = tuple(
@@ -150,6 +146,7 @@ class LiteRTExportableModuleForDecoderOnlyLMPrefill(
150
146
  mask,
151
147
  ):
152
148
  inputs = self.adapt_inputs(tokens, None, input_pos, kv_cache, mask)
149
+ inputs |= self.attention_kwargs()
153
150
  output = self.model(**inputs)
154
151
  return {"kv_cache": output.past_key_values}
155
152
 
@@ -165,11 +162,10 @@ class LiteRTExportableModuleForDecoderOnlyLMPrefill(
165
162
  )
166
163
  return tokens, tokens_dynamic_shape
167
164
 
168
- def get_sample_inputs(
169
- self, model_config, export_config: ExportableModuleConfig
170
- ):
165
+ def get_sample_inputs(self, model_config):
166
+ export_config = self.export_config
171
167
  kv_cache_inputs, kv_cache_dynamic_shapes = self.get_sample_kv_cache(
172
- model_config, export_config
168
+ model_config
173
169
  )
174
170
  batch_size = export_config.batch_size
175
171
  cache_length = export_config.cache_length
@@ -218,6 +214,7 @@ class LiteRTExportableModuleForDecoderOnlyLMGenerate(
218
214
  mask,
219
215
  ):
220
216
  inputs = self.adapt_inputs(tokens, None, input_pos, kv_cache, mask)
217
+ inputs |= self.attention_kwargs()
221
218
  output = self.model(**inputs)
222
219
  return {"kv_cache": output.past_key_values, "logits": output.logits}
223
220
 
@@ -231,11 +228,10 @@ class LiteRTExportableModuleForDecoderOnlyLMGenerate(
231
228
  tokens_dynamic_shape = {"tokens": None} if decode_length_dim else {}
232
229
  return tokens, tokens_dynamic_shape
233
230
 
234
- def get_sample_inputs(
235
- self, model_config, export_config: ExportableModuleConfig
236
- ):
231
+ def get_sample_inputs(self, model_config):
232
+ export_config = self.export_config
237
233
  kv_cache_inputs, kv_cache_dynamic_shapes = self.get_sample_kv_cache(
238
- model_config, export_config
234
+ model_config
239
235
  )
240
236
  batch_size = export_config.batch_size
241
237
  cache_length = export_config.cache_length
@@ -0,0 +1,39 @@
1
+ # Copyright 2025 The LiteRT Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Exportable modules."""
16
+
17
+ import dataclasses
18
+ import torch
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class ExportableModuleConfig:
23
+ """Config for exportable modules."""
24
+
25
+ batch_size: int = 1
26
+ cache_length: int = 1280
27
+ prefill_lengths: list[int] = dataclasses.field(default_factory=lambda: [128])
28
+ # For dynamic shape
29
+ cache_length_dim: torch.export.Dim | None = None
30
+ prefill_length_dim: torch.export.Dim | None = None
31
+
32
+ # Export configs
33
+ externalize_embedder: bool = False
34
+ externalize_rope: bool = False
35
+
36
+ split_cache: bool = False
37
+ cache_implementation: str = "LiteRTLMCache"
38
+ k_ts_idx: int = 2
39
+ v_ts_idx: int = 3
@@ -17,6 +17,7 @@
17
17
  import math
18
18
  from typing import Optional
19
19
 
20
+ from litert_torch.generative.custom_ops import bmm_4d as bmm_lib
20
21
  from litert_torch.generative.export_hf.core.split_cache import cache as kv_cache_lib
21
22
  import torch
22
23
  import torch.nn.functional as F
@@ -28,6 +29,8 @@ def _scaled_dot_product_attention(
28
29
  key_cache: kv_cache_lib.KeyCacheEntry,
29
30
  value_cache: kv_cache_lib.ValueCacheEntry,
30
31
  head_size: int,
32
+ k_ts_idx: int,
33
+ v_ts_idx: int,
31
34
  mask: Optional[torch.Tensor] = None,
32
35
  scale: Optional[float] = None,
33
36
  softcap: Optional[float] = None,
@@ -40,6 +43,8 @@ def _scaled_dot_product_attention(
40
43
  value_cache: A tuple of Value tensor. 1(bk)sh
41
44
  head_size (int): head dimension.
42
45
  mask (torch.Tensor): the optional mask tensor.
46
+ k_ts_idx (int): the timestamp index of the key tensor.
47
+ v_ts_idx (int): the timestamp index of the value tensor.
43
48
  scale (float): the optional scale factor.
44
49
  softcap (float): the optional softcap for the logits.
45
50
 
@@ -60,8 +65,13 @@ def _scaled_dot_product_attention(
60
65
  assert mask is not None, "Mask should not be None!"
61
66
  t = mask.shape[2]
62
67
 
63
- logits0 = torch.einsum("abth,abhs->abts", query, key_past)
64
- logits1 = torch.einsum("abth,abhs->abts", query, key)
68
+ if k_ts_idx == 2:
69
+ bmm_fn = bmm_lib.bmm_4d
70
+ else:
71
+ assert k_ts_idx == 3, "k_ts_idx must be 2 or 3."
72
+ bmm_fn = lambda x, y: torch.einsum("abth,abhs->abts", x, y)
73
+ logits0 = bmm_fn(query, key_past)
74
+ logits1 = bmm_fn(query, key)
65
75
  logits = torch.cat([logits0, logits1], dim=-1)
66
76
 
67
77
  _, bk, gt, s = logits.shape
@@ -76,8 +86,13 @@ def _scaled_dot_product_attention(
76
86
  probs = F.softmax(padded_logits, dim=-1).type_as(key)
77
87
  probs0, probs1 = probs[..., :-t], probs[..., -t:]
78
88
 
79
- encoded0 = torch.einsum("abts,absh->abth", probs0, value_past)
80
- encoded1 = torch.einsum("abts,absh->abth", probs1, value)
89
+ if v_ts_idx == 3:
90
+ bmm_fn = bmm_lib.bmm_4d
91
+ else:
92
+ assert v_ts_idx == 2, "v_ts_idx must be 2 or 3."
93
+ bmm_fn = lambda x, y: torch.einsum("abts,absh->abth", x, y)
94
+ encoded0 = bmm_fn(probs0, value_past)
95
+ encoded1 = bmm_fn(probs1, value)
81
96
  encoded = encoded0 + encoded1
82
97
 
83
98
  return encoded # 1, bk, gt, h
@@ -94,7 +109,6 @@ def split_cache_attention(
94
109
  **kwargs, # You need to accept **kwargs as models will pass other args
95
110
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
96
111
  """ODML transposed attention implementation for NPU."""
97
- del kwargs
98
112
 
99
113
  b, n, seq_len, h = query.shape
100
114
  if hasattr(module, "num_key_value_groups"):
@@ -102,6 +116,13 @@ def split_cache_attention(
102
116
  else:
103
117
  g = 1
104
118
  num_query_groups = n // g
119
+ k_ts_idx: int | None = kwargs.get("k_ts_idx", None)
120
+ v_ts_idx: int | None = kwargs.get("v_ts_idx", None)
121
+ if k_ts_idx is None or v_ts_idx is None:
122
+ raise ValueError(
123
+ "Timestamp indices not passed to attention module. The model is not"
124
+ " passing the kwargs correctly."
125
+ )
105
126
  # bnth -> b(kg)th -> 1(bk)(gt)h
106
127
  query = query.reshape(1, b * num_query_groups, g * seq_len, h)
107
128
 
@@ -113,6 +134,8 @@ def split_cache_attention(
113
134
  mask=attention_mask,
114
135
  scale=scaling,
115
136
  softcap=softcap,
137
+ k_ts_idx=k_ts_idx,
138
+ v_ts_idx=v_ts_idx,
116
139
  ) # 1, bk, gt, h
117
140
  sdpa_out = sdpa_out.reshape(b, -1, seq_len, h).permute(0, 2, 1, 3)
118
141
  return sdpa_out, None
@@ -25,16 +25,26 @@ Shape annotations used here:
25
25
  """
26
26
 
27
27
  from typing import Any, List, Optional, Self, Tuple
28
- import litert_torch.generative.export_hf.core.cache_base as cache_base_lib
29
28
  import jaxtyping as jt
29
+ from litert_torch.generative.export_hf.core import exportable_module_config
30
+ import litert_torch.generative.export_hf.core.cache_base as cache_base_lib
30
31
  import torch
31
32
  import torch.utils._pytree as pytree
32
33
 
34
+ ExportableModuleConfig = exportable_module_config.ExportableModuleConfig
33
35
 
34
- KeyCache = jt.Shaped[torch.Tensor, "1 BK H S"]
35
- KeySlice = jt.Shaped[torch.Tensor, "1 BK H T"]
36
- ValueCache = jt.Shaped[torch.Tensor, "1 BK S H"]
37
- ValueSlice = jt.Shaped[torch.Tensor, "1 BK T H"]
36
+ KeyCache = (
37
+ jt.Shaped[torch.Tensor, "1 BK H S"] | jt.Shaped[torch.Tensor, "1 BK S H"]
38
+ )
39
+ KeySlice = (
40
+ jt.Shaped[torch.Tensor, "1 BK H T"] | jt.Shaped[torch.Tensor, "1 BK T H"]
41
+ )
42
+ ValueCache = (
43
+ jt.Shaped[torch.Tensor, "1 BK S H"] | jt.Shaped[torch.Tensor, "1 BK H S"]
44
+ )
45
+ ValueSlice = (
46
+ jt.Shaped[torch.Tensor, "1 BK T H"] | jt.Shaped[torch.Tensor, "1 BK H T"]
47
+ )
38
48
 
39
49
  KeyCacheEntry = Tuple[KeyCache, KeySlice | None]
40
50
  ValueCacheEntry = Tuple[ValueCache, ValueSlice | None]
@@ -51,6 +61,8 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
51
61
  key_cache: KeyCacheEntry,
52
62
  value_cache: ValueCacheEntry,
53
63
  batch_size: int = 1,
64
+ k_ts_idx: int = 2,
65
+ v_ts_idx: int = 3,
54
66
  **kwargs,
55
67
  ):
56
68
  super().__init__()
@@ -62,12 +74,16 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
62
74
  self.values = value_cache
63
75
  self.is_initialized = True
64
76
 
77
+ self.k_ts_idx = k_ts_idx
78
+ self.v_ts_idx = v_ts_idx
79
+
65
80
  self.k_cache_shape = self.keys[0].shape
66
81
  self.v_cache_shape = self.values[0].shape
67
- self.max_cache_len = self.k_cache_shape[3]
82
+ self.max_cache_len = self.k_cache_shape[self.k_ts_idx]
68
83
  self.batch_size = batch_size
69
- self.num_key_value_heads = self.k_cache_shape[1] // self.batch_size
70
- self.head_dim = self.k_cache_shape[2]
84
+ self.head_dim = (
85
+ self.k_cache_shape[2] if self.k_ts_idx == 3 else self.k_cache_shape[3]
86
+ )
71
87
  self.additional_states = kwargs.get("additional_states", None)
72
88
 
73
89
  self.cumulative_length = 0
@@ -75,6 +91,12 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
75
91
  def get_batch_size(self) -> int:
76
92
  return self.batch_size
77
93
 
94
+ def get_k_ts_idx(self) -> int:
95
+ return self.k_ts_idx
96
+
97
+ def get_v_ts_idx(self) -> int:
98
+ return self.v_ts_idx
99
+
78
100
  def lazy_initialization(self, key_states: torch.Tensor):
79
101
  # Since we don't support real lazy initialization, this function could only
80
102
  # be called by Cache.early_initialization, where uses a standard cache
@@ -97,12 +119,25 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
97
119
 
98
120
  value_states = value_states.to(self.values[0].dtype)
99
121
 
100
- key_states = key_states.permute(0, 1, 3, 2).reshape(
101
- 1, -1, self.head_dim, seq_len
102
- ) # 1, bk, h, s
103
- value_states = value_states.reshape(
104
- 1, -1, seq_len, self.head_dim
105
- ) # 1, bk, s, h
122
+ if self.k_ts_idx == 2:
123
+ key_states = key_states.reshape(
124
+ 1, -1, seq_len, self.head_dim
125
+ ) # 1, bk, s, h
126
+ else:
127
+ assert self.k_ts_idx == 3, "k_ts_idx must be 2 or 3."
128
+ key_states = key_states.permute(0, 1, 3, 2).reshape(
129
+ 1, -1, self.head_dim, seq_len
130
+ ) # 1, bk, h, s
131
+
132
+ if self.v_ts_idx == 2:
133
+ value_states = value_states.reshape(
134
+ 1, -1, seq_len, self.head_dim
135
+ ) # 1, bk, s, h
136
+ else:
137
+ assert self.v_ts_idx == 3, "v_ts_idx must be 2 or 3."
138
+ value_states = value_states.permute(0, 1, 3, 2).reshape(
139
+ 1, -1, self.head_dim, seq_len
140
+ ) # 1, bk, h, s
106
141
 
107
142
  self.keys = (self.keys[0], key_states)
108
143
  self.values = (self.values[0], value_states)
@@ -123,37 +158,68 @@ class LiteRTLMSplitCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
123
158
 
124
159
  @classmethod
125
160
  def _infer_cache_shape_from_config(
126
- cls, model_config, layer_index, cache_length, batch_size=1
161
+ cls,
162
+ model_config,
163
+ layer_index,
164
+ export_config: ExportableModuleConfig,
165
+ **kwargs,
127
166
  ):
128
167
  """Infers the KV cache shape from the model config."""
129
168
  del layer_index # Unused.
169
+ del kwargs # Unused.
170
+ cache_length = export_config.cache_length
171
+ batch_size = export_config.batch_size
172
+ k_ts_idx = export_config.k_ts_idx
173
+ v_ts_idx = export_config.v_ts_idx
130
174
  num_kv_heads = model_config.num_key_value_heads
131
175
  embed_size_per_head = (
132
176
  getattr(model_config, "head_dim", None)
133
177
  or model_config.hidden_size // model_config.num_attention_heads
134
178
  )
135
179
 
136
- k_cache_shape = (
137
- 1,
138
- batch_size * num_kv_heads,
139
- embed_size_per_head,
140
- cache_length,
141
- )
142
- v_cache_shape = (
143
- 1,
144
- batch_size * num_kv_heads,
145
- cache_length,
146
- embed_size_per_head,
147
- )
180
+ if k_ts_idx == 2:
181
+ k_cache_shape = (
182
+ 1,
183
+ batch_size * num_kv_heads,
184
+ cache_length,
185
+ embed_size_per_head,
186
+ )
187
+ else:
188
+ assert k_ts_idx == 3, "k_ts_idx must be 2 or 3."
189
+ k_cache_shape = (
190
+ 1,
191
+ batch_size * num_kv_heads,
192
+ embed_size_per_head,
193
+ cache_length,
194
+ )
195
+ if v_ts_idx == 2:
196
+ v_cache_shape = (
197
+ 1,
198
+ batch_size * num_kv_heads,
199
+ cache_length,
200
+ embed_size_per_head,
201
+ )
202
+ else:
203
+ assert v_ts_idx == 3, "v_ts_idx must be 2 or 3."
204
+ v_cache_shape = (
205
+ 1,
206
+ batch_size * num_kv_heads,
207
+ embed_size_per_head,
208
+ cache_length,
209
+ )
148
210
  return k_cache_shape, v_cache_shape
149
211
 
150
212
  @classmethod
151
213
  def create_from_config(
152
- cls, model_config, layer_index, cache_length, batch_size=1, **kwargs
214
+ cls,
215
+ model_config,
216
+ layer_index,
217
+ export_config: ExportableModuleConfig,
218
+ **kwargs,
153
219
  ) -> Self:
154
220
  """Creates a KV cache from the model config."""
155
221
  k_cache_shape, v_cache_shape = cls._infer_cache_shape_from_config(
156
- model_config, layer_index, cache_length, batch_size
222
+ model_config, layer_index, export_config, **kwargs
157
223
  )
158
224
  keys = torch.zeros(k_cache_shape, dtype=torch.float32)
159
225
  values = torch.zeros(v_cache_shape, dtype=torch.float32)
@@ -165,14 +231,22 @@ class LiteRTLMSplitCache(cache_base_lib.LiteRTLMCacheMixin):
165
231
  """Optimized Cache class for HuggingFace integration."""
166
232
 
167
233
  @classmethod
168
- def create_from_config(cls, model_config, cache_length, batch_size=1) -> Self:
234
+ def create_from_config(
235
+ cls,
236
+ model_config,
237
+ export_config: ExportableModuleConfig,
238
+ **kwargs,
239
+ ) -> Self:
169
240
  """Creates a KV cache from the model config."""
170
241
  num_layers = model_config.num_hidden_layers
171
242
  layers = []
172
243
  for layer_index in range(num_layers):
173
244
  layers.append(
174
245
  LiteRTLMSplitCacheLayer.create_from_config(
175
- model_config, layer_index, cache_length, batch_size=batch_size
246
+ model_config,
247
+ layer_index,
248
+ export_config,
249
+ **kwargs,
176
250
  )
177
251
  )
178
252
  return cls(layers)
@@ -188,6 +262,8 @@ def _flatten_kvc_t(
188
262
  layer_0 = kvc.layers[0]
189
263
  assert isinstance(layer_0, cache_base_lib.LiteRTLMCacheLayerMixin)
190
264
  batch_size = layer_0.get_batch_size()
265
+ k_ts_idx = layer_0.get_k_ts_idx()
266
+ v_ts_idx = layer_0.get_v_ts_idx()
191
267
  for i, cache_layer in enumerate(kvc.layers):
192
268
  flattened.append(cache_layer.keys[0])
193
269
  flat_names.append(f"k_{i}")
@@ -199,16 +275,18 @@ def _flatten_kvc_t(
199
275
  flat_names.append(f"k_{i}_slice")
200
276
  flattened.append(cache_layer.values[1])
201
277
  flat_names.append(f"v_{i}_slice")
202
- return flattened, [flat_names, (batch_size, num_layers)]
278
+ return flattened, [flat_names, (batch_size, num_layers, k_ts_idx, v_ts_idx)]
203
279
 
204
280
 
205
281
  def _unflatten_kvc_t(
206
282
  values: List[torch.Tensor],
207
- context: Tuple[List[str], Tuple[int, int]],
283
+ context: Tuple[List[str], Tuple[int, int, int, int]],
208
284
  ) -> LiteRTLMSplitCache:
209
285
  """Unflattens the KV cache from a list of tensors."""
210
286
  flat_names = context[0]
211
287
  batch_size = context[1][0]
288
+ k_ts_idx = context[1][2]
289
+ v_ts_idx = context[1][3]
212
290
  num_layers = context[1][1]
213
291
  kv_entries = []
214
292
  for i in range(num_layers):
@@ -231,6 +309,8 @@ def _unflatten_kvc_t(
231
309
  key_cache=(k_cache, k_cache_update),
232
310
  value_cache=(v_cache, v_cache_update),
233
311
  batch_size=batch_size,
312
+ k_ts_idx=k_ts_idx,
313
+ v_ts_idx=v_ts_idx,
234
314
  )
235
315
  )
236
316
  obj = LiteRTLMSplitCache(kv_entries)
@@ -14,12 +14,12 @@
14
14
  # ==============================================================================
15
15
  """Exportable module for split cache attention models."""
16
16
 
17
+ import copy
17
18
  from litert_torch.generative.export_hf.core import cache as base_cache_lib
18
19
  from litert_torch.generative.export_hf.core import exportable_module as base_exportable_module
19
20
  from litert_torch.generative.export_hf.core import utils
20
21
  from litert_torch.generative.export_hf.core.split_cache import attention_mask
21
22
  from litert_torch.generative.export_hf.core.split_cache import cache as kv_cache_lib
22
- import numpy as np
23
23
  import torch
24
24
  from torch import nn
25
25
 
@@ -64,9 +64,9 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLM(
64
64
  ret['inputs_embeds'] = embeddings
65
65
 
66
66
  ret.update({
67
- 'position_ids': np.arange(embeddings.shape[1])[None, :],
67
+ 'position_ids': torch.arange(embeddings.shape[1])[None, :],
68
68
  'past_key_values': kv_cache,
69
- 'cache_position': np.arange(embeddings.shape[1]),
69
+ 'cache_position': torch.arange(embeddings.shape[1]),
70
70
  'attention_mask': masks,
71
71
  # Other common settings
72
72
  'use_cache': True,
@@ -164,6 +164,7 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLMPrefill(
164
164
  mask,
165
165
  kv_cache,
166
166
  )
167
+ inputs |= self.attention_kwargs()
167
168
  output = self.model(**inputs)
168
169
  output_cache = output.past_key_values
169
170
  return self.post_process_kv_cache(output_cache)
@@ -171,9 +172,9 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLMPrefill(
171
172
  def get_sample_inputs(
172
173
  self,
173
174
  model_config,
174
- export_config: base_exportable_module.ExportableModuleConfig,
175
175
  ):
176
- kv_cache_inputs, _ = self.get_sample_kv_cache(model_config, export_config)
176
+ export_config = self.export_config
177
+ kv_cache_inputs, _ = self.get_sample_kv_cache(model_config)
177
178
 
178
179
  sample_inputs = {}
179
180
  for prefill_length in export_config.prefill_lengths:
@@ -207,6 +208,7 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLMGenerate(
207
208
  mask,
208
209
  kv_cache,
209
210
  )
211
+ inputs |= self.attention_kwargs()
210
212
  output = self.model(**inputs)
211
213
  output_cache = output.past_key_values
212
214
  ret = self.post_process_kv_cache(output_cache)
@@ -216,9 +218,9 @@ class LiteRTSplitCacheExportableModuleForDecoderOnlyLMGenerate(
216
218
  def get_sample_inputs(
217
219
  self,
218
220
  model_config,
219
- export_config: base_exportable_module.ExportableModuleConfig,
220
221
  ):
221
- kv_cache_inputs, _ = self.get_sample_kv_cache(model_config, export_config)
222
+ export_config = self.export_config
223
+ kv_cache_inputs, _ = self.get_sample_kv_cache(model_config)
222
224
  sample_inputs = {
223
225
  **kv_cache_inputs,
224
226
  **self._get_input(
@@ -322,13 +324,20 @@ class CacheUpdate(torch.nn.Module):
322
324
  return {'kv_cache': kv_cache}
323
325
 
324
326
  @classmethod
325
- def _get_input(cls, model_config, input_length, cache_length, batch_size=1):
327
+ def _get_input(
328
+ cls,
329
+ model_config,
330
+ input_length,
331
+ export_config: base_exportable_module.ExportableModuleConfig,
332
+ ):
326
333
  """Gets sample inputs for the model."""
327
334
  kv_cache = base_cache_lib.LiteRTLMCache.create_from_config(
328
- model_config, cache_length, batch_size, reverse_kv=True
335
+ model_config, export_config
329
336
  )
337
+ slice_export_config = copy.deepcopy(export_config)
338
+ slice_export_config.cache_length = input_length
330
339
  kv_slice = base_cache_lib.LiteRTLMCache.create_from_config(
331
- model_config, input_length, batch_size, reverse_kv=True
340
+ model_config, slice_export_config
332
341
  )
333
342
  return {
334
343
  'kv_cache': kv_cache,
@@ -348,15 +357,13 @@ class CacheUpdate(torch.nn.Module):
348
357
  inputs = cls._get_input(
349
358
  model_config,
350
359
  prefill_length,
351
- export_config.cache_length,
352
- export_config.batch_size,
360
+ export_config,
353
361
  )
354
362
  sample_inputs[f'prefill_cache_update_{prefill_length}'] = (inputs, {})
355
363
  decode_inputs = cls._get_input(
356
364
  model_config,
357
365
  1,
358
- export_config.cache_length,
359
- export_config.batch_size,
366
+ export_config,
360
367
  )
361
368
  sample_inputs['decode_cache_update'] = (decode_inputs, {})
362
369
  return sample_inputs
@@ -30,7 +30,10 @@ def export(
30
30
  cache_length=4096,
31
31
  quantization_recipe: str = 'dynamic_wi8_afp32',
32
32
  enable_dynamic_shape: bool = False,
33
- # externalize_embedder: bool = False,
33
+ externalize_embedder: bool = False,
34
+ key_ts_idx: int = 2,
35
+ value_ts_idx: int = 3,
36
+ split_cache: bool = False,
34
37
  auto_model_override: str | None = None,
35
38
  # target_accelerator: str | None = None,
36
39
  trust_remote_code: bool = False,
@@ -46,6 +49,8 @@ def export(
46
49
  auto_model_override=auto_model_override,
47
50
  )
48
51
  del config # Unused.
52
+ if split_cache and not externalize_embedder:
53
+ raise ValueError('Split cache requires externalize embedder to be enabled.')
49
54
  export_config = exportable_module.ExportableModuleConfig(
50
55
  batch_size=1,
51
56
  prefill_lengths=prefill_lengths,
@@ -56,17 +61,45 @@ def export(
56
61
  cache_length_dim=torch.export.Dim('cache_length')
57
62
  if enable_dynamic_shape
58
63
  else None,
59
- externalize_embedder=False,
64
+ externalize_embedder=externalize_embedder,
65
+ k_ts_idx=key_ts_idx,
66
+ v_ts_idx=value_ts_idx,
67
+ split_cache=split_cache,
68
+ externalize_rope=split_cache,
69
+ cache_implementation='LiteRTLMSplitCache'
70
+ if split_cache
71
+ else 'LiteRTLMCache',
60
72
  )
61
73
  export_lib.export_text_prefill_decode_model(
62
74
  pt_model, text_model_config, export_config, work_dir, quantization_recipe
63
75
  )
64
76
  gc.collect()
77
+ if externalize_embedder:
78
+ export_lib.export_embedder_model(
79
+ pt_model,
80
+ text_model_config,
81
+ export_config,
82
+ work_dir,
83
+ quantization_recipe,
84
+ )
85
+ gc.collect()
86
+ if split_cache:
87
+ export_lib.export_auxiliary_model(
88
+ pt_model,
89
+ text_model_config,
90
+ export_config,
91
+ work_dir,
92
+ quantization_recipe,
93
+ )
94
+ gc.collect()
65
95
  tokenizer_model_path = export_lib.export_tokenizer(tokenizer, work_dir)
66
96
  tflite_model_path = os.path.join(
67
97
  work_dir,
68
98
  'model_quantized.tflite' if quantization_recipe else 'model.tflite',
69
99
  )
100
+ if externalize_embedder or split_cache:
101
+ # TODO(weiyiw): Add support for packaging models.
102
+ return
70
103
  litert_lm_builder.package_model(
71
104
  pt_model,
72
105
  tokenizer,
litert_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of litert-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.9.0.dev20260202"
18
+ __version__ = "0.9.0.dev20260203"