ai-edge-torch-nightly 0.4.0.dev20250226__py3-none-any.whl → 0.4.0.dev20250228__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +9 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +25 -8
- ai_edge_torch/generative/layers/attention.py +0 -12
- ai_edge_torch/generative/layers/experimental/attention.py +0 -8
- ai_edge_torch/generative/layers/experimental/kv_cache.py +45 -31
- ai_edge_torch/generative/layers/kv_cache.py +8 -5
- ai_edge_torch/generative/layers/model_config.py +0 -3
- ai_edge_torch/generative/utilities/converter.py +14 -4
- ai_edge_torch/generative/utilities/model_builder.py +2 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250226.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250226.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/RECORD +15 -15
- {ai_edge_torch_nightly-0.4.0.dev20250226.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250226.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250226.dist-info → ai_edge_torch_nightly-0.4.0.dev20250228.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
-
from ai_edge_torch.generative.utilities
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
@@ -59,6 +59,11 @@ _LORA_RANKS = flags.DEFINE_multi_integer(
|
|
59
59
|
None,
|
60
60
|
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
61
|
)
|
62
|
+
_DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
63
|
+
'decode_batch_size',
|
64
|
+
1,
|
65
|
+
'The batch size for the decode signature.',
|
66
|
+
)
|
62
67
|
|
63
68
|
|
64
69
|
def main(_):
|
@@ -72,7 +77,9 @@ def main(_):
|
|
72
77
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
73
78
|
quantize=_QUANTIZE.value,
|
74
79
|
lora_ranks=_LORA_RANKS.value,
|
75
|
-
export_config=ExportConfig(
|
80
|
+
export_config=model_builder.ExportConfig(
|
81
|
+
decode_batch_size=_DECODE_BATCH_SIZE.value
|
82
|
+
),
|
76
83
|
)
|
77
84
|
|
78
85
|
|
@@ -22,17 +22,22 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
-
from ai_edge_torch.generative.utilities
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'smollm2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,6 +54,16 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
62
|
+
_DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
63
|
+
'decode_batch_size',
|
64
|
+
1,
|
65
|
+
'The batch size for the decode signature.',
|
66
|
+
)
|
52
67
|
|
53
68
|
|
54
69
|
def main(_):
|
@@ -56,14 +71,16 @@ def main(_):
|
|
56
71
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
72
|
)
|
58
73
|
|
59
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
-
output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
74
|
converter.convert_to_tflite(
|
62
75
|
pytorch_model,
|
63
|
-
|
76
|
+
output_path=_OUTPUT_PATH.value,
|
77
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
64
78
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
79
|
quantize=_QUANTIZE.value,
|
66
|
-
|
80
|
+
lora_ranks=_LORA_RANKS.value,
|
81
|
+
export_config=model_builder.ExportConfig(
|
82
|
+
decode_batch_size=_DECODE_BATCH_SIZE.value
|
83
|
+
),
|
67
84
|
)
|
68
85
|
|
69
86
|
|
@@ -48,7 +48,6 @@ class TransformerBlock(nn.Module):
|
|
48
48
|
config.pre_attention_norm_config,
|
49
49
|
)
|
50
50
|
self.atten_func = CausalSelfAttention(
|
51
|
-
model_config.batch_size,
|
52
51
|
model_config.embedding_dim,
|
53
52
|
config.attn_config,
|
54
53
|
model_config.enable_hlfb,
|
@@ -115,7 +114,6 @@ class CausalSelfAttention(nn.Module):
|
|
115
114
|
|
116
115
|
def __init__(
|
117
116
|
self,
|
118
|
-
batch_size: int,
|
119
117
|
dim: int,
|
120
118
|
config: cfg.AttentionConfig,
|
121
119
|
enable_hlfb: bool,
|
@@ -123,14 +121,12 @@ class CausalSelfAttention(nn.Module):
|
|
123
121
|
"""Initialize an instance of CausalSelfAttention.
|
124
122
|
|
125
123
|
Args:
|
126
|
-
batch_size (int): batch size of the input tensor.
|
127
124
|
dim (int): causal attention's input/output dimmension.
|
128
125
|
config (cfg.AttentionConfig): attention specific configurations.
|
129
126
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
130
127
|
"""
|
131
128
|
super().__init__()
|
132
129
|
self.kv_cache = None
|
133
|
-
self.batch_size = batch_size
|
134
130
|
qkv_shape = (
|
135
131
|
config.num_heads + 2 * config.num_query_groups
|
136
132
|
) * config.head_dim
|
@@ -179,11 +175,6 @@ class CausalSelfAttention(nn.Module):
|
|
179
175
|
"""
|
180
176
|
# Batch size, sequence length, embedding dimensionality.
|
181
177
|
B, T, E = x.size()
|
182
|
-
assert B == self.batch_size, (
|
183
|
-
"batch size of input tensor must match with the batch size specified in"
|
184
|
-
" the model configuration."
|
185
|
-
)
|
186
|
-
|
187
178
|
qkv = self.qkv_projection(x)
|
188
179
|
|
189
180
|
# Assemble into a number of query groups to support MHA, MQA and GQA.
|
@@ -290,7 +281,6 @@ class CrossAttention(nn.Module):
|
|
290
281
|
|
291
282
|
def __init__(
|
292
283
|
self,
|
293
|
-
batch_size: int,
|
294
284
|
query_dim: int,
|
295
285
|
cross_dim: int,
|
296
286
|
hidden_dim: int,
|
@@ -301,7 +291,6 @@ class CrossAttention(nn.Module):
|
|
301
291
|
"""Initialize an instance of CrossAttention.
|
302
292
|
|
303
293
|
Args:
|
304
|
-
batch_size (int): batch size of the input tensor.
|
305
294
|
query_dim (int): query tensor's dimension.
|
306
295
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
307
296
|
hidden_dim (int): hidden dimension that q, k, v tensors project to.
|
@@ -376,7 +365,6 @@ class CrossAttention(nn.Module):
|
|
376
365
|
|
377
366
|
if rope is not None:
|
378
367
|
# Compute rotary positional embedding for query and key.
|
379
|
-
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
380
368
|
cos, sin = rope
|
381
369
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
382
370
|
|
@@ -52,7 +52,6 @@ class TransformerBlock(nn.Module):
|
|
52
52
|
config.pre_attention_norm_config,
|
53
53
|
)
|
54
54
|
self.atten_func = CausalSelfAttention(
|
55
|
-
model_config.batch_size,
|
56
55
|
model_config.embedding_dim,
|
57
56
|
config.attn_config,
|
58
57
|
model_config.enable_hlfb,
|
@@ -119,7 +118,6 @@ class CausalSelfAttention(nn.Module):
|
|
119
118
|
|
120
119
|
def __init__(
|
121
120
|
self,
|
122
|
-
batch_size: int,
|
123
121
|
dim: int,
|
124
122
|
config: cfg.AttentionConfig,
|
125
123
|
enable_hlfb: bool,
|
@@ -127,14 +125,12 @@ class CausalSelfAttention(nn.Module):
|
|
127
125
|
"""Initialize an instance of CausalSelfAttention.
|
128
126
|
|
129
127
|
Args:
|
130
|
-
batch_size (int): batch size of the input tensor.
|
131
128
|
dim (int): causal attention's input/output dimmension.
|
132
129
|
config (cfg.AttentionConfig): attention specific configurations.
|
133
130
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
134
131
|
"""
|
135
132
|
super().__init__()
|
136
133
|
self.kv_cache = None
|
137
|
-
self.batch_size = batch_size
|
138
134
|
qkv_shape = (
|
139
135
|
config.num_heads + 2 * config.num_query_groups
|
140
136
|
) * config.head_dim
|
@@ -180,10 +176,6 @@ class CausalSelfAttention(nn.Module):
|
|
180
176
|
"""
|
181
177
|
# Batch size, sequence length, embedding dimensionality.
|
182
178
|
B, T, E = x.size()
|
183
|
-
assert B == self.batch_size, (
|
184
|
-
"batch size of input tensor must match with the batch size specified in"
|
185
|
-
" the model configuration."
|
186
|
-
)
|
187
179
|
|
188
180
|
qkv = self.qkv_projection(x)
|
189
181
|
|
@@ -21,23 +21,19 @@ This is an experimental implementation and is subject to change at any time.
|
|
21
21
|
import dataclasses
|
22
22
|
from typing import List, Tuple
|
23
23
|
|
24
|
-
from ai_edge_torch import hlfb
|
25
24
|
from ai_edge_torch.generative.layers import model_config
|
26
|
-
from ai_edge_torch.generative.layers.experimental import types
|
27
|
-
from ai_edge_torch.generative.utilities
|
25
|
+
from ai_edge_torch.generative.layers.experimental import types
|
26
|
+
from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
|
28
27
|
import torch
|
29
|
-
import torch.nn as nn
|
30
28
|
import torch.utils._pytree as pytree
|
31
29
|
|
32
|
-
BATCH_SIZE = 1
|
33
|
-
|
34
30
|
|
35
31
|
@dataclasses.dataclass
|
36
32
|
class KVCacheEntryBase:
|
37
33
|
"""A single cache entry that includes K and V caches.
|
38
34
|
|
39
35
|
The chaches are built based on the provided config with the shape of
|
40
|
-
(batch_size
|
36
|
+
(batch_size, kv_cache_max, num_query_groups, head_dim).
|
41
37
|
"""
|
42
38
|
|
43
39
|
k_cache: torch.Tensor
|
@@ -46,10 +42,8 @@ class KVCacheEntryBase:
|
|
46
42
|
@classmethod
|
47
43
|
def _from_model_config(
|
48
44
|
cls,
|
49
|
-
|
50
|
-
|
51
|
-
k_shape: Tuple,
|
52
|
-
v_shape: Tuple,
|
45
|
+
k_shape: Tuple[int, ...],
|
46
|
+
v_shape: Tuple[int, ...],
|
53
47
|
dtype: torch.dtype = torch.float32,
|
54
48
|
device: torch.device = None,
|
55
49
|
) -> "KVCacheEntryBase":
|
@@ -66,12 +60,11 @@ class KVCacheEntryBase:
|
|
66
60
|
config: model_config.AttentionConfig,
|
67
61
|
dtype: torch.dtype = torch.float32,
|
68
62
|
device: torch.device = None,
|
63
|
+
batch_size: int = 1,
|
69
64
|
) -> "KVCacheEntryBase":
|
70
65
|
"""Build an instance of the class based on model config."""
|
71
|
-
shape = (
|
72
|
-
return cls._from_model_config(
|
73
|
-
kv_cache_max, config, shape, shape, dtype, device
|
74
|
-
)
|
66
|
+
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
|
67
|
+
return cls._from_model_config(shape, shape, dtype, device)
|
75
68
|
|
76
69
|
|
77
70
|
@dataclasses.dataclass
|
@@ -93,24 +86,22 @@ class KVCacheEntryTransposed(KVCacheEntryBase):
|
|
93
86
|
config: model_config.AttentionConfig,
|
94
87
|
dtype: torch.dtype = torch.float32,
|
95
88
|
device: torch.device = None,
|
89
|
+
batch_size: int = 1,
|
96
90
|
) -> "KVCacheEntryBase":
|
97
91
|
"""Build an instance of the class based on model config."""
|
98
|
-
num_kv_heads = config.num_query_groups
|
99
92
|
k_shape = (
|
100
|
-
|
101
|
-
|
93
|
+
batch_size,
|
94
|
+
config.num_query_groups,
|
102
95
|
kv_cache_max,
|
103
96
|
config.head_dim,
|
104
|
-
) #
|
97
|
+
) # b, k, s, h
|
105
98
|
v_shape = (
|
106
|
-
|
107
|
-
|
99
|
+
batch_size,
|
100
|
+
config.num_query_groups,
|
108
101
|
config.head_dim,
|
109
102
|
kv_cache_max,
|
110
|
-
) #
|
111
|
-
return cls._from_model_config(
|
112
|
-
kv_cache_max, config, k_shape, v_shape, dtype, device
|
113
|
-
)
|
103
|
+
) # b, k, h, s
|
104
|
+
return cls._from_model_config(k_shape, v_shape, dtype, device)
|
114
105
|
|
115
106
|
|
116
107
|
@dataclasses.dataclass
|
@@ -126,6 +117,7 @@ class KVCacheBase:
|
|
126
117
|
config: model_config.ModelConfig,
|
127
118
|
dtype: torch.dtype = torch.float32,
|
128
119
|
device: torch.device = None,
|
120
|
+
batch_size: int = 1,
|
129
121
|
) -> "KVCacheBase":
|
130
122
|
caches = [
|
131
123
|
kv_entry_cls.from_model_config(
|
@@ -133,6 +125,7 @@ class KVCacheBase:
|
|
133
125
|
config.block_config(idx).attn_config,
|
134
126
|
dtype,
|
135
127
|
device,
|
128
|
+
batch_size,
|
136
129
|
)
|
137
130
|
for idx in range(config.num_layers)
|
138
131
|
]
|
@@ -145,6 +138,7 @@ class KVCacheBase:
|
|
145
138
|
config: model_config.ModelConfig,
|
146
139
|
dtype: torch.dtype = torch.float32,
|
147
140
|
device: torch.device = None,
|
141
|
+
batch_size: int = 1,
|
148
142
|
) -> "KVCacheBase":
|
149
143
|
"""Build an instance of the class based on model config.
|
150
144
|
|
@@ -154,12 +148,19 @@ class KVCacheBase:
|
|
154
148
|
Defaults to torch.float32.
|
155
149
|
device (torch.device, optional): The device placement of the cache
|
156
150
|
tensors. Defaults to None.
|
151
|
+
batch_size (int, optional): The batch size of the cache tensors.
|
152
|
+
Defaults to 1.
|
157
153
|
|
158
154
|
Returns:
|
159
155
|
KVCacheBase: The created cache object.
|
160
156
|
"""
|
157
|
+
assert batch_size == 1, "Batch size must be 1 for KV Cache."
|
161
158
|
return cls._from_model_config(
|
162
|
-
KVCacheEntryBase,
|
159
|
+
KVCacheEntryBase,
|
160
|
+
config=config,
|
161
|
+
dtype=dtype,
|
162
|
+
device=device,
|
163
|
+
batch_size=batch_size,
|
163
164
|
)
|
164
165
|
|
165
166
|
def flatten(self) -> List[torch.Tensor]:
|
@@ -177,9 +178,14 @@ class KVCacheBTNH(KVCacheBase):
|
|
177
178
|
config: model_config.ModelConfig,
|
178
179
|
dtype: torch.dtype = torch.float32,
|
179
180
|
device: torch.device = None,
|
181
|
+
batch_size: int = 1,
|
180
182
|
) -> "KVCacheBTNH":
|
181
183
|
return cls._from_model_config(
|
182
|
-
KVCacheEntryBTNH,
|
184
|
+
KVCacheEntryBTNH,
|
185
|
+
config=config,
|
186
|
+
dtype=dtype,
|
187
|
+
device=device,
|
188
|
+
batch_size=batch_size,
|
183
189
|
)
|
184
190
|
|
185
191
|
|
@@ -192,9 +198,14 @@ class KVCacheTransposed(KVCacheBase):
|
|
192
198
|
config: model_config.ModelConfig,
|
193
199
|
dtype: torch.dtype = torch.float32,
|
194
200
|
device: torch.device = None,
|
201
|
+
batch_size: int = 1,
|
195
202
|
) -> "KVCacheBTNH":
|
196
203
|
return cls._from_model_config(
|
197
|
-
KVCacheEntryTransposed,
|
204
|
+
KVCacheEntryTransposed,
|
205
|
+
config=config,
|
206
|
+
dtype=dtype,
|
207
|
+
device=device,
|
208
|
+
batch_size=batch_size,
|
198
209
|
)
|
199
210
|
|
200
211
|
|
@@ -258,7 +269,6 @@ def update(
|
|
258
269
|
input_pos: torch.Tensor,
|
259
270
|
k_slice: torch.Tensor,
|
260
271
|
v_slice: torch.Tensor,
|
261
|
-
use_dus: bool = True,
|
262
272
|
) -> KVCacheEntryBase:
|
263
273
|
"""Out of place update of Cache buffer.
|
264
274
|
|
@@ -309,6 +319,10 @@ def _update_kv_impl(
|
|
309
319
|
positions = input_pos.clone()
|
310
320
|
k_slice_indices = _get_slice_indices(positions, cache_dim, k_ts_idx)
|
311
321
|
v_slice_indices = _get_slice_indices(positions, cache_dim, v_ts_idx)
|
312
|
-
k = dynamic_update_slice(
|
313
|
-
|
322
|
+
k = dus_utils.dynamic_update_slice(
|
323
|
+
cache.k_cache, k_slice, [x for x in k_slice_indices]
|
324
|
+
)
|
325
|
+
v = dus_utils.dynamic_update_slice(
|
326
|
+
cache.v_cache, v_slice, [x for x in v_slice_indices]
|
327
|
+
)
|
314
328
|
return KVCacheEntryTransposed(k, v)
|
@@ -18,14 +18,11 @@
|
|
18
18
|
import dataclasses
|
19
19
|
from typing import List, Tuple
|
20
20
|
|
21
|
-
from ai_edge_torch import hlfb
|
22
21
|
from ai_edge_torch.generative.layers import model_config
|
23
22
|
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
|
24
23
|
import torch
|
25
24
|
import torch.utils._pytree as pytree
|
26
25
|
|
27
|
-
BATCH_SIZE = 1
|
28
|
-
|
29
26
|
|
30
27
|
@dataclasses.dataclass
|
31
28
|
class KVCacheEntry:
|
@@ -45,9 +42,10 @@ class KVCacheEntry:
|
|
45
42
|
config: model_config.AttentionConfig,
|
46
43
|
dtype: torch.dtype = torch.float32,
|
47
44
|
device: torch.device = None,
|
45
|
+
batch_size: int = 1,
|
48
46
|
) -> "KVCacheEntry":
|
49
47
|
"""Build an instance of the class based on model config."""
|
50
|
-
shape = (
|
48
|
+
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
|
51
49
|
k = torch.zeros(shape, dtype=dtype, device=device)
|
52
50
|
v = torch.zeros(shape, dtype=dtype, device=device)
|
53
51
|
obj = cls(k_cache=k, v_cache=v)
|
@@ -66,6 +64,7 @@ class KVCache:
|
|
66
64
|
config: model_config.ModelConfig,
|
67
65
|
dtype: torch.dtype = torch.float32,
|
68
66
|
device: torch.device = None,
|
67
|
+
batch_size: int = 1,
|
69
68
|
) -> "KVCache":
|
70
69
|
"""Build an instance of the class based on model config.
|
71
70
|
|
@@ -75,17 +74,21 @@ class KVCache:
|
|
75
74
|
Defaults to torch.float32.
|
76
75
|
device (torch.device, optional): The device placement of the cache
|
77
76
|
tensors. Defaults to None.
|
77
|
+
batch_size (int, optional): The batch size of the cache tensors.
|
78
|
+
Defaults to 1.
|
78
79
|
|
79
80
|
Returns:
|
80
81
|
KVCache: The created cache object.
|
81
82
|
"""
|
82
83
|
caches = [
|
83
84
|
KVCacheEntry.from_model_config(
|
84
|
-
config.kv_cache_max
|
85
|
+
config.kv_cache_max
|
86
|
+
if not config.block_config(idx).kv_cache_max_len
|
85
87
|
else config.block_config(idx).kv_cache_max_len,
|
86
88
|
config.block_config(idx).attn_config,
|
87
89
|
dtype,
|
88
90
|
device,
|
91
|
+
batch_size,
|
89
92
|
)
|
90
93
|
for idx in range(config.num_layers)
|
91
94
|
]
|
@@ -220,9 +220,6 @@ class ModelConfig:
|
|
220
220
|
# The maximum sequence length of the KV cache. Should not exceed max_seq_len.
|
221
221
|
kv_cache_max_len: int = 0
|
222
222
|
|
223
|
-
# Default batch size of the exported model. Default value is 1.
|
224
|
-
batch_size: int = 1
|
225
|
-
|
226
223
|
# Softcap on the model output logits.
|
227
224
|
final_logit_softcap: Optional[float] = None
|
228
225
|
|
@@ -110,6 +110,11 @@ def convert_to_tflite(
|
|
110
110
|
lora_suffix = (
|
111
111
|
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
|
112
112
|
)
|
113
|
+
|
114
|
+
if export_config is not None:
|
115
|
+
if export_config.decode_batch_size > 1:
|
116
|
+
output_name_prefix += f'_dbs{export_config.decode_batch_size}'
|
117
|
+
|
113
118
|
output_filename = (
|
114
119
|
f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
|
115
120
|
)
|
@@ -162,9 +167,14 @@ def _export_helper(
|
|
162
167
|
if prefill_masks:
|
163
168
|
assert len(prefill_masks) == len(prefill_seq_lens)
|
164
169
|
|
165
|
-
decode_token = torch.tensor(
|
170
|
+
decode_token = torch.tensor(
|
171
|
+
[[0] for _ in range(export_config.decode_batch_size)], dtype=torch.int
|
172
|
+
)
|
166
173
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
167
|
-
|
174
|
+
prefill_kv = export_config.kvcache_cls.from_model_config(config)
|
175
|
+
decode_kv = export_config.kvcache_cls.from_model_config(
|
176
|
+
config, batch_size=export_config.decode_batch_size
|
177
|
+
)
|
168
178
|
|
169
179
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
170
180
|
|
@@ -183,7 +193,7 @@ def _export_helper(
|
|
183
193
|
sample_kwargs = {
|
184
194
|
'tokens': prefill_tokens,
|
185
195
|
'input_pos': prefill_input_pos,
|
186
|
-
'kv_cache':
|
196
|
+
'kv_cache': prefill_kv,
|
187
197
|
}
|
188
198
|
if prefill_masks is not None:
|
189
199
|
sample_kwargs['mask'] = prefill_masks[i]
|
@@ -211,7 +221,7 @@ def _export_helper(
|
|
211
221
|
sample_kwargs = {
|
212
222
|
'tokens': decode_token,
|
213
223
|
'input_pos': decode_input_pos,
|
214
|
-
'kv_cache':
|
224
|
+
'kv_cache': decode_kv,
|
215
225
|
}
|
216
226
|
if export_config.decode_mask is not None:
|
217
227
|
sample_kwargs['mask'] = export_config.decode_mask
|
@@ -60,6 +60,8 @@ class ExportConfig:
|
|
60
60
|
decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
61
61
|
# The KV Cache class for K and V buffers in attention.
|
62
62
|
kvcache_cls: type = kv_utils.KVCache
|
63
|
+
# The batch size of the decode signature.
|
64
|
+
decode_batch_size: int = 1
|
63
65
|
|
64
66
|
|
65
67
|
class DecoderOnlyModel(nn.Module):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250228
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256
|
5
|
+
ai_edge_torch/version.py,sha256=-EqWeDLQh8HxiqQxA-N-t0YXsYU9QT1iaq2h-kCDBdo,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -102,8 +102,8 @@ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e
|
|
102
102
|
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
103
103
|
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
|
104
104
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
105
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
106
|
-
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=
|
105
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=hWko-RJB8eXNUfi4EzQ2yjW30YE4UB4zAz7rd2Q5qpg,2708
|
106
|
+
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=sJ-o385eqQsciv0TEQRkixvS0DD6dKruAuK0zlEsDoY,2715
|
107
107
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
|
108
108
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
|
109
109
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -136,19 +136,19 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1t
|
|
136
136
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
|
137
137
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
138
138
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
139
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
139
|
+
ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDqYuGOZdYAI9EooOM,13247
|
140
140
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
141
141
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
142
142
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
143
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
143
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=jwbt0-2fd_CNWS2fp4nf0zvh6kk5citINGlFC_RtEUU,6540
|
144
144
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
145
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
145
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=wNsZDzZQoimOKdZ9FWMCktPj-pQ_0D7084hgzMT5XYo,8155
|
146
146
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
147
147
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
148
148
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
149
149
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
150
|
-
ai_edge_torch/generative/layers/experimental/attention.py,sha256=
|
151
|
-
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=
|
150
|
+
ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
|
151
|
+
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=VN4gn4ylaVOwaTR5EXKv0YTVgpQ850bmjGLCgCCI1ps,9267
|
152
152
|
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
|
153
153
|
ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
|
154
154
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -173,10 +173,10 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728Fc
|
|
173
173
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
174
174
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
175
175
|
ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
|
176
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
176
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=VtG42CVz657XbvTj-FZJiCFW0Hm11OVKKC_mr2tjxhc,8413
|
177
177
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
178
178
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
179
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
179
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=eY3qAcBhupIn955YnWuzUi9hoWYvl4ntRWA6PBudzMo,6888
|
180
180
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
181
181
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
182
182
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
230
230
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
231
231
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
232
232
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
233
|
-
ai_edge_torch_nightly-0.4.0.
|
234
|
-
ai_edge_torch_nightly-0.4.0.
|
235
|
-
ai_edge_torch_nightly-0.4.0.
|
236
|
-
ai_edge_torch_nightly-0.4.0.
|
237
|
-
ai_edge_torch_nightly-0.4.0.
|
233
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
234
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/METADATA,sha256=oGVZ_Z3zOzdyxj4cJ5XTT-YzPpTa99SBgFJo5zUBqJU,1966
|
235
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
236
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
237
|
+
ai_edge_torch_nightly-0.4.0.dev20250228.dist-info/RECORD,,
|
File without changes
|
File without changes
|