ai-edge-torch-nightly 0.4.0.dev20250220__py3-none-any.whl → 0.4.0.dev20250222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +15 -10
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +5 -1
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +5 -1
- ai_edge_torch/odml_torch/debuginfo/_build.py +30 -10
- ai_edge_torch/odml_torch/lowerings/_basic.py +9 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250220.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250220.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/RECORD +13 -13
- {ai_edge_torch_nightly-0.4.0.dev20250220.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250220.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250220.dist-info → ai_edge_torch_nightly-0.4.0.dev20250222.dist-info}/top_level.txt +0 -0
@@ -29,41 +29,46 @@ import torch
|
|
29
29
|
|
30
30
|
_CLIP_CKPT = flags.DEFINE_string(
|
31
31
|
'clip_ckpt',
|
32
|
-
|
32
|
+
os.path.join(
|
33
|
+
pathlib.Path.home(),
|
34
|
+
'Downloads/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors',
|
35
|
+
),
|
33
36
|
help='Path to source CLIP model checkpoint',
|
34
|
-
required=True,
|
35
37
|
)
|
36
38
|
|
37
39
|
_DIFFUSION_CKPT = flags.DEFINE_string(
|
38
40
|
'diffusion_ckpt',
|
39
|
-
|
41
|
+
os.path.join(
|
42
|
+
pathlib.Path.home(),
|
43
|
+
'Downloads/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors',
|
44
|
+
),
|
40
45
|
help='Path to source diffusion model checkpoint',
|
41
|
-
required=True,
|
42
46
|
)
|
43
47
|
|
44
48
|
_DECODER_CKPT = flags.DEFINE_string(
|
45
49
|
'decoder_ckpt',
|
46
|
-
|
50
|
+
os.path.join(
|
51
|
+
pathlib.Path.home(),
|
52
|
+
'Downloads/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors',
|
53
|
+
),
|
47
54
|
help='Path to source image decoder model checkpoint',
|
48
|
-
required=True,
|
49
55
|
)
|
50
56
|
|
51
57
|
_OUTPUT_DIR = flags.DEFINE_string(
|
52
58
|
'output_dir',
|
53
|
-
|
59
|
+
'/tmp/sd_tflite',
|
54
60
|
help='Path to the converted TF Lite directory.',
|
55
|
-
required=True,
|
56
61
|
)
|
57
62
|
|
58
63
|
_QUANTIZE = flags.DEFINE_bool(
|
59
64
|
'quantize',
|
60
65
|
help='Whether to quantize the model during conversion.',
|
61
|
-
default=
|
66
|
+
default=False,
|
62
67
|
)
|
63
68
|
|
64
69
|
_DEVICE_TYPE = flags.DEFINE_string(
|
65
70
|
'device_type',
|
66
|
-
'
|
71
|
+
'gpu',
|
67
72
|
help='The device type of the model. Currently supported: cpu, gpu.',
|
68
73
|
)
|
69
74
|
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
import ai_edge_torch
|
16
17
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
17
18
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
18
19
|
from ai_edge_torch.generative.layers.unet import blocks_2d
|
@@ -281,7 +282,10 @@ def get_model_config(device_type: str = "cpu") -> unet_cfg.AutoEncoderConfig:
|
|
281
282
|
|
282
283
|
# For now, only turns on StableHLO composite ops on GPU backend for better
|
283
284
|
# performance. CPU should also switch to it once the support is done.
|
284
|
-
enable_hlfb =
|
285
|
+
enable_hlfb = False
|
286
|
+
if device_type == "gpu":
|
287
|
+
ai_edge_torch.config.enable_group_norm_composite = True
|
288
|
+
enable_hlfb = True
|
285
289
|
|
286
290
|
norm_config = layers_cfg.NormalizationConfig(
|
287
291
|
layers_cfg.NormalizationType.GROUP_NORM,
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
import ai_edge_torch
|
16
17
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
17
18
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
18
19
|
from ai_edge_torch.generative.layers.unet import blocks_2d
|
@@ -601,7 +602,10 @@ def get_model_config(
|
|
601
602
|
|
602
603
|
# For now, only turns on StableHLO composite ops on GPU backend for better
|
603
604
|
# performance. CPU should also switch to it once the support is done.
|
604
|
-
enable_hlfb =
|
605
|
+
enable_hlfb = False
|
606
|
+
if device_type == "gpu":
|
607
|
+
ai_edge_torch.config.enable_group_norm_composite = True
|
608
|
+
enable_hlfb = True
|
605
609
|
|
606
610
|
# Residual configs.
|
607
611
|
residual_norm_config = layers_cfg.NormalizationConfig(
|
@@ -12,8 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
import torch
|
16
15
|
import re
|
16
|
+
import torch
|
17
17
|
|
18
18
|
|
19
19
|
def _class_fullname(cls):
|
@@ -35,7 +35,7 @@ def _get_hierarchy(node: torch.fx.Node):
|
|
35
35
|
return hierachy_str
|
36
36
|
|
37
37
|
|
38
|
-
def _get_canonical_filename(filename):
|
38
|
+
def _get_canonical_filename(filename: str):
|
39
39
|
"""Remove unnecessary path prefix to make the filename more readable.
|
40
40
|
|
41
41
|
This should be factored out so that pattern is a global option that a user
|
@@ -53,6 +53,29 @@ def _get_canonical_filename(filename):
|
|
53
53
|
return filename
|
54
54
|
|
55
55
|
|
56
|
+
def _get_canoical_nodename(node: torch.fx.Node) -> str:
|
57
|
+
"""Get the canonical node name from the node's history."""
|
58
|
+
|
59
|
+
history = node.meta.get("from_node", [])
|
60
|
+
|
61
|
+
if len(history) > 1: # Compatible with torch version under 2.6.0
|
62
|
+
return history[1][0]
|
63
|
+
|
64
|
+
if not hasattr(history[0], "name"):
|
65
|
+
return None
|
66
|
+
names = []
|
67
|
+
while history:
|
68
|
+
names.append(history[0].name)
|
69
|
+
history = history[0].from_node
|
70
|
+
|
71
|
+
# Based on the experiment, the third to last name in the history stack
|
72
|
+
# can be mapped to the original torch node name. The history stack is
|
73
|
+
# generated by tracing the node's transformation history during lowering.
|
74
|
+
if len(names) > 2:
|
75
|
+
return names[-3]
|
76
|
+
return None
|
77
|
+
|
78
|
+
|
56
79
|
def build_mlir_file_debuginfo(node: torch.fx.Node):
|
57
80
|
"""Build the file and line info for the given node's lowerings in MLIR."""
|
58
81
|
|
@@ -66,16 +89,13 @@ def build_mlir_file_debuginfo(node: torch.fx.Node):
|
|
66
89
|
return _get_canonical_filename(pt_trace.file), int(pt_trace.lineno)
|
67
90
|
|
68
91
|
|
69
|
-
def build_nodename_debuginfo(node: torch.fx.Node):
|
92
|
+
def build_nodename_debuginfo(node: torch.fx.Node) -> str:
|
70
93
|
"""Build the fx node name for the given node's lowerings in MLIR."""
|
71
|
-
|
72
|
-
if not
|
94
|
+
|
95
|
+
if not hasattr(node, "meta") or "from_node" not in node.meta:
|
73
96
|
return None
|
74
|
-
|
75
|
-
|
76
|
-
if hasattr(history[0], "name"): # torch 2.6.0+
|
77
|
-
return history[0].name
|
78
|
-
return None
|
97
|
+
|
98
|
+
return _get_canoical_nodename(node)
|
79
99
|
|
80
100
|
|
81
101
|
def build_mlir_debuginfo(node: torch.fx.Node):
|
@@ -215,6 +215,15 @@ def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
|
|
215
215
|
return stablehlo.floor(x)
|
216
216
|
|
217
217
|
|
218
|
+
# Schema:
|
219
|
+
# - aten::abs(Tensor input) -> Tensor
|
220
|
+
# Torch Reference:
|
221
|
+
# - https://pytorch.org/docs/main/generated/torch.abs.html
|
222
|
+
@lower(torch.ops.aten.abs.default)
|
223
|
+
def _aten_abs(lctx, input: ir.Value, *, out=None) -> ir.Value:
|
224
|
+
return stablehlo.abs(input)
|
225
|
+
|
226
|
+
|
218
227
|
# Schema:
|
219
228
|
# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor
|
220
229
|
# Torch Reference:
|
@@ -77,7 +77,6 @@ lower_by_torch_xla2(torch.ops.aten._softmax)
|
|
77
77
|
lower_by_torch_xla2(torch.ops.aten._to_copy)
|
78
78
|
lower_by_torch_xla2(torch.ops.aten._unsafe_index)
|
79
79
|
lower_by_torch_xla2(torch.ops.aten._unsafe_view)
|
80
|
-
lower_by_torch_xla2(torch.ops.aten.abs)
|
81
80
|
lower_by_torch_xla2(torch.ops.aten.acos)
|
82
81
|
lower_by_torch_xla2(torch.ops.aten.acosh)
|
83
82
|
lower_by_torch_xla2(torch.ops.aten.add.Scalar)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250222
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=3V323rXqU4LWg1Um9kHwmrGlIkEIet9NlHpxI3IJNyw,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -47,7 +47,7 @@ ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
|
|
47
47
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
48
48
|
ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
49
49
|
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
|
50
|
-
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=JQJrMw50R_9h8tHyhD-GS9WwASBTAnz12tlfVzk9f70,2564
|
51
51
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=o13NkFlBgawBsjdJup05VMUjAPvDRAmig6VyEkX8q6U,2426
|
52
52
|
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
53
53
|
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=I5eA-XfFdHjYwDsLIjn23T2e-IgnSCQ129-5DOU8j44,2532
|
@@ -109,9 +109,9 @@ ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68
|
|
109
109
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
110
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
111
111
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
|
112
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
113
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=
|
114
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
112
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=_yk6wVoZm1_FRMFJF5URaPZNNdmMR89fwmKz81BEyao,5601
|
113
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=afyHXc86h-ij5zTULmZnM1h313N9VWCyIVriH6pqeSo,16368
|
114
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=ylqXOZhYc6XFCaNBKQw0jAnYrCtRFFQKzQzEsFIntvo,34890
|
115
115
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
116
116
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=GnY3vPZ-obrWuJifuE5bUooKLqAI7v6q71oaTuLKeBE,8778
|
117
117
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -205,17 +205,17 @@ ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xK
|
|
205
205
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
206
206
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
207
207
|
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
|
208
|
-
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=
|
208
|
+
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=k9Kas790kpMS5OrVcLzIr48ejAzcc2smrroKAHHM7TQ,3311
|
209
209
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
210
210
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
211
211
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
212
212
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
213
213
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
214
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
214
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=Jq8_yAxC7ilzd6tOaRyBsOUEeenFF_EAC5haacZT4Pg,10247
|
215
215
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
216
216
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
217
217
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
|
218
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
218
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=h-YHW7xmvt9dpea-7Zj82HW7h5TKzW6GBEE13dIJQ40,11518
|
219
219
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
220
220
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
221
221
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
@@ -230,8 +230,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
230
230
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
231
231
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
232
232
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
233
|
-
ai_edge_torch_nightly-0.4.0.
|
234
|
-
ai_edge_torch_nightly-0.4.0.
|
235
|
-
ai_edge_torch_nightly-0.4.0.
|
236
|
-
ai_edge_torch_nightly-0.4.0.
|
237
|
-
ai_edge_torch_nightly-0.4.0.
|
233
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
234
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/METADATA,sha256=nMXtCeRH62h6fhvV2GuzHlG0TDmvPAP_W9KQF6lVsc0,1966
|
235
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
236
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
237
|
+
ai_edge_torch_nightly-0.4.0.dev20250222.dist-info/RECORD,,
|
File without changes
|
File without changes
|