ai-edge-torch-nightly 0.4.0.dev20250225__py3-none-any.whl → 0.4.0.dev20250227__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +4 -0
- ai_edge_torch/fx_infra/graph_utils.py +32 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +9 -2
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +25 -8
- ai_edge_torch/generative/layers/attention.py +0 -12
- ai_edge_torch/generative/layers/kv_cache.py +8 -5
- ai_edge_torch/generative/layers/model_config.py +0 -3
- ai_edge_torch/generative/utilities/converter.py +14 -4
- ai_edge_torch/generative/utilities/model_builder.py +2 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +11 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250225.dist-info → ai_edge_torch_nightly-0.4.0.dev20250227.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250225.dist-info → ai_edge_torch_nightly-0.4.0.dev20250227.dist-info}/RECORD +16 -16
- {ai_edge_torch_nightly-0.4.0.dev20250225.dist-info → ai_edge_torch_nightly-0.4.0.dev20250227.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250225.dist-info → ai_edge_torch_nightly-0.4.0.dev20250227.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250225.dist-info → ai_edge_torch_nightly-0.4.0.dev20250227.dist-info}/top_level.txt +0 -0
@@ -125,6 +125,10 @@ def convert_signatures(
|
|
125
125
|
else:
|
126
126
|
exported_program = torch.export.export(**kwargs, strict=True)
|
127
127
|
|
128
|
+
exported_program = fx_infra.graph_utils.reset_from_node_meta(
|
129
|
+
exported_program
|
130
|
+
)
|
131
|
+
|
128
132
|
exported_program = fx_infra.safe_run_decompositions(
|
129
133
|
exported_program,
|
130
134
|
fx_infra.decomp.pre_convert_decomp(),
|
@@ -13,7 +13,10 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
"""FX graph utilities."""
|
16
|
+
|
17
|
+
from packaging import version
|
16
18
|
import torch
|
19
|
+
from torch.fx import traceback
|
17
20
|
|
18
21
|
|
19
22
|
def remove_dangling_args(graph_module: torch.fx.GraphModule):
|
@@ -40,3 +43,32 @@ def remove_assert_tensor_metadata_nodes(graph_module: torch.fx.GraphModule):
|
|
40
43
|
graph_module.graph.lint()
|
41
44
|
graph_module.recompile()
|
42
45
|
return graph_module
|
46
|
+
|
47
|
+
|
48
|
+
def is_torch_version_under(torch_version: str) -> bool:
|
49
|
+
"""Checks if the current torch version is under the given version."""
|
50
|
+
if not torch_version:
|
51
|
+
raise ValueError("torch_version cannot be empty.")
|
52
|
+
current_version = version.parse(torch.__version__)
|
53
|
+
compared_version = version.parse(torch_version)
|
54
|
+
return current_version < compared_version
|
55
|
+
|
56
|
+
|
57
|
+
def reset_from_node_meta(ep: torch.export.ExportedProgram):
|
58
|
+
"""Resets the "from_node" meta field to fx node name only for the exported program."""
|
59
|
+
|
60
|
+
for node in ep.graph.nodes:
|
61
|
+
if not hasattr(node, "meta") or "from_node" not in node.meta:
|
62
|
+
continue
|
63
|
+
if is_torch_version_under("2.6.0.dev0"):
|
64
|
+
# For torch version under 2.6.0, the history stack is a list of tuple. We
|
65
|
+
# will only keep the current node's name in the history stack.
|
66
|
+
history = [(node.name,)]
|
67
|
+
else:
|
68
|
+
# Clean up the history stack by keeping only the current node info (fx
|
69
|
+
# node name and graph id) in a list of size 1. Clear the "from_node" field
|
70
|
+
# to prevent redundant additions to the history stack.
|
71
|
+
history = [traceback.NodeSource(node)]
|
72
|
+
history[0].from_node = []
|
73
|
+
node.meta["from_node"] = history
|
74
|
+
return ep
|
@@ -22,7 +22,7 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
-
from ai_edge_torch.generative.utilities
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
@@ -59,6 +59,11 @@ _LORA_RANKS = flags.DEFINE_multi_integer(
|
|
59
59
|
None,
|
60
60
|
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
61
|
)
|
62
|
+
_DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
63
|
+
'decode_batch_size',
|
64
|
+
1,
|
65
|
+
'The batch size for the decode signature.',
|
66
|
+
)
|
62
67
|
|
63
68
|
|
64
69
|
def main(_):
|
@@ -72,7 +77,9 @@ def main(_):
|
|
72
77
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
73
78
|
quantize=_QUANTIZE.value,
|
74
79
|
lora_ranks=_LORA_RANKS.value,
|
75
|
-
export_config=ExportConfig(
|
80
|
+
export_config=model_builder.ExportConfig(
|
81
|
+
decode_batch_size=_DECODE_BATCH_SIZE.value
|
82
|
+
),
|
76
83
|
)
|
77
84
|
|
78
85
|
|
@@ -22,17 +22,22 @@ from absl import app
|
|
22
22
|
from absl import flags
|
23
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
|
-
from ai_edge_torch.generative.utilities
|
25
|
+
from ai_edge_torch.generative.utilities import model_builder
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm2'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'smollm2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,6 +54,16 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
62
|
+
_DECODE_BATCH_SIZE = flags.DEFINE_integer(
|
63
|
+
'decode_batch_size',
|
64
|
+
1,
|
65
|
+
'The batch size for the decode signature.',
|
66
|
+
)
|
52
67
|
|
53
68
|
|
54
69
|
def main(_):
|
@@ -56,14 +71,16 @@ def main(_):
|
|
56
71
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
72
|
)
|
58
73
|
|
59
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
-
output_filename = f'smollm2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
74
|
converter.convert_to_tflite(
|
62
75
|
pytorch_model,
|
63
|
-
|
76
|
+
output_path=_OUTPUT_PATH.value,
|
77
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
64
78
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
79
|
quantize=_QUANTIZE.value,
|
66
|
-
|
80
|
+
lora_ranks=_LORA_RANKS.value,
|
81
|
+
export_config=model_builder.ExportConfig(
|
82
|
+
decode_batch_size=_DECODE_BATCH_SIZE.value
|
83
|
+
),
|
67
84
|
)
|
68
85
|
|
69
86
|
|
@@ -48,7 +48,6 @@ class TransformerBlock(nn.Module):
|
|
48
48
|
config.pre_attention_norm_config,
|
49
49
|
)
|
50
50
|
self.atten_func = CausalSelfAttention(
|
51
|
-
model_config.batch_size,
|
52
51
|
model_config.embedding_dim,
|
53
52
|
config.attn_config,
|
54
53
|
model_config.enable_hlfb,
|
@@ -115,7 +114,6 @@ class CausalSelfAttention(nn.Module):
|
|
115
114
|
|
116
115
|
def __init__(
|
117
116
|
self,
|
118
|
-
batch_size: int,
|
119
117
|
dim: int,
|
120
118
|
config: cfg.AttentionConfig,
|
121
119
|
enable_hlfb: bool,
|
@@ -123,14 +121,12 @@ class CausalSelfAttention(nn.Module):
|
|
123
121
|
"""Initialize an instance of CausalSelfAttention.
|
124
122
|
|
125
123
|
Args:
|
126
|
-
batch_size (int): batch size of the input tensor.
|
127
124
|
dim (int): causal attention's input/output dimmension.
|
128
125
|
config (cfg.AttentionConfig): attention specific configurations.
|
129
126
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
130
127
|
"""
|
131
128
|
super().__init__()
|
132
129
|
self.kv_cache = None
|
133
|
-
self.batch_size = batch_size
|
134
130
|
qkv_shape = (
|
135
131
|
config.num_heads + 2 * config.num_query_groups
|
136
132
|
) * config.head_dim
|
@@ -179,11 +175,6 @@ class CausalSelfAttention(nn.Module):
|
|
179
175
|
"""
|
180
176
|
# Batch size, sequence length, embedding dimensionality.
|
181
177
|
B, T, E = x.size()
|
182
|
-
assert B == self.batch_size, (
|
183
|
-
"batch size of input tensor must match with the batch size specified in"
|
184
|
-
" the model configuration."
|
185
|
-
)
|
186
|
-
|
187
178
|
qkv = self.qkv_projection(x)
|
188
179
|
|
189
180
|
# Assemble into a number of query groups to support MHA, MQA and GQA.
|
@@ -290,7 +281,6 @@ class CrossAttention(nn.Module):
|
|
290
281
|
|
291
282
|
def __init__(
|
292
283
|
self,
|
293
|
-
batch_size: int,
|
294
284
|
query_dim: int,
|
295
285
|
cross_dim: int,
|
296
286
|
hidden_dim: int,
|
@@ -301,7 +291,6 @@ class CrossAttention(nn.Module):
|
|
301
291
|
"""Initialize an instance of CrossAttention.
|
302
292
|
|
303
293
|
Args:
|
304
|
-
batch_size (int): batch size of the input tensor.
|
305
294
|
query_dim (int): query tensor's dimension.
|
306
295
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
307
296
|
hidden_dim (int): hidden dimension that q, k, v tensors project to.
|
@@ -376,7 +365,6 @@ class CrossAttention(nn.Module):
|
|
376
365
|
|
377
366
|
if rope is not None:
|
378
367
|
# Compute rotary positional embedding for query and key.
|
379
|
-
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
380
368
|
cos, sin = rope
|
381
369
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
382
370
|
|
@@ -18,14 +18,11 @@
|
|
18
18
|
import dataclasses
|
19
19
|
from typing import List, Tuple
|
20
20
|
|
21
|
-
from ai_edge_torch import hlfb
|
22
21
|
from ai_edge_torch.generative.layers import model_config
|
23
22
|
from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
|
24
23
|
import torch
|
25
24
|
import torch.utils._pytree as pytree
|
26
25
|
|
27
|
-
BATCH_SIZE = 1
|
28
|
-
|
29
26
|
|
30
27
|
@dataclasses.dataclass
|
31
28
|
class KVCacheEntry:
|
@@ -45,9 +42,10 @@ class KVCacheEntry:
|
|
45
42
|
config: model_config.AttentionConfig,
|
46
43
|
dtype: torch.dtype = torch.float32,
|
47
44
|
device: torch.device = None,
|
45
|
+
batch_size: int = 1,
|
48
46
|
) -> "KVCacheEntry":
|
49
47
|
"""Build an instance of the class based on model config."""
|
50
|
-
shape = (
|
48
|
+
shape = (batch_size, kv_cache_max, config.num_query_groups, config.head_dim)
|
51
49
|
k = torch.zeros(shape, dtype=dtype, device=device)
|
52
50
|
v = torch.zeros(shape, dtype=dtype, device=device)
|
53
51
|
obj = cls(k_cache=k, v_cache=v)
|
@@ -66,6 +64,7 @@ class KVCache:
|
|
66
64
|
config: model_config.ModelConfig,
|
67
65
|
dtype: torch.dtype = torch.float32,
|
68
66
|
device: torch.device = None,
|
67
|
+
batch_size: int = 1,
|
69
68
|
) -> "KVCache":
|
70
69
|
"""Build an instance of the class based on model config.
|
71
70
|
|
@@ -75,17 +74,21 @@ class KVCache:
|
|
75
74
|
Defaults to torch.float32.
|
76
75
|
device (torch.device, optional): The device placement of the cache
|
77
76
|
tensors. Defaults to None.
|
77
|
+
batch_size (int, optional): The batch size of the cache tensors.
|
78
|
+
Defaults to 1.
|
78
79
|
|
79
80
|
Returns:
|
80
81
|
KVCache: The created cache object.
|
81
82
|
"""
|
82
83
|
caches = [
|
83
84
|
KVCacheEntry.from_model_config(
|
84
|
-
config.kv_cache_max
|
85
|
+
config.kv_cache_max
|
86
|
+
if not config.block_config(idx).kv_cache_max_len
|
85
87
|
else config.block_config(idx).kv_cache_max_len,
|
86
88
|
config.block_config(idx).attn_config,
|
87
89
|
dtype,
|
88
90
|
device,
|
91
|
+
batch_size,
|
89
92
|
)
|
90
93
|
for idx in range(config.num_layers)
|
91
94
|
]
|
@@ -220,9 +220,6 @@ class ModelConfig:
|
|
220
220
|
# The maximum sequence length of the KV cache. Should not exceed max_seq_len.
|
221
221
|
kv_cache_max_len: int = 0
|
222
222
|
|
223
|
-
# Default batch size of the exported model. Default value is 1.
|
224
|
-
batch_size: int = 1
|
225
|
-
|
226
223
|
# Softcap on the model output logits.
|
227
224
|
final_logit_softcap: Optional[float] = None
|
228
225
|
|
@@ -110,6 +110,11 @@ def convert_to_tflite(
|
|
110
110
|
lora_suffix = (
|
111
111
|
'' if not lora_ranks else f'_lora{",".join(map(str, lora_ranks))}'
|
112
112
|
)
|
113
|
+
|
114
|
+
if export_config is not None:
|
115
|
+
if export_config.decode_batch_size > 1:
|
116
|
+
output_name_prefix += f'_dbs{export_config.decode_batch_size}'
|
117
|
+
|
113
118
|
output_filename = (
|
114
119
|
f'{output_name_prefix}_{quant_suffix}_ekv{kv_size}{lora_suffix}.tflite'
|
115
120
|
)
|
@@ -162,9 +167,14 @@ def _export_helper(
|
|
162
167
|
if prefill_masks:
|
163
168
|
assert len(prefill_masks) == len(prefill_seq_lens)
|
164
169
|
|
165
|
-
decode_token = torch.tensor(
|
170
|
+
decode_token = torch.tensor(
|
171
|
+
[[0] for _ in range(export_config.decode_batch_size)], dtype=torch.int
|
172
|
+
)
|
166
173
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
167
|
-
|
174
|
+
prefill_kv = export_config.kvcache_cls.from_model_config(config)
|
175
|
+
decode_kv = export_config.kvcache_cls.from_model_config(
|
176
|
+
config, batch_size=export_config.decode_batch_size
|
177
|
+
)
|
168
178
|
|
169
179
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
170
180
|
|
@@ -183,7 +193,7 @@ def _export_helper(
|
|
183
193
|
sample_kwargs = {
|
184
194
|
'tokens': prefill_tokens,
|
185
195
|
'input_pos': prefill_input_pos,
|
186
|
-
'kv_cache':
|
196
|
+
'kv_cache': prefill_kv,
|
187
197
|
}
|
188
198
|
if prefill_masks is not None:
|
189
199
|
sample_kwargs['mask'] = prefill_masks[i]
|
@@ -211,7 +221,7 @@ def _export_helper(
|
|
211
221
|
sample_kwargs = {
|
212
222
|
'tokens': decode_token,
|
213
223
|
'input_pos': decode_input_pos,
|
214
|
-
'kv_cache':
|
224
|
+
'kv_cache': decode_kv,
|
215
225
|
}
|
216
226
|
if export_config.decode_mask is not None:
|
217
227
|
sample_kwargs['mask'] = export_config.decode_mask
|
@@ -60,6 +60,8 @@ class ExportConfig:
|
|
60
60
|
decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
61
61
|
# The KV Cache class for K and V buffers in attention.
|
62
62
|
kvcache_cls: type = kv_utils.KVCache
|
63
|
+
# The batch size of the decode signature.
|
64
|
+
decode_batch_size: int = 1
|
63
65
|
|
64
66
|
|
65
67
|
class DecoderOnlyModel(nn.Module):
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
import re
|
16
|
+
from ai_edge_torch.fx_infra import graph_utils
|
16
17
|
import torch
|
17
18
|
|
18
19
|
|
@@ -48,8 +49,6 @@ def _get_canonical_filename(filename: str):
|
|
48
49
|
The canonicalized filename.
|
49
50
|
"""
|
50
51
|
|
51
|
-
# TODO(yijieyang): We should add a config option to provide a regex to strip
|
52
|
-
# from the debug info. Currently absolute path is used.
|
53
52
|
return filename
|
54
53
|
|
55
54
|
|
@@ -57,9 +56,13 @@ def _get_canoical_nodename(node: torch.fx.Node) -> str:
|
|
57
56
|
"""Get the canonical node name from the node's history."""
|
58
57
|
|
59
58
|
history = node.meta.get("from_node", [])
|
59
|
+
if not history:
|
60
|
+
return None
|
60
61
|
|
61
|
-
|
62
|
-
|
62
|
+
# Compatible with torch version under 2.6.0. The history stack is a list of
|
63
|
+
# tuple. The first element of the first tuple is the node name.
|
64
|
+
if graph_utils.is_torch_version_under("2.6.0.dev0"):
|
65
|
+
return history[0][0]
|
63
66
|
|
64
67
|
if not hasattr(history[0], "name"):
|
65
68
|
return None
|
@@ -68,12 +71,10 @@ def _get_canoical_nodename(node: torch.fx.Node) -> str:
|
|
68
71
|
names.append(history[0].name)
|
69
72
|
history = history[0].from_node
|
70
73
|
|
71
|
-
#
|
72
|
-
#
|
73
|
-
#
|
74
|
-
|
75
|
-
return names[-3]
|
76
|
-
return None
|
74
|
+
# The history stack is generated by tracing the node's transformation history
|
75
|
+
# during lowering. The last name in the history stack is used to map to the
|
76
|
+
# original torch fx node name.
|
77
|
+
return names[-1]
|
77
78
|
|
78
79
|
|
79
80
|
def build_mlir_file_debuginfo(node: torch.fx.Node):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250227
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,9 +2,9 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=K2jtDrBNGi74j_uQYVUT6MJ2-aQFKkKy5ZYur9iWdVU,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
@@ -41,7 +41,7 @@ ai_edge_torch/fx_infra/__init__.py,sha256=APjkSqEfwDxcnI8k53rGi3Ef-G2L-M8fdaPGpx
|
|
41
41
|
ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
|
42
42
|
ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=ZbWheeZ8ydsxCk2aVGUgUynrkEkBOMjBCzPhS5uq4sU,2595
|
43
43
|
ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
|
44
|
-
ai_edge_torch/fx_infra/graph_utils.py,sha256=
|
44
|
+
ai_edge_torch/fx_infra/graph_utils.py,sha256=nqGe-xIJ77RamSUh0UYyI2XHOsZqFDWax-vpRAtVR_E,2796
|
45
45
|
ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
|
46
46
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
47
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -102,8 +102,8 @@ ai_edge_torch/generative/examples/qwen_vl/verify.py,sha256=JUwHoC_zvcC3RC3wZ3e3e
|
|
102
102
|
ai_edge_torch/generative/examples/qwen_vl/verify_decoder.py,sha256=xPWoOBLh2eK12KEhELLYymfL7xvc0chmYC98c6x37oo,2602
|
103
103
|
ai_edge_torch/generative/examples/qwen_vl/verify_image_encoder.py,sha256=PZ392nDoJG2OmHZ_7Jet3Zu1JkN6QErxKcDc7a-PPds,3126
|
104
104
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
105
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
106
|
-
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=
|
105
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=hWko-RJB8eXNUfi4EzQ2yjW30YE4UB4zAz7rd2Q5qpg,2708
|
106
|
+
ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py,sha256=sJ-o385eqQsciv0TEQRkixvS0DD6dKruAuK0zlEsDoY,2715
|
107
107
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=3uUltb6D3Q1aHpndcYTJrsWM_RBwLAraKDniH8ZZous,3779
|
108
108
|
ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68n0owZMvty5Rr3W7ZNWWSw,2702
|
109
109
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -136,13 +136,13 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LRu6PSw7Lqu6HGbv1t
|
|
136
136
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
|
137
137
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
138
138
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
139
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
139
|
+
ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDqYuGOZdYAI9EooOM,13247
|
140
140
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
141
141
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
142
142
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
143
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
143
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=jwbt0-2fd_CNWS2fp4nf0zvh6kk5citINGlFC_RtEUU,6540
|
144
144
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
145
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
145
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=wNsZDzZQoimOKdZ9FWMCktPj-pQ_0D7084hgzMT5XYo,8155
|
146
146
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
147
147
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
148
148
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
@@ -173,10 +173,10 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728Fc
|
|
173
173
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
174
174
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
175
175
|
ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
|
176
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
176
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=VtG42CVz657XbvTj-FZJiCFW0Hm11OVKKC_mr2tjxhc,8413
|
177
177
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
178
178
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
179
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
179
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=eY3qAcBhupIn955YnWuzUi9hoWYvl4ntRWA6PBudzMo,6888
|
180
180
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
181
181
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
182
182
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
@@ -205,7 +205,7 @@ ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xK
|
|
205
205
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
206
206
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
207
207
|
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
|
208
|
-
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=
|
208
|
+
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=6Ns2rlfOilLJEk5cUxlkRwm2uxOgEF2-0S2DMcOqr6A,3319
|
209
209
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
210
210
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
211
211
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
230
230
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
231
231
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
232
232
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
233
|
-
ai_edge_torch_nightly-0.4.0.
|
234
|
-
ai_edge_torch_nightly-0.4.0.
|
235
|
-
ai_edge_torch_nightly-0.4.0.
|
236
|
-
ai_edge_torch_nightly-0.4.0.
|
237
|
-
ai_edge_torch_nightly-0.4.0.
|
233
|
+
ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
234
|
+
ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/METADATA,sha256=cHcz3adq1WwVddazAJ06h7SKITJm70eMpFVjoNa2Jw4,1966
|
235
|
+
ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
236
|
+
ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
237
|
+
ai_edge_torch_nightly-0.4.0.dev20250227.dist-info/RECORD,,
|
File without changes
|
File without changes
|