ai-edge-torch-nightly 0.3.0.dev20241216__py3-none-any.whl → 0.3.0.dev20241220__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +5 -1
- ai_edge_torch/_convert/converter.py +8 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +14 -15
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/layers/attention.py +4 -29
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -27
- ai_edge_torch/generative/utilities/model_builder.py +11 -12
- ai_edge_torch/lowertools/_shim.py +4 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +4 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +6 -0
- ai_edge_torch/odml_torch/export.py +4 -0
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +5 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/RECORD +18 -18
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241216.dist-info → ai_edge_torch_nightly-0.3.0.dev20241220.dist-info}/top_level.txt +0 -0
@@ -78,7 +78,8 @@ def convert_signatures(
|
|
78
78
|
*,
|
79
79
|
strict_export: Union[Literal["auto"], bool] = True,
|
80
80
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
81
|
-
_tfl_converter_flags: Optional[dict[str, Any]],
|
81
|
+
_tfl_converter_flags: Optional[dict[str, Any]] = None,
|
82
|
+
_saved_model_dir: Optional[str] = None,
|
82
83
|
) -> model.TfLiteModel:
|
83
84
|
"""Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
|
84
85
|
|
@@ -93,6 +94,8 @@ def convert_signatures(
|
|
93
94
|
quant_config: User-defined quantization method and scheme of the model.
|
94
95
|
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
95
96
|
underlying tflite converter.
|
97
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
98
|
+
specified, a random temporary directory would be used.
|
96
99
|
|
97
100
|
Returns:
|
98
101
|
The converted `model.TfLiteModel` object.
|
@@ -140,6 +143,7 @@ def convert_signatures(
|
|
140
143
|
signatures,
|
141
144
|
quant_config=quant_config,
|
142
145
|
_tfl_converter_flags=_tfl_converter_flags,
|
146
|
+
_saved_model_dir=_saved_model_dir,
|
143
147
|
)
|
144
148
|
|
145
149
|
return model.TfLiteModel(tflite_model)
|
@@ -106,6 +106,7 @@ class Converter:
|
|
106
106
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
107
107
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
108
108
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
109
|
+
_saved_model_dir: Optional[str] = None,
|
109
110
|
) -> model.TfLiteModel:
|
110
111
|
"""Finalizes the conversion and produces an edge model.
|
111
112
|
|
@@ -139,6 +140,8 @@ class Converter:
|
|
139
140
|
of this function and so needs to be treated as such. Please do not rely
|
140
141
|
on this parameter except for local debugging as this can be removed in a
|
141
142
|
future release.
|
143
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
144
|
+
specified, a random temporary directory would be used.
|
142
145
|
|
143
146
|
Returns:
|
144
147
|
The converted edge model.
|
@@ -171,6 +174,7 @@ class Converter:
|
|
171
174
|
strict_export=strict_export,
|
172
175
|
quant_config=quant_config,
|
173
176
|
_tfl_converter_flags=_ai_edge_converter_flags,
|
177
|
+
_saved_model_dir=_saved_model_dir,
|
174
178
|
)
|
175
179
|
|
176
180
|
|
@@ -216,6 +220,7 @@ def convert(
|
|
216
220
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
217
221
|
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
218
222
|
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
223
|
+
_saved_model_dir: Optional[str] = None,
|
219
224
|
) -> model.TfLiteModel:
|
220
225
|
"""Converts a PyTorch model to an edge model with a default signature.
|
221
226
|
|
@@ -240,6 +245,8 @@ def convert(
|
|
240
245
|
this function and so needs to be treated as such. Please do not rely on
|
241
246
|
this parameter except for local debugging as this can be removed in a
|
242
247
|
future release.
|
248
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
249
|
+
specified, a random temporary directory would be used.
|
243
250
|
|
244
251
|
Returns:
|
245
252
|
The converted edge model.
|
@@ -259,4 +266,5 @@ def convert(
|
|
259
266
|
quant_config=quant_config,
|
260
267
|
dynamic_shapes=dynamic_shapes,
|
261
268
|
_ai_edge_converter_flags=_ai_edge_converter_flags,
|
269
|
+
_saved_model_dir=_saved_model_dir,
|
262
270
|
)
|
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.layers import builder
|
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
25
26
|
from ai_edge_torch.generative.utilities import model_builder
|
26
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
28
|
import torch
|
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
|
|
103
104
|
config.embedding_dim,
|
104
105
|
config.final_norm_config,
|
105
106
|
)
|
106
|
-
# Gemma2 has same hyper parameters for each layer except for attention
|
107
|
-
# types. Use the first layer.
|
108
|
-
attn_config = config.block_config(0).attn_config
|
109
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
-
size=config.kv_cache_max,
|
111
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
-
base=attn_config.rotary_base,
|
113
|
-
)
|
114
107
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
108
|
size=config.kv_cache_max,
|
116
109
|
)
|
110
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
111
|
+
# types. Use the first layer.
|
112
|
+
attn_config = config.block_config(0).attn_config
|
117
113
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
118
114
|
size=config.kv_cache_max,
|
119
115
|
window_size=attn_config.sliding_window_size,
|
@@ -145,24 +141,27 @@ class Gemma2(nn.Module):
|
|
145
141
|
" must be the same."
|
146
142
|
)
|
147
143
|
|
148
|
-
|
149
|
-
|
150
|
-
|
144
|
+
# RoPE parameters are the same for all blocks. Use the first layer.
|
145
|
+
attn_config = self.config.block_config(0).attn_config
|
146
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
147
|
+
rope = rotary_pos_emb.build_rope(
|
148
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
149
|
+
)
|
151
150
|
|
152
151
|
# token embeddings of shape (b, t, n_embd)
|
153
152
|
x = self.tok_embedding(tokens)
|
154
153
|
x = x * (self.config.embedding_dim**0.5)
|
155
154
|
|
156
|
-
|
155
|
+
updated_kv_entries = []
|
157
156
|
for i, block in enumerate(self.transformer_blocks):
|
158
157
|
mask = self.get_attention_mask(
|
159
158
|
block.config.attn_config.attn_type, input_pos
|
160
159
|
)
|
161
160
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
162
|
-
x, kv_entry = block(x,
|
161
|
+
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
|
163
162
|
if kv_entry:
|
164
|
-
|
165
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
163
|
+
updated_kv_entries.append(kv_entry)
|
164
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
166
165
|
|
167
166
|
if export_config is not None:
|
168
167
|
if (
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entries = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entries.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -26,33 +26,6 @@ import torch
|
|
26
26
|
from torch import nn
|
27
27
|
|
28
28
|
|
29
|
-
def _embed_rope(
|
30
|
-
q: torch.Tensor,
|
31
|
-
k: torch.Tensor,
|
32
|
-
n_elem: int,
|
33
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
34
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
35
|
-
"""Embed rotary positional embedding for query and key.
|
36
|
-
|
37
|
-
Args:
|
38
|
-
q (torch.Tensor): query tensor.
|
39
|
-
k (torch.Tensor): key tensor.
|
40
|
-
n_elem (int): number of elements to embed rotarty positional embedding.
|
41
|
-
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
42
|
-
"""
|
43
|
-
if n_elem > 0:
|
44
|
-
cos, sin = rope
|
45
|
-
q_roped = rotary_pos_emb.apply_rope(
|
46
|
-
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
47
|
-
)
|
48
|
-
k_roped = rotary_pos_emb.apply_rope(
|
49
|
-
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
50
|
-
)
|
51
|
-
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
52
|
-
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
53
|
-
return q, k
|
54
|
-
|
55
|
-
|
56
29
|
class TransformerBlock(nn.Module):
|
57
30
|
|
58
31
|
def __init__(
|
@@ -238,7 +211,8 @@ class CausalSelfAttention(nn.Module):
|
|
238
211
|
if rope is not None:
|
239
212
|
# Compute rotary positional embedding for query and key.
|
240
213
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
241
|
-
|
214
|
+
cos, sin = rope
|
215
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
242
216
|
|
243
217
|
if kv_cache is not None:
|
244
218
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -374,7 +348,8 @@ class CrossAttention(nn.Module):
|
|
374
348
|
if rope is not None:
|
375
349
|
# Compute rotary positional embedding for query and key.
|
376
350
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
377
|
-
|
351
|
+
cos, sin = rope
|
352
|
+
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
378
353
|
|
379
354
|
if kv_cache is not None:
|
380
355
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -32,57 +32,64 @@ def apply_rope(
|
|
32
32
|
"""
|
33
33
|
x = x.transpose(1, 2)
|
34
34
|
head_size = x.size(-1)
|
35
|
-
x1 = x
|
36
|
-
|
37
|
-
|
38
|
-
roped = (
|
35
|
+
x1, x2 = torch.split(x, head_size // 2, dim=-1)
|
36
|
+
left = x1 * cos - x2 * sin
|
37
|
+
right = x2 * cos + x1 * sin
|
38
|
+
roped = torch.cat([left, right], dim=-1)
|
39
39
|
return roped.transpose(1, 2).type_as(x)
|
40
40
|
|
41
41
|
|
42
|
-
def
|
43
|
-
q: torch.Tensor,
|
44
|
-
k: torch.Tensor,
|
42
|
+
def build_rope(
|
45
43
|
input_pos: torch.Tensor,
|
46
44
|
n_elem: int,
|
45
|
+
head_dim: int,
|
47
46
|
base: int = 10_000,
|
48
47
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49
|
-
"""Computes rotary positional embedding
|
48
|
+
"""Computes rotary positional embedding cosine and sine tensors.
|
50
49
|
|
51
50
|
Args:
|
52
|
-
q: the query tensor.
|
53
|
-
k: the key tensor.
|
54
51
|
input_pos: the sequence indices for the query and key
|
55
52
|
n_elem: number of elements of the head dimension for RoPE computation
|
53
|
+
base: the base of the exponentiated value for RoPE.
|
56
54
|
|
57
55
|
Returns:
|
58
|
-
|
56
|
+
cos, sin tensors
|
59
57
|
"""
|
60
58
|
|
61
59
|
if n_elem <= 0:
|
62
|
-
return
|
60
|
+
return None, None
|
63
61
|
|
64
62
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
65
63
|
freq_exponents = (2.0 / n_elem) * torch.arange(
|
66
|
-
|
64
|
+
head_dim // 2, dtype=torch.float32
|
67
65
|
)
|
68
66
|
timescale = float(base) ** freq_exponents
|
69
67
|
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
70
68
|
0
|
71
69
|
).unsqueeze(0)
|
72
|
-
cos = torch.cos(radians)
|
73
|
-
sin = torch.sin(radians)
|
70
|
+
cos = torch.cos(radians)
|
71
|
+
sin = torch.sin(radians)
|
72
|
+
return cos, sin
|
73
|
+
|
74
74
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
75
|
+
def apply_rope_inline(
|
76
|
+
q: torch.Tensor,
|
77
|
+
k: torch.Tensor,
|
78
|
+
cos: torch.Tensor,
|
79
|
+
sin: torch.Tensor,
|
80
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
81
|
+
"""Computes rotary positional embedding inline for a query and key.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
q: the query tensor.
|
85
|
+
k: the key tensor.
|
86
|
+
cos: the cosine tensor.
|
87
|
+
sin: the sine tensor.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
output the RoPE'd query and key.
|
91
|
+
"""
|
85
92
|
|
86
|
-
q_roped =
|
87
|
-
k_roped =
|
93
|
+
q_roped = apply_rope(q, cos, sin)
|
94
|
+
k_roped = apply_rope(k, cos, sin)
|
88
95
|
return q_roped, k_roped
|
@@ -24,6 +24,7 @@ from ai_edge_torch.generative.layers import builder
|
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
27
28
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
29
|
import torch
|
29
30
|
from torch import nn
|
@@ -85,13 +86,6 @@ class DecoderOnlyModel(nn.Module):
|
|
85
86
|
config.embedding_dim,
|
86
87
|
config.final_norm_config,
|
87
88
|
)
|
88
|
-
# ROPE parameters for all attn_configs are the same. Take the first one.
|
89
|
-
attn_config = config.block_config(0).attn_config
|
90
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
91
|
-
size=config.kv_cache_max,
|
92
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
93
|
-
base=attn_config.rotary_base,
|
94
|
-
)
|
95
89
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
96
90
|
size=config.kv_cache_max,
|
97
91
|
)
|
@@ -113,11 +107,16 @@ class DecoderOnlyModel(nn.Module):
|
|
113
107
|
|
114
108
|
# token embeddings of shape (b, t, n_embd)
|
115
109
|
input_embeds = self.tok_embedding(tokens)
|
116
|
-
cos, sin = self.rope_cache
|
117
|
-
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
|
118
110
|
mask = self.mask_cache.index_select(2, input_pos)
|
119
111
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
120
112
|
|
113
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
114
|
+
attn_config = self.config.block_config(0).attn_config
|
115
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
116
|
+
rope = rotary_pos_emb.build_rope(
|
117
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
118
|
+
)
|
119
|
+
|
121
120
|
return self.forward_with_embeds(
|
122
121
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
123
122
|
)
|
@@ -141,13 +140,13 @@ class DecoderOnlyModel(nn.Module):
|
|
141
140
|
if self.config.embedding_scale is not None:
|
142
141
|
x = x * self.config.embedding_scale
|
143
142
|
|
144
|
-
|
143
|
+
updated_kv_entries = []
|
145
144
|
for i, block in enumerate(self.transformer_blocks):
|
146
145
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
147
146
|
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
|
148
147
|
if kv_entry:
|
149
|
-
|
150
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
148
|
+
updated_kv_entries.append(kv_entry)
|
149
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
151
150
|
|
152
151
|
if export_config is not None:
|
153
152
|
if (
|
@@ -50,6 +50,7 @@ def exported_programs_to_tflite(
|
|
50
50
|
*,
|
51
51
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
52
52
|
_tfl_converter_flags: Optional[dict[str, Any]] = None,
|
53
|
+
_saved_model_dir: Optional[str] = None
|
53
54
|
):
|
54
55
|
"""Converts a list of ExportedProgram to a TFLite model.
|
55
56
|
|
@@ -57,6 +58,8 @@ def exported_programs_to_tflite(
|
|
57
58
|
exported_programs: A list of ExportedProgram.
|
58
59
|
signatures: A list of Signature.
|
59
60
|
quant_config: A QuantConfig.
|
61
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
62
|
+
specified, a random temporary directory would be used.
|
60
63
|
_tfl_converter_flags: A dict of flags for TFLiteConverter.
|
61
64
|
|
62
65
|
Returns:
|
@@ -79,4 +82,5 @@ def exported_programs_to_tflite(
|
|
79
82
|
signatures,
|
80
83
|
quant_config=quant_config,
|
81
84
|
_tfl_converter_flags=_tfl_converter_flags,
|
85
|
+
_saved_model_dir=_saved_model_dir,
|
82
86
|
)
|
@@ -138,6 +138,7 @@ def merged_bundle_to_tfl_model(
|
|
138
138
|
*,
|
139
139
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
140
140
|
_tfl_converter_flags: dict = {},
|
141
|
+
_saved_model_dir: Optional[str] = None,
|
141
142
|
):
|
142
143
|
tf_state_dict = merged_bundle.bundles[0].state_dict
|
143
144
|
|
@@ -173,6 +174,9 @@ def merged_bundle_to_tfl_model(
|
|
173
174
|
# We need to temporarily save since TFLite's from_concrete_functions does not
|
174
175
|
# allow providing names for each of the concrete functions.
|
175
176
|
with tempfile.TemporaryDirectory() as temp_dir_path:
|
177
|
+
if _saved_model_dir is not None:
|
178
|
+
temp_dir_path = _saved_model_dir
|
179
|
+
|
176
180
|
tf.saved_model.save(
|
177
181
|
tf_module,
|
178
182
|
temp_dir_path,
|
@@ -192,6 +192,7 @@ def merged_bundle_to_tfl_model(
|
|
192
192
|
*,
|
193
193
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
194
194
|
_tfl_converter_flags: dict = {},
|
195
|
+
_saved_model_dir: Optional[str] = None,
|
195
196
|
) -> None:
|
196
197
|
"""Converts a StableHLOGraphModule to a tflite model.
|
197
198
|
|
@@ -200,6 +201,8 @@ def merged_bundle_to_tfl_model(
|
|
200
201
|
signatures: List of signatures from which names of the signatures is
|
201
202
|
extracted.
|
202
203
|
quant_config: User-defined quantization method and scheme of the model.
|
204
|
+
_saved_model_dir: Directory for the intermediate saved model. If not
|
205
|
+
specified, a random temporary directory would be used.
|
203
206
|
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
204
207
|
underlying tflite converter.
|
205
208
|
"""
|
@@ -246,6 +249,9 @@ def merged_bundle_to_tfl_model(
|
|
246
249
|
# We need to temporarily save since TFLite's from_concrete_functions does not
|
247
250
|
# allow providing names for each of the concrete functions.
|
248
251
|
with tempfile.TemporaryDirectory() as temp_dir_path:
|
252
|
+
if _saved_model_dir is not None:
|
253
|
+
temp_dir_path = _saved_model_dir
|
254
|
+
|
249
255
|
tf.saved_model.save(
|
250
256
|
tf_module,
|
251
257
|
temp_dir_path,
|
@@ -304,9 +304,13 @@ def exported_program_to_mlir(
|
|
304
304
|
)
|
305
305
|
|
306
306
|
_convert_i64_to_i32(exported_program)
|
307
|
+
|
307
308
|
exported_program = _torch_future.safe_run_decompositions(
|
308
309
|
exported_program, lowerings.decompositions()
|
309
310
|
)
|
311
|
+
|
312
|
+
# Passes below mutate the exported program to a state not executable by torch.
|
313
|
+
# Do not call run_decompositions after applying the passes.
|
310
314
|
_convert_q_dq_per_channel_args_to_list(exported_program)
|
311
315
|
|
312
316
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
@@ -52,10 +52,13 @@ def _uniform_quantized_type(
|
|
52
52
|
assert isinstance(scale, (list, tuple))
|
53
53
|
assert isinstance(zero_point, (list, tuple))
|
54
54
|
|
55
|
+
scale = list(scale)
|
56
|
+
zero_point = list(zero_point)
|
57
|
+
|
55
58
|
if len(scale) == 1:
|
56
|
-
scale
|
59
|
+
scale = scale * channel_axis_size
|
57
60
|
if len(zero_point) == 1:
|
58
|
-
zero_point
|
61
|
+
zero_point = zero_point * channel_axis_size
|
59
62
|
|
60
63
|
assert len(scale) == len(zero_point) == channel_axis_size
|
61
64
|
scale_zp_strs = []
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241220
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,11 +3,11 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=xD-MWAEa1ROHhyF3rY7MaL28xsuON0aJwaiXbJ04qfc,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
|
-
ai_edge_torch/_convert/converter.py,sha256=
|
10
|
+
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
|
48
48
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
|
49
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=N0jKVZA3qWKOaHVbIM3WmQh3u0Sq7MTw_oO3Zo16wCw,3456
|
50
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=roEwWVXASbk5BFj7jojjEJpHui6gCelT51l-TtN_ZaQ,9367
|
51
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
53
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
@@ -107,7 +107,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
|
|
107
107
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
108
108
|
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
109
109
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
|
110
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
110
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=C9dzJFK3TybxKpM1vSdLjOKftkJ72DGjr8YR4H7vCe8,4664
|
111
111
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
112
112
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
|
113
113
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
@@ -115,14 +115,14 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
115
115
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
116
116
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
117
117
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
118
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
118
|
+
ai_edge_torch/generative/layers/attention.py,sha256=_OmamS3f0m_JtW73ljwGLwFPeMLL837JCLY-dJ3iRUg,12453
|
119
119
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
120
120
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
121
121
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
122
122
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
|
123
123
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
124
124
|
ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
|
125
|
-
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=
|
125
|
+
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=zbFTNgQdOT-tcKK1QaIX6fG-50syYwQX_ZbLhg2C98c,2691
|
126
126
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
127
127
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
128
128
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
@@ -147,7 +147,7 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
|
|
147
147
|
ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
|
148
148
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
149
149
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
150
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
150
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=q82-1E2zYlzpbFW6Vw-MWrJivRXHKpRh8jUxpR-w0sY,6349
|
151
151
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
152
152
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
153
153
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
@@ -160,16 +160,16 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNgh
|
|
160
160
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
161
161
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
|
162
162
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
163
|
-
ai_edge_torch/lowertools/_shim.py,sha256=
|
163
|
+
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
164
164
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
165
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
165
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
|
166
166
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
167
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=
|
167
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=tH5BW8-Up1uy5Iq1LdXiJInXBh4-YqNXJpSwwy3kwSg,9460
|
168
168
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
169
169
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
170
170
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
171
171
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
172
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
172
|
+
ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
|
173
173
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
174
174
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
175
175
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -187,7 +187,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
|
|
187
187
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
188
188
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
|
189
189
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
190
|
-
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=
|
190
|
+
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
191
191
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
192
192
|
ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
|
193
193
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
@@ -200,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
200
200
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
201
201
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
202
202
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
203
|
-
ai_edge_torch_nightly-0.3.0.
|
204
|
-
ai_edge_torch_nightly-0.3.0.
|
205
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
204
|
+
ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/METADATA,sha256=PfyYhqbf7VEibw2TEDRb8tBOIPG9dfXhT9tNNou_iZg,1966
|
205
|
+
ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/RECORD,,
|
File without changes
|
File without changes
|