litert-torch-nightly 0.9.0.dev20260201__py3-none-any.whl → 0.9.0.dev20260203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. litert_torch/generative/export_hf/core/attention.py +86 -8
  2. litert_torch/generative/export_hf/core/attention_test.py +7 -2
  3. litert_torch/generative/export_hf/core/cache.py +112 -64
  4. litert_torch/generative/export_hf/core/cache_base.py +19 -2
  5. litert_torch/generative/export_hf/core/export_lib.py +55 -6
  6. litert_torch/generative/export_hf/core/exportable_module.py +30 -34
  7. litert_torch/generative/export_hf/core/exportable_module_config.py +39 -0
  8. litert_torch/generative/export_hf/core/split_cache/attention.py +28 -5
  9. litert_torch/generative/export_hf/core/split_cache/cache.py +113 -33
  10. litert_torch/generative/export_hf/core/split_cache/exportable_module.py +21 -14
  11. litert_torch/generative/export_hf/export.py +35 -2
  12. litert_torch/version.py +1 -1
  13. {litert_torch_nightly-0.9.0.dev20260201.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/METADATA +1 -1
  14. {litert_torch_nightly-0.9.0.dev20260201.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/RECORD +18 -17
  15. {litert_torch_nightly-0.9.0.dev20260201.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/WHEEL +0 -0
  16. {litert_torch_nightly-0.9.0.dev20260201.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/entry_points.txt +0 -0
  17. {litert_torch_nightly-0.9.0.dev20260201.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/licenses/LICENSE +0 -0
  18. {litert_torch_nightly-0.9.0.dev20260201.dist-info → litert_torch_nightly-0.9.0.dev20260203.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,83 @@
14
14
  # ==============================================================================
15
15
  """Optimized Attention layer for HuggingFace integration."""
16
16
 
17
-
18
- from litert_torch.generative.layers import scaled_dot_product_attention as sdpa_lib
17
+ import math
18
+ from typing import Optional
19
19
  import jaxtyping as jt
20
+ from litert_torch.generative.custom_ops import bmm_4d as bmm_lib
20
21
  import torch
22
+ import torch.nn.functional as F
21
23
  import transformers
22
24
 
23
25
 
26
+ def scaled_dot_product_attention_transposed(
27
+ query: torch.Tensor,
28
+ key: torch.Tensor,
29
+ value: torch.Tensor,
30
+ head_size: int,
31
+ k_ts_idx: int,
32
+ v_ts_idx: int,
33
+ mask: Optional[torch.Tensor] = None,
34
+ scale: Optional[float] = None,
35
+ softcap: Optional[float] = None,
36
+ alibi_bias: Optional[torch.Tensor] = None,
37
+ ):
38
+ """Scaled dot product attention with transposed key and value.
39
+
40
+ Args:
41
+ query: Query tensor, with shape [B, T, N, H].
42
+ key: Key tensor, with shape [B, T, KV_LEN, H].
43
+ value: Value tensor, with shape [B, T, H, KV_LEN].
44
+ head_size (int): head dimension.
45
+ mask (torch.Tensor): the optional mask tensor.
46
+ scale (float): the optional scale factor.
47
+ softcap (float): the optional softcap for the logits.
48
+ alibi_bias (torch.Tensor): optional alibi bias tensor.
49
+
50
+ Returns:
51
+ The output tensor of scaled_dot_product_attention_transposed.
52
+ """
53
+ if scale is None:
54
+ scale = 1.0 / math.sqrt(head_size)
55
+
56
+ if alibi_bias is not None:
57
+ alibi_bias = alibi_bias * scale
58
+ if mask is None:
59
+ mask = alibi_bias
60
+ else:
61
+ mask = mask + alibi_bias
62
+
63
+ query = query * scale
64
+
65
+ assert mask is not None, "Mask should not be None!"
66
+ t = mask.shape[2]
67
+ if k_ts_idx == 2:
68
+ bmm_fn = bmm_lib.bmm_4d
69
+ else:
70
+ assert k_ts_idx == 3, "k_ts_idx must be 2 or 3."
71
+ bmm_fn = lambda x, y: torch.einsum("abth,abhs->abts", x, y)
72
+ logits = bmm_fn(query, key)
73
+
74
+ _, bk, gt, s = logits.shape
75
+ g = gt // t
76
+ logits = logits.reshape((bk, g, t, s))
77
+ if softcap is not None:
78
+ logits = torch.tanh(logits / softcap)
79
+ logits = logits * softcap
80
+
81
+ padded_logits = logits + mask
82
+ padded_logits = padded_logits.reshape(1, bk, gt, s)
83
+ probs = F.softmax(padded_logits, dim=-1).type_as(key)
84
+ if v_ts_idx == 3:
85
+ bmm_fn = bmm_lib.bmm_4d
86
+ else:
87
+ assert v_ts_idx == 2, "v_ts_idx must be 2 or 3."
88
+ bmm_fn = lambda x, y: torch.einsum("abts,absh->abth", x, y)
89
+ encoded = bmm_fn(probs, value)
90
+
91
+ return encoded # 1, bk, gt, h
92
+
93
+
24
94
  def transposed_attention(
25
95
  module: torch.nn.Module,
26
96
  query: jt.Float[torch.Tensor, "b n t h"],
@@ -46,20 +116,28 @@ def transposed_attention(
46
116
  Returns:
47
117
  The attention output tensor.
48
118
  """
49
- del kwargs # Unused in this implementation but required by the interface.
50
119
 
51
120
  b, n, seq_len, h = query.shape
52
121
  g = getattr(module, "num_key_value_groups", 1)
53
122
  num_query_groups = n // g
54
123
  # bnth -> b(kg)th -> 1(bk)(gt)h
55
124
  query = query.reshape(1, b * num_query_groups, g * seq_len, h)
125
+ key_ts_idx: int | None = kwargs.get("k_ts_idx", None)
126
+ value_ts_idx: int | None = kwargs.get("v_ts_idx", None)
127
+ if key_ts_idx is None or value_ts_idx is None:
128
+ raise ValueError(
129
+ "Timestamp indices not passed to attention module. The model is not"
130
+ " passing the kwargs correctly."
131
+ )
56
132
 
57
133
  # 1, bk, gt, h
58
- sdpa_out = sdpa_lib.scaled_dot_product_attention_transposed(
59
- query,
60
- key,
61
- value,
62
- h,
134
+ sdpa_out = scaled_dot_product_attention_transposed(
135
+ query=query,
136
+ key=key,
137
+ value=value,
138
+ head_size=h,
139
+ k_ts_idx=key_ts_idx,
140
+ v_ts_idx=value_ts_idx,
63
141
  mask=attention_mask,
64
142
  scale=scaling,
65
143
  softcap=softcap,
@@ -71,7 +71,7 @@ class DummyAttentionModule(torch.nn.Module):
71
71
  self.scaling = scaling
72
72
  self.softcap = softcap
73
73
 
74
- def forward(self, query, key, value, attention_mask):
74
+ def forward(self, query, key, value, attention_mask, **kwargs):
75
75
  attention_interface = modeling_utils.ALL_ATTENTION_FUNCTIONS[
76
76
  self.attention_implementation
77
77
  ]
@@ -84,6 +84,7 @@ class DummyAttentionModule(torch.nn.Module):
84
84
  attention_mask,
85
85
  scaling=self.scaling,
86
86
  softcap=self.softcap,
87
+ **kwargs,
87
88
  )[0]
88
89
 
89
90
 
@@ -139,8 +140,12 @@ class AttentionTest(parameterized.TestCase):
139
140
  scaling=scl,
140
141
  softcap=scp,
141
142
  )
143
+ attention_kwargs = {
144
+ 'k_ts_idx': 2,
145
+ 'v_ts_idx': 3,
146
+ }
142
147
  expected = attn(query, key, value, mask)
143
- actual = test_attn(query, key, value, mask)
148
+ actual = test_attn(query, key, value, mask, **attention_kwargs)
144
149
  self.assertTrue(
145
150
  torch.allclose(
146
151
  expected, actual, rtol=1e-2, atol=1e-2, equal_nan=True
@@ -25,18 +25,30 @@ Shape annotations used here:
25
25
  """
26
26
 
27
27
  from typing import Any, List, Optional, Tuple
28
+
29
+ import jaxtyping as jt
28
30
  import litert_torch.generative.custom_ops.dynamic_update_slice as tfl_dus
31
+ from litert_torch.generative.export_hf.core import exportable_module_config
29
32
  import litert_torch.generative.export_hf.core.cache_base as cache_base_lib
30
- import jaxtyping as jt
31
33
  import torch
32
34
  import torch.utils._pytree as pytree
33
35
 
36
+ ExportableModuleConfig = exportable_module_config.ExportableModuleConfig
37
+
34
38
 
35
39
  # Shape annotations for the cache entries.
36
- KeyCache = jt.Shaped[torch.Tensor, "1 BK S H"]
37
- KeySlice = jt.Shaped[torch.Tensor, "1 BK T H"]
38
- ValueCache = jt.Shaped[torch.Tensor, "1 BK H S"]
39
- ValueSlice = jt.Shaped[torch.Tensor, "1 BK H T"]
40
+ KeyCache = (
41
+ jt.Shaped[torch.Tensor, "1 BK S H"] | jt.Shaped[torch.Tensor, "1 BK H S"]
42
+ )
43
+ KeySlice = (
44
+ jt.Shaped[torch.Tensor, "1 BK T H"] | jt.Shaped[torch.Tensor, "1 BK H T"]
45
+ )
46
+ ValueCache = (
47
+ jt.Shaped[torch.Tensor, "1 BK H S"] | jt.Shaped[torch.Tensor, "1 BK S H"]
48
+ )
49
+ ValueSlice = (
50
+ jt.Shaped[torch.Tensor, "1 BK H T"] | jt.Shaped[torch.Tensor, "1 BK T H"]
51
+ )
40
52
 
41
53
 
42
54
  def _get_slice_indices(
@@ -77,15 +89,11 @@ def _update_kv_impl(
77
89
  k_slice: KeySlice,
78
90
  v_slice: ValueSlice,
79
91
  cache_position: jt.Int32[torch.Tensor, "T"],
80
- reverse_kv: bool = False,
92
+ k_ts_idx: int,
93
+ v_ts_idx: int,
81
94
  ):
82
95
  """Updates the cache buffer using tfl.dynamic_update_slice."""
83
96
  cache_dim = 4
84
- k_ts_idx = 2 # K Cache shape is 1 BK S H
85
- v_ts_idx = 3 # V Cache shape is 1 BK H S
86
- if reverse_kv:
87
- k_ts_idx = 3 # K Cache shape is 1 BK H S
88
- v_ts_idx = 2 # V Cache shape is 1 BK S H
89
97
  positions = cache_position[0] # The position of the first input token.
90
98
  k_slice_indices = _get_slice_indices(positions.clone(), cache_dim, k_ts_idx)
91
99
  v_slice_indices = _get_slice_indices(positions.clone(), cache_dim, v_ts_idx)
@@ -109,27 +117,26 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
109
117
  key_cache: KeyCache,
110
118
  value_cache: ValueCache,
111
119
  batch_size: int = 1,
112
- reverse_kv: bool = False,
120
+ k_ts_idx: int = 2,
121
+ v_ts_idx: int = 3,
113
122
  **kwargs,
114
123
  ):
115
124
  super().__init__()
116
125
  self.keys = key_cache
117
126
  self.values = value_cache
118
- self.reverse_kv = reverse_kv
127
+ self.k_ts_idx = k_ts_idx # The index of the sequence dimension in K cache.
128
+ self.v_ts_idx = v_ts_idx # The index of the sequence dimension in V cache.
129
+ assert k_ts_idx in [2, 3]
130
+ assert v_ts_idx in [2, 3]
119
131
  self.is_initialized = True
120
132
 
121
133
  self.k_cache_shape = self.keys.shape
122
134
  self.v_cache_shape = self.values.shape
123
- self.max_cache_len = (
124
- self.v_cache_shape[2] if reverse_kv else self.k_cache_shape[2]
125
- )
135
+ self.max_cache_len = self.v_cache_shape[self.v_ts_idx]
126
136
  self.batch_size = batch_size
127
- self.num_key_value_heads = (
128
- self.v_cache_shape[1] if reverse_kv else self.k_cache_shape[1]
129
- ) // self.batch_size
130
- self.head_dim = (
131
- self.v_cache_shape[3] if reverse_kv else self.k_cache_shape[3]
132
- )
137
+ v_head_dim_idx = 3 if self.v_ts_idx == 2 else 2
138
+ self.head_dim = self.v_cache_shape[v_head_dim_idx]
139
+
133
140
  self.additional_states = kwargs.get("additional_states", None)
134
141
 
135
142
  self.cumulative_length = 0
@@ -137,6 +144,12 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
137
144
  def get_batch_size(self) -> int:
138
145
  return self.batch_size
139
146
 
147
+ def get_k_ts_idx(self) -> int:
148
+ return self.k_ts_idx
149
+
150
+ def get_v_ts_idx(self) -> int:
151
+ return self.v_ts_idx
152
+
140
153
  def lazy_initialization(self, key_states: torch.Tensor):
141
154
  # Since we don't support real lazy initialization, this function could only
142
155
  # be called by Cache.early_initialization, where uses a standard cache
@@ -162,13 +175,24 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
162
175
  value_states = value_states.to(self.values.dtype)
163
176
 
164
177
  if not cache_kwargs.get("kv_slice_preprocessed", False):
165
- assert not self.reverse_kv, "Reverse KV is not supported."
166
- key_states = key_states.reshape(
167
- 1, -1, seq_len, self.head_dim
168
- ) # 1, bk, s, h
169
- value_states = value_states.permute(0, 1, 3, 2).reshape(
170
- 1, -1, self.head_dim, seq_len
171
- ) # 1, bk, h, s
178
+ if self.k_ts_idx == 3:
179
+ key_target_shape = (1, -1, self.head_dim, seq_len)
180
+ key_states = key_states.permute(0, 1, 3, 2).reshape(*key_target_shape)
181
+ elif self.k_ts_idx == 2:
182
+ key_target_shape = (1, -1, seq_len, self.head_dim)
183
+ key_states = key_states.reshape(*key_target_shape)
184
+ else:
185
+ raise ValueError(f"Unsupported k_ts_idx: {self.k_ts_idx}")
186
+ if self.v_ts_idx == 3:
187
+ value_target_shape = (1, -1, self.head_dim, seq_len)
188
+ value_states = value_states.permute(0, 1, 3, 2).reshape(
189
+ *value_target_shape
190
+ )
191
+ elif self.v_ts_idx == 2:
192
+ value_target_shape = (1, -1, seq_len, self.head_dim)
193
+ value_states = value_states.reshape(*value_target_shape)
194
+ else:
195
+ raise ValueError(f"Unsupported v_ts_idx: {self.v_ts_idx}")
172
196
 
173
197
  cache_position: jt.Int32[torch.Tensor, "T"] = cache_kwargs.get(
174
198
  "cache_position"
@@ -182,7 +206,8 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
182
206
  key_states,
183
207
  value_states,
184
208
  cache_position,
185
- self.reverse_kv,
209
+ self.k_ts_idx,
210
+ self.v_ts_idx,
186
211
  )
187
212
  return self.keys, self.values
188
213
 
@@ -203,32 +228,52 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
203
228
  cls,
204
229
  model_config,
205
230
  layer_index,
206
- cache_length,
207
- batch_size=1,
208
- reverse_kv=False,
231
+ export_config: ExportableModuleConfig,
209
232
  ):
210
233
  """Infers the KV cache shape from the model config."""
211
234
  del layer_index # Unused.
235
+ cache_length = export_config.cache_length
236
+ batch_size = export_config.batch_size
237
+ k_ts_idx = export_config.k_ts_idx
238
+ v_ts_idx = export_config.v_ts_idx
212
239
  num_kv_heads = model_config.num_key_value_heads
213
240
  embed_size_per_head = (
214
241
  getattr(model_config, "head_dim", None)
215
242
  or model_config.hidden_size // model_config.num_attention_heads
216
243
  )
217
244
 
218
- k_cache_shape = (
219
- 1,
220
- batch_size * num_kv_heads,
221
- cache_length,
222
- embed_size_per_head,
223
- )
224
- v_cache_shape = (
225
- 1,
226
- batch_size * num_kv_heads,
227
- embed_size_per_head,
228
- cache_length,
229
- )
230
- if reverse_kv:
231
- k_cache_shape, v_cache_shape = v_cache_shape, k_cache_shape
245
+ if k_ts_idx == 2:
246
+ k_cache_shape = (
247
+ 1,
248
+ batch_size * num_kv_heads,
249
+ cache_length,
250
+ embed_size_per_head,
251
+ )
252
+ elif k_ts_idx == 3:
253
+ k_cache_shape = (
254
+ 1,
255
+ batch_size * num_kv_heads,
256
+ embed_size_per_head,
257
+ cache_length,
258
+ )
259
+ else:
260
+ raise ValueError(f"Unsupported k_ts_idx: {k_ts_idx}")
261
+ if v_ts_idx == 2:
262
+ v_cache_shape = (
263
+ 1,
264
+ batch_size * num_kv_heads,
265
+ cache_length,
266
+ embed_size_per_head,
267
+ )
268
+ elif v_ts_idx == 3:
269
+ v_cache_shape = (
270
+ 1,
271
+ batch_size * num_kv_heads,
272
+ embed_size_per_head,
273
+ cache_length,
274
+ )
275
+ else:
276
+ raise ValueError(f"Unsupported v_ts_idx: {v_ts_idx}")
232
277
  return k_cache_shape, v_cache_shape
233
278
 
234
279
  @classmethod
@@ -236,18 +281,22 @@ class LiteRTLMCacheLayer(cache_base_lib.LiteRTLMCacheLayerMixin):
236
281
  cls,
237
282
  model_config,
238
283
  layer_index,
239
- cache_length,
240
- batch_size=1,
241
- reverse_kv=False,
284
+ export_config: ExportableModuleConfig,
242
285
  **kwargs,
243
286
  ) -> "LiteRTLMCacheLayer":
244
287
  """Creates a KV cache from the model config."""
245
288
  k_cache_shape, v_cache_shape = cls._infer_cache_shape_from_config(
246
- model_config, layer_index, cache_length, batch_size, reverse_kv
289
+ model_config, layer_index, export_config
247
290
  )
248
291
  keys = torch.zeros(k_cache_shape, dtype=torch.float32)
249
292
  values = torch.zeros(v_cache_shape, dtype=torch.float32)
250
- return cls(keys, values, reverse_kv=reverse_kv, **kwargs)
293
+ return cls(
294
+ keys,
295
+ values,
296
+ k_ts_idx=export_config.k_ts_idx,
297
+ v_ts_idx=export_config.v_ts_idx,
298
+ **kwargs,
299
+ )
251
300
 
252
301
 
253
302
  @cache_base_lib.register_cache_implementation
@@ -258,9 +307,7 @@ class LiteRTLMCache(cache_base_lib.LiteRTLMCacheMixin):
258
307
  def create_from_config(
259
308
  cls,
260
309
  model_config,
261
- cache_length,
262
- batch_size=1,
263
- reverse_kv=False,
310
+ export_config: ExportableModuleConfig,
264
311
  **kwargs,
265
312
  ) -> "LiteRTLMCache":
266
313
  """Creates a KV cache from the model config."""
@@ -271,9 +318,8 @@ class LiteRTLMCache(cache_base_lib.LiteRTLMCacheMixin):
271
318
  LiteRTLMCacheLayer.create_from_config(
272
319
  model_config,
273
320
  layer_index,
274
- cache_length,
275
- batch_size=batch_size,
276
- reverse_kv=reverse_kv,
321
+ export_config,
322
+ **kwargs,
277
323
  )
278
324
  )
279
325
  return cls(layers)
@@ -281,7 +327,7 @@ class LiteRTLMCache(cache_base_lib.LiteRTLMCacheMixin):
281
327
 
282
328
  def _flatten_kvc_t(
283
329
  kvc: LiteRTLMCache,
284
- ) -> Tuple[List[torch.Tensor], Tuple[List[str], Tuple[int, int, bool]]]:
330
+ ) -> Tuple[List[torch.Tensor], Tuple[List[str], Tuple[int, int, int, int]]]:
285
331
  """Flattens the cache into a list of tensors."""
286
332
  flattened = []
287
333
  flat_names = []
@@ -289,22 +335,23 @@ def _flatten_kvc_t(
289
335
  layer_0 = kvc.layers[0]
290
336
  assert isinstance(layer_0, cache_base_lib.LiteRTLMCacheLayerMixin)
291
337
  batch_size = layer_0.get_batch_size()
292
- reverse_kv = getattr(layer_0, "reverse_kv", False)
338
+ k_ts_idx = layer_0.get_k_ts_idx()
339
+ v_ts_idx = layer_0.get_v_ts_idx()
293
340
  for i, layer in enumerate(kvc.layers):
294
341
  flattened.append(layer.keys)
295
342
  flat_names.append(f"k_{i}")
296
343
  flattened.append(layer.values)
297
344
  flat_names.append(f"v_{i}")
298
- return flattened, (flat_names, (batch_size, num_layers, reverse_kv))
345
+ return flattened, (flat_names, (batch_size, num_layers, k_ts_idx, v_ts_idx))
299
346
 
300
347
 
301
348
  def _unflatten_kvc_t(
302
349
  values: List[torch.Tensor],
303
- context: Tuple[List[str], Tuple[int, int, bool]],
350
+ context: Tuple[List[str], Tuple[int, int, int, int]],
304
351
  ) -> LiteRTLMCache:
305
352
  """Unflattens the cache from a list of tensors."""
306
353
  flat_names = context[0]
307
- batch_size, num_layers, reverse_kv = context[1]
354
+ batch_size, num_layers, k_ts_idx, v_ts_idx = context[1]
308
355
  layers = []
309
356
  for i in range(num_layers):
310
357
  k_cache_idx = flat_names.index(f"k_{i}")
@@ -314,7 +361,8 @@ def _unflatten_kvc_t(
314
361
  key_cache=values[k_cache_idx],
315
362
  value_cache=values[v_cache_idx],
316
363
  batch_size=batch_size,
317
- reverse_kv=reverse_kv,
364
+ k_ts_idx=k_ts_idx,
365
+ v_ts_idx=v_ts_idx,
318
366
  )
319
367
  )
320
368
  obj = LiteRTLMCache(layers)
@@ -15,8 +15,11 @@
15
15
  """Base class for cache."""
16
16
 
17
17
  import abc
18
+ from litert_torch.generative.export_hf.core import exportable_module_config
18
19
  from transformers import cache_utils
19
20
 
21
+ ExportableModuleConfig = exportable_module_config.ExportableModuleConfig
22
+
20
23
 
21
24
  class LiteRTLMCacheLayerMixin(cache_utils.CacheLayerMixin, abc.ABC):
22
25
  """Optimized Cache layer class mixin for HuggingFace integration."""
@@ -26,10 +29,24 @@ class LiteRTLMCacheLayerMixin(cache_utils.CacheLayerMixin, abc.ABC):
26
29
  """Returns the batch size of the cache."""
27
30
  ...
28
31
 
32
+ @abc.abstractmethod
33
+ def get_k_ts_idx(self) -> int:
34
+ """Returns the index of the sequence dimension in K cache."""
35
+ ...
36
+
37
+ @abc.abstractmethod
38
+ def get_v_ts_idx(self) -> int:
39
+ """Returns the index of the sequence dimension in V cache."""
40
+ ...
41
+
29
42
  @classmethod
30
43
  @abc.abstractmethod
31
44
  def create_from_config(
32
- cls, model_config, layer_index, cache_length, batch_size=1, **kwargs
45
+ cls,
46
+ model_config,
47
+ layer_index,
48
+ export_config: ExportableModuleConfig,
49
+ **kwargs
33
50
  ) -> "LiteRTLMCacheLayerMixin":
34
51
  ...
35
52
 
@@ -40,7 +57,7 @@ class LiteRTLMCacheMixin(cache_utils.Cache, abc.ABC):
40
57
  @classmethod
41
58
  @abc.abstractmethod
42
59
  def create_from_config(
43
- cls, model_config, cache_length, batch_size=1
60
+ cls, model_config, export_config: ExportableModuleConfig, **kwargs
44
61
  ) -> "LiteRTLMCacheMixin":
45
62
  """Creates a KV cache from the model config."""
46
63
  ...
@@ -26,6 +26,7 @@ from litert_torch.generative.export_hf.core import exportable_module
26
26
  from litert_torch.generative.export_hf.core import patches as _
27
27
  from litert_torch.generative.export_hf.core import utils
28
28
  from litert_torch.generative.export_hf.core.external_emb import exportable_module as external_emb_module
29
+ from litert_torch.generative.export_hf.core.external_rope import exportable_module as external_rope_module
29
30
  from litert_torch.generative.export_hf.core.external_rope import preprocess_model as external_rope_preprocess_model
30
31
  from litert_torch.generative.export_hf.core.mu import mu_pass_lib
31
32
  from litert_torch.generative.export_hf.core.split_cache import attention as _
@@ -34,6 +35,7 @@ from litert_torch.generative.tools import tokenizer_to_sentencepiece_lib as toke
34
35
  from litert_torch.odml_torch.experimental import torch_tfl
35
36
  import torch
36
37
  import transformers
38
+
37
39
  from ai_edge_quantizer import quantizer as quantizer_lib
38
40
  from ai_edge_quantizer import recipe as recipe_lib
39
41
 
@@ -174,12 +176,10 @@ def export_text_prefill_decode_model(
174
176
  prefill_module_cls, decode_module_cls = get_prefill_decode_exportable_cls(
175
177
  export_config
176
178
  )
177
- prefill_module = prefill_module_cls(model)
178
- decode_module = decode_module_cls(model)
179
+ prefill_module = prefill_module_cls(model, export_config)
180
+ decode_module = decode_module_cls(model, export_config)
179
181
  converter = converter_utils.Converter()
180
- sample_prefill_inputs = prefill_module.get_sample_inputs(
181
- text_model_config, export_config
182
- )
182
+ sample_prefill_inputs = prefill_module.get_sample_inputs(text_model_config)
183
183
  for signature_name, (
184
184
  sample_prefill_inputs,
185
185
  prefill_dynamic_shapes,
@@ -213,7 +213,7 @@ def export_text_prefill_decode_model(
213
213
  sample_kwargs=sample_prefill_inputs,
214
214
  )
215
215
  sample_decode_inputs, decode_dynamic_shapes = decode_module.get_sample_inputs(
216
- text_model_config, export_config
216
+ text_model_config
217
217
  )['decode']
218
218
  if has_dynamic_shape:
219
219
  print('Exporting decode_module...')
@@ -337,6 +337,55 @@ def export_embedder_model(
337
337
  return model_path
338
338
 
339
339
 
340
+ def export_auxiliary_model(
341
+ model,
342
+ text_model_config,
343
+ export_config: exportable_module.ExportableModuleConfig,
344
+ work_dir: str,
345
+ quantization_recipe: str | None = None,
346
+ ):
347
+ """Exports auxiliary model."""
348
+ del quantization_recipe # Unused.
349
+ converter = converter_utils.Converter()
350
+ # RoPE
351
+ rope_module = external_rope_module.RoPEEmbedder(model)
352
+ sample_inputs = rope_module.get_sample_inputs(
353
+ text_model_config, export_config
354
+ )
355
+ for signature_name, (sample_input, _) in sample_inputs.items():
356
+ converter.add_signature(
357
+ signature_name,
358
+ rope_module.eval(),
359
+ sample_kwargs=sample_input,
360
+ )
361
+ # Attention Mask
362
+ attention_mask_module = split_cache_module.SplitAttentionMaskBuilder(model)
363
+ sample_inputs = attention_mask_module.get_sample_inputs(
364
+ text_model_config, export_config
365
+ )
366
+ for signature_name, (sample_input, _) in sample_inputs.items():
367
+ converter.add_signature(
368
+ signature_name,
369
+ attention_mask_module.eval(),
370
+ sample_kwargs=sample_input,
371
+ )
372
+ # Cache Update
373
+ cache_update_module = split_cache_module.CacheUpdate(model)
374
+ sample_inputs = cache_update_module.get_sample_inputs(
375
+ text_model_config, export_config
376
+ )
377
+ for signature_name, (sample_input, _) in sample_inputs.items():
378
+ converter.add_signature(
379
+ signature_name,
380
+ cache_update_module.eval(),
381
+ sample_kwargs=sample_input,
382
+ )
383
+ lrt_model = converter.convert(strict_export=False)
384
+ model_path = os.path.join(work_dir, 'auxiliary.tflite')
385
+ lrt_model.export(model_path)
386
+ return model_path
387
+
388
+
340
389
  def export_tokenizer(
341
390
  tokenizer,
342
391
  work_dir: str,