ai-edge-torch-nightly 0.2.0.dev20240730__py3-none-any.whl → 0.2.0.dev20240802__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +45 -36
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +26 -13
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +15 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +47 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +42 -12
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +10 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -47
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/RECORD +89 -89
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/top_level.txt +0 -0
|
@@ -18,11 +18,10 @@ import glob
|
|
|
18
18
|
import os
|
|
19
19
|
from typing import Callable, Dict
|
|
20
20
|
|
|
21
|
+
from ai_edge_torch.generative.layers import model_config
|
|
21
22
|
from safetensors import safe_open
|
|
22
23
|
import torch
|
|
23
24
|
|
|
24
|
-
from ai_edge_torch.generative.layers import model_config
|
|
25
|
-
|
|
26
25
|
|
|
27
26
|
def load_safetensors(full_path: str):
|
|
28
27
|
"""Loads safetensors into a single state dictionary.
|
|
@@ -71,7 +70,11 @@ def load_pytorch_statedict(full_path: str):
|
|
|
71
70
|
Raises:
|
|
72
71
|
ValueError: If no tensors are loaded from the provided directory or file.
|
|
73
72
|
"""
|
|
74
|
-
pattern =
|
|
73
|
+
pattern = (
|
|
74
|
+
os.path.join(full_path, "*.bin")
|
|
75
|
+
if os.path.isdir(full_path)
|
|
76
|
+
else full_path
|
|
77
|
+
)
|
|
75
78
|
files = []
|
|
76
79
|
for file in glob.glob(pattern):
|
|
77
80
|
files.append(file)
|
|
@@ -131,7 +134,10 @@ class ModelLoader:
|
|
|
131
134
|
self._loader = self._get_loader()
|
|
132
135
|
|
|
133
136
|
def load(
|
|
134
|
-
self,
|
|
137
|
+
self,
|
|
138
|
+
model: torch.nn.Module,
|
|
139
|
+
strict: bool = True,
|
|
140
|
+
fuse_attention: bool = True,
|
|
135
141
|
):
|
|
136
142
|
"""Load the model from the checkpoint
|
|
137
143
|
|
|
@@ -166,11 +172,14 @@ class ModelLoader:
|
|
|
166
172
|
|
|
167
173
|
if strict and state:
|
|
168
174
|
raise ValueError(
|
|
169
|
-
|
|
175
|
+
"Failed to map all tensor. Remaining tensor are:"
|
|
176
|
+
f" {list(state.keys())}"
|
|
170
177
|
)
|
|
171
178
|
model.load_state_dict(converted_state, strict=strict)
|
|
172
179
|
|
|
173
|
-
def _do_load(
|
|
180
|
+
def _do_load(
|
|
181
|
+
self, model, state, names, additional_prefix="", fuse_attention=True
|
|
182
|
+
):
|
|
174
183
|
"""Load the model from the checkpoint
|
|
175
184
|
|
|
176
185
|
Args:
|
|
@@ -183,7 +192,9 @@ class ModelLoader:
|
|
|
183
192
|
"""
|
|
184
193
|
converted_state = dict()
|
|
185
194
|
if names.embedding is not None:
|
|
186
|
-
converted_state["tok_embedding.weight"] = state.pop(
|
|
195
|
+
converted_state["tok_embedding.weight"] = state.pop(
|
|
196
|
+
f"{names.embedding}.weight"
|
|
197
|
+
)
|
|
187
198
|
if names.lm_head is not None:
|
|
188
199
|
converted_state["lm_head.weight"] = state.pop(f"{names.lm_head}.weight")
|
|
189
200
|
if model.config.lm_head_use_bias:
|
|
@@ -195,7 +206,9 @@ class ModelLoader:
|
|
|
195
206
|
f"{final_norm_name}.weight"
|
|
196
207
|
)
|
|
197
208
|
if f"{final_norm_name}.bias" in state:
|
|
198
|
-
converted_state["final_norm.bias"] = state.pop(
|
|
209
|
+
converted_state["final_norm.bias"] = state.pop(
|
|
210
|
+
f"{final_norm_name}.bias"
|
|
211
|
+
)
|
|
199
212
|
|
|
200
213
|
if names.relative_attn_bias:
|
|
201
214
|
rel_attn_name = names.relative_attn_bias
|
|
@@ -205,7 +218,9 @@ class ModelLoader:
|
|
|
205
218
|
)
|
|
206
219
|
|
|
207
220
|
for i in range(model.config.num_layers):
|
|
208
|
-
self._map_norm(
|
|
221
|
+
self._map_norm(
|
|
222
|
+
i, model.config, state, converted_state, names, additional_prefix
|
|
223
|
+
)
|
|
209
224
|
self._map_feedforward(
|
|
210
225
|
i, model.config, state, converted_state, names, additional_prefix
|
|
211
226
|
)
|
|
@@ -268,13 +283,19 @@ class ModelLoader:
|
|
|
268
283
|
if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
|
269
284
|
ff_up_proj_name = names.ff_up_proj.format(idx)
|
|
270
285
|
ff_down_proj_name = names.ff_down_proj.format(idx)
|
|
271
|
-
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
|
286
|
+
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
|
287
|
+
f"{ff_up_proj_name}.weight"
|
|
288
|
+
)
|
|
272
289
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
|
273
290
|
f"{ff_down_proj_name}.weight"
|
|
274
291
|
)
|
|
275
292
|
if config.ff_config.use_bias:
|
|
276
|
-
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
|
277
|
-
|
|
293
|
+
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
|
294
|
+
f"{ff_up_proj_name}.bias"
|
|
295
|
+
)
|
|
296
|
+
converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
|
|
297
|
+
f"{ff_down_proj_name}.bias"
|
|
298
|
+
)
|
|
278
299
|
else:
|
|
279
300
|
if names.ff_gate_proj is not None:
|
|
280
301
|
ff_up_proj_name = names.ff_up_proj.format(idx)
|
|
@@ -290,7 +311,9 @@ class ModelLoader:
|
|
|
290
311
|
f"{ff_gate_proj_name}.weight"
|
|
291
312
|
)
|
|
292
313
|
if config.ff_config.use_bias:
|
|
293
|
-
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
|
314
|
+
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
|
315
|
+
f"{ff_up_proj_name}.bias"
|
|
316
|
+
)
|
|
294
317
|
converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
|
|
295
318
|
f"{ff_down_proj_name}.bias"
|
|
296
319
|
)
|
|
@@ -355,12 +378,12 @@ class ModelLoader:
|
|
|
355
378
|
)
|
|
356
379
|
|
|
357
380
|
o_name = names.attn_output_proj.format(idx)
|
|
358
|
-
converted_state[f"{prefix}.atten_func.output_projection.weight"] =
|
|
359
|
-
f"{o_name}.weight"
|
|
381
|
+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
|
382
|
+
state.pop(f"{o_name}.weight")
|
|
360
383
|
)
|
|
361
384
|
if config.attn_config.output_proj_use_bias:
|
|
362
|
-
converted_state[f"{prefix}.atten_func.output_projection.bias"] =
|
|
363
|
-
f"{o_name}.bias"
|
|
385
|
+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
|
386
|
+
state.pop(f"{o_name}.bias")
|
|
364
387
|
)
|
|
365
388
|
|
|
366
389
|
def _map_cross_attention(
|
|
@@ -385,47 +408,51 @@ class ModelLoader:
|
|
|
385
408
|
v_name = names.cross_attn_value_proj.format(idx)
|
|
386
409
|
|
|
387
410
|
if fuse_attention:
|
|
388
|
-
converted_state[f"{prefix}.cross_atten_func.attn.weight"] =
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
411
|
+
converted_state[f"{prefix}.cross_atten_func.attn.weight"] = (
|
|
412
|
+
self._fuse_qkv(
|
|
413
|
+
config,
|
|
414
|
+
state.pop(f"{q_name}.weight"),
|
|
415
|
+
state.pop(f"{k_name}.weight"),
|
|
416
|
+
state.pop(f"{v_name}.weight"),
|
|
417
|
+
)
|
|
393
418
|
)
|
|
394
419
|
if config.attn_config.qkv_use_bias:
|
|
395
|
-
converted_state[f"{prefix}.cross_atten_func.attn.bias"] =
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
420
|
+
converted_state[f"{prefix}.cross_atten_func.attn.bias"] = (
|
|
421
|
+
self._fuse_qkv(
|
|
422
|
+
config,
|
|
423
|
+
state.pop(f"{q_name}.bias"),
|
|
424
|
+
state.pop(f"{k_name}.bias"),
|
|
425
|
+
state.pop(f"{v_name}.bias"),
|
|
426
|
+
)
|
|
400
427
|
)
|
|
401
428
|
else:
|
|
402
|
-
converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] =
|
|
403
|
-
f"{q_name}.weight"
|
|
429
|
+
converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = (
|
|
430
|
+
state.pop(f"{q_name}.weight")
|
|
404
431
|
)
|
|
405
|
-
converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] =
|
|
406
|
-
f"{k_name}.weight"
|
|
432
|
+
converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = (
|
|
433
|
+
state.pop(f"{k_name}.weight")
|
|
407
434
|
)
|
|
408
|
-
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] =
|
|
409
|
-
f"{v_name}.weight"
|
|
435
|
+
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = (
|
|
436
|
+
state.pop(f"{v_name}.weight")
|
|
410
437
|
)
|
|
411
438
|
if config.attn_config.qkv_use_bias:
|
|
412
|
-
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] =
|
|
413
|
-
f"{q_name}.bias"
|
|
439
|
+
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = (
|
|
440
|
+
state.pop(f"{q_name}.bias")
|
|
414
441
|
)
|
|
415
|
-
converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] =
|
|
416
|
-
f"{k_name}.bias"
|
|
442
|
+
converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = (
|
|
443
|
+
state.pop(f"{k_name}.bias")
|
|
417
444
|
)
|
|
418
|
-
converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] =
|
|
419
|
-
f"{v_name}.bias"
|
|
445
|
+
converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = (
|
|
446
|
+
state.pop(f"{v_name}.bias")
|
|
420
447
|
)
|
|
421
448
|
|
|
422
449
|
o_name = names.cross_attn_output_proj.format(idx)
|
|
423
|
-
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] =
|
|
424
|
-
f"{o_name}.weight"
|
|
450
|
+
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = (
|
|
451
|
+
state.pop(f"{o_name}.weight")
|
|
425
452
|
)
|
|
426
453
|
if config.attn_config.output_proj_use_bias:
|
|
427
|
-
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] =
|
|
428
|
-
f"{o_name}.bias"
|
|
454
|
+
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = (
|
|
455
|
+
state.pop(f"{o_name}.bias")
|
|
429
456
|
)
|
|
430
457
|
|
|
431
458
|
def _map_norm(
|
|
@@ -450,12 +477,12 @@ class ModelLoader:
|
|
|
450
477
|
|
|
451
478
|
if names.pre_cross_attn_norm:
|
|
452
479
|
pre_cross_attn_norm_name = names.pre_cross_attn_norm.format(idx)
|
|
453
|
-
converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.weight"] =
|
|
454
|
-
f"{pre_cross_attn_norm_name}.weight"
|
|
480
|
+
converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.weight"] = (
|
|
481
|
+
state.pop(f"{pre_cross_attn_norm_name}.weight")
|
|
455
482
|
)
|
|
456
483
|
if f"{pre_cross_attn_norm_name}.bias" in state:
|
|
457
|
-
converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.bias"] =
|
|
458
|
-
f"{pre_cross_attn_norm_name}.bias"
|
|
484
|
+
converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.bias"] = (
|
|
485
|
+
state.pop(f"{pre_cross_attn_norm_name}.bias")
|
|
459
486
|
)
|
|
460
487
|
|
|
461
488
|
if names.pre_ff_norm is not None:
|
|
@@ -475,7 +502,9 @@ class ModelLoader:
|
|
|
475
502
|
k: torch.Tensor,
|
|
476
503
|
v: torch.Tensor,
|
|
477
504
|
) -> torch.Tensor:
|
|
478
|
-
q_per_kv =
|
|
505
|
+
q_per_kv = (
|
|
506
|
+
config.attn_config.num_heads // config.attn_config.num_query_groups
|
|
507
|
+
)
|
|
479
508
|
qs = torch.split(q, config.head_dim * q_per_kv)
|
|
480
509
|
ks = torch.split(k, config.head_dim)
|
|
481
510
|
vs = torch.split(v, config.head_dim)
|
|
@@ -16,11 +16,10 @@ import copy
|
|
|
16
16
|
from typing import Any
|
|
17
17
|
import uuid
|
|
18
18
|
|
|
19
|
-
import torch
|
|
20
|
-
from torch_xla.experimental import xla_marker
|
|
21
|
-
|
|
22
19
|
from ai_edge_torch.hlfb.mark_pattern.pattern import Pattern
|
|
23
20
|
from ai_edge_torch.hlfb.mark_pattern.pattern import ScalarAttrTracker # NOQA
|
|
21
|
+
import torch
|
|
22
|
+
from torch_xla.experimental import xla_marker
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
@torch._dynamo.assume_constant_result
|
|
@@ -16,6 +16,7 @@ import copy
|
|
|
16
16
|
import dataclasses
|
|
17
17
|
from typing import Any, Callable, Optional, Union
|
|
18
18
|
|
|
19
|
+
from ai_edge_torch.hlfb.mark_pattern import passes
|
|
19
20
|
import torch
|
|
20
21
|
from torch.export.graph_signature import TensorArgument
|
|
21
22
|
from torch.fx import Graph
|
|
@@ -23,8 +24,6 @@ from torch.fx import GraphModule
|
|
|
23
24
|
from torch.fx.passes.utils.matcher_utils import InternalMatch
|
|
24
25
|
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
|
25
26
|
|
|
26
|
-
from ai_edge_torch.hlfb.mark_pattern import passes
|
|
27
|
-
|
|
28
27
|
|
|
29
28
|
def _are_equal(x: Any, y: Any) -> bool:
|
|
30
29
|
if type(x) != type(y):
|
|
@@ -69,7 +68,9 @@ class ScalarAttrTracker:
|
|
|
69
68
|
pattern_arg_pos: int
|
|
70
69
|
transform: Callable = lambda x: x
|
|
71
70
|
inverse_transform: Callable = lambda x: x
|
|
72
|
-
_source_targets: list[tuple[Any, Any]] = dataclasses.field(
|
|
71
|
+
_source_targets: list[tuple[Any, Any]] = dataclasses.field(
|
|
72
|
+
default_factory=list
|
|
73
|
+
)
|
|
73
74
|
|
|
74
75
|
def track(self, *sources):
|
|
75
76
|
"""Register magic values to track the (transformed) attr values in
|
|
@@ -78,7 +79,9 @@ class ScalarAttrTracker:
|
|
|
78
79
|
for source in sources:
|
|
79
80
|
target = self.transform(source)
|
|
80
81
|
if not _are_equal(self.inverse_transform(target), source):
|
|
81
|
-
raise Exception(
|
|
82
|
+
raise Exception(
|
|
83
|
+
f"Invalid transform/inverse_transform for {self.attr_name}"
|
|
84
|
+
)
|
|
82
85
|
self._source_targets.append([source, target])
|
|
83
86
|
return self
|
|
84
87
|
|
|
@@ -189,7 +192,9 @@ class Pattern:
|
|
|
189
192
|
|
|
190
193
|
self.name = name
|
|
191
194
|
self.attr_builder = attr_builder
|
|
192
|
-
self._scalar_attr_trackers =
|
|
195
|
+
self._scalar_attr_trackers = (
|
|
196
|
+
scalar_attr_trackers if scalar_attr_trackers else []
|
|
197
|
+
)
|
|
193
198
|
|
|
194
199
|
exported_program = torch.export.export(module, export_args)
|
|
195
200
|
if decomp_table is not None:
|
|
@@ -201,7 +206,9 @@ class Pattern:
|
|
|
201
206
|
self._scalar_attr_locations = []
|
|
202
207
|
for tracker in self._scalar_attr_trackers:
|
|
203
208
|
self._scalar_attr_locations.append(
|
|
204
|
-
_find_scalar_attr(
|
|
209
|
+
_find_scalar_attr(
|
|
210
|
+
module, export_args, tracker, decomp_table=decomp_table
|
|
211
|
+
)
|
|
205
212
|
)
|
|
206
213
|
|
|
207
214
|
# Sanitize graph_module for more precise pattern matching.
|
|
@@ -251,7 +258,9 @@ class Pattern:
|
|
|
251
258
|
attrs = {}
|
|
252
259
|
|
|
253
260
|
for loc in self._scalar_attr_locations:
|
|
254
|
-
attrs[loc.attr_name] = self._get_attr_value_from_pattern_match(
|
|
261
|
+
attrs[loc.attr_name] = self._get_attr_value_from_pattern_match(
|
|
262
|
+
match, loc
|
|
263
|
+
)
|
|
255
264
|
|
|
256
265
|
attrs = attrs if attrs else None
|
|
257
266
|
match_with_attrs.append((match, attrs))
|
|
@@ -15,11 +15,10 @@
|
|
|
15
15
|
|
|
16
16
|
import unittest
|
|
17
17
|
|
|
18
|
+
from ai_edge_torch.hlfb import mark_pattern
|
|
18
19
|
import torch
|
|
19
20
|
import torch_xla
|
|
20
21
|
|
|
21
|
-
from ai_edge_torch.hlfb import mark_pattern
|
|
22
|
-
|
|
23
22
|
|
|
24
23
|
def _export_stablehlo_mlir(model, args=None):
|
|
25
24
|
if not isinstance(model, torch.export.ExportedProgram):
|
|
@@ -73,7 +72,9 @@ class TestMarkPattern(unittest.TestCase):
|
|
|
73
72
|
mlir = _export_stablehlo_mlir(exported_program)
|
|
74
73
|
|
|
75
74
|
self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
|
|
76
|
-
self.assertEqual(
|
|
75
|
+
self.assertEqual(
|
|
76
|
+
mlir.count('composite_attributes = {alias = "test.test_add"}'), 2
|
|
77
|
+
)
|
|
77
78
|
|
|
78
79
|
def test_mark_pattern_with_scalar_attr_tracker(self):
|
|
79
80
|
class TestModel(torch.nn.Module):
|
|
@@ -15,12 +15,11 @@
|
|
|
15
15
|
import math
|
|
16
16
|
import unittest
|
|
17
17
|
|
|
18
|
+
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
18
19
|
import torch
|
|
19
20
|
import torch.nn.functional as F
|
|
20
21
|
import torch_xla
|
|
21
22
|
|
|
22
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
23
|
-
|
|
24
23
|
|
|
25
24
|
def _export_stablehlo_mlir(model, args):
|
|
26
25
|
ep = torch.export.export(model, args)
|
|
@@ -80,7 +79,9 @@ class TestStableHLOCompositeBuilder(unittest.TestCase):
|
|
|
80
79
|
super().__init__()
|
|
81
80
|
|
|
82
81
|
def log_softmax(self, x: torch.Tensor, dim: int):
|
|
83
|
-
builder = StableHLOCompositeBuilder(
|
|
82
|
+
builder = StableHLOCompositeBuilder(
|
|
83
|
+
name="test.log_softmax", attr={"dim": dim}
|
|
84
|
+
)
|
|
84
85
|
x = builder.mark_inputs(x)
|
|
85
86
|
y = torch.nn.functional.log_softmax(x, dim=dim)
|
|
86
87
|
y = builder.mark_outputs(y)
|
|
@@ -126,7 +127,8 @@ class TestStableHLOCompositeBuilder(unittest.TestCase):
|
|
|
126
127
|
self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
|
|
127
128
|
self.assertEqual(
|
|
128
129
|
mlir.count(
|
|
129
|
-
'composite_attributes = {dim = 0 : i64, source = "torch.nn",
|
|
130
|
+
'composite_attributes = {dim = 0 : i64, source = "torch.nn",'
|
|
131
|
+
" version = 1.000000e+00 : f32}"
|
|
130
132
|
),
|
|
131
133
|
1,
|
|
132
134
|
)
|
|
@@ -236,8 +238,12 @@ class TestStableHLOCompositeBuilder(unittest.TestCase):
|
|
|
236
238
|
self.assertEqual(
|
|
237
239
|
mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
|
|
238
240
|
)
|
|
239
|
-
self.assertEqual(
|
|
240
|
-
|
|
241
|
+
self.assertEqual(
|
|
242
|
+
mlir.count("composite_attributes = {include_captanh = true}"), 1
|
|
243
|
+
)
|
|
244
|
+
self.assertEqual(
|
|
245
|
+
mlir.count("composite_attributes = {include_captanh = false}"), 1
|
|
246
|
+
)
|
|
241
247
|
|
|
242
248
|
def test_build_composite_with_multiple_inputs_outputs(self):
|
|
243
249
|
class SampleModel(torch.nn.Module):
|
ai_edge_torch/model.py
CHANGED
|
@@ -21,12 +21,11 @@ from __future__ import annotations
|
|
|
21
21
|
|
|
22
22
|
import abc
|
|
23
23
|
|
|
24
|
+
from ai_edge_torch.convert import conversion_utils as cutils
|
|
24
25
|
import numpy as np
|
|
25
26
|
import numpy.typing as npt
|
|
26
27
|
import tensorflow as tf
|
|
27
28
|
|
|
28
|
-
from ai_edge_torch.convert import conversion_utils as cutils
|
|
29
|
-
|
|
30
29
|
|
|
31
30
|
class Model(abc.ABC):
|
|
32
31
|
"""Represents and edge model."""
|
|
@@ -84,7 +83,8 @@ class TfLiteModel(Model):
|
|
|
84
83
|
signature_list = interpreter.get_signature_list()
|
|
85
84
|
if signature_name not in signature_list:
|
|
86
85
|
raise ValueError(
|
|
87
|
-
|
|
86
|
+
'Invalid signature name provided. Available signatures:'
|
|
87
|
+
f' {", ".join(signature_list.keys())}'
|
|
88
88
|
)
|
|
89
89
|
|
|
90
90
|
try:
|
|
@@ -92,14 +92,17 @@ class TfLiteModel(Model):
|
|
|
92
92
|
except ValueError as exception:
|
|
93
93
|
if 'Invalid signature_key provided.' in str(exception):
|
|
94
94
|
raise ValueError(
|
|
95
|
-
|
|
95
|
+
'Invalid signature key provided. Available signatures:'
|
|
96
|
+
f' {list(signature_list.keys())}'
|
|
96
97
|
)
|
|
97
98
|
else:
|
|
98
99
|
raise exception
|
|
99
100
|
|
|
100
101
|
if len(signature_list[signature_name]['inputs']) != len(args) + len(kwargs):
|
|
101
102
|
raise ValueError(
|
|
102
|
-
|
|
103
|
+
'The model requires'
|
|
104
|
+
f' {len(signature_list[signature_name]["inputs"])} arguments but'
|
|
105
|
+
f' {len(args)} was provided.'
|
|
103
106
|
)
|
|
104
107
|
|
|
105
108
|
# Gather the input dictionary based on the signature.
|
|
@@ -19,6 +19,12 @@ import copy
|
|
|
19
19
|
import functools
|
|
20
20
|
from typing import Any, Callable, Dict, List, Optional, Set
|
|
21
21
|
|
|
22
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
|
|
23
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
|
|
24
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
|
|
25
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
|
|
26
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
|
|
27
|
+
from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
|
|
22
28
|
import torch
|
|
23
29
|
from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
|
|
24
30
|
from torch.ao.quantization.observer import HistogramObserver
|
|
@@ -34,20 +40,15 @@ from torch.ao.quantization.quantizer import Quantizer
|
|
|
34
40
|
from torch.fx import Node
|
|
35
41
|
import torch.nn.functional as F
|
|
36
42
|
|
|
37
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
|
|
38
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
|
|
39
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
|
|
40
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
|
|
41
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
|
|
42
|
-
from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
|
|
43
|
-
|
|
44
43
|
__all__ = [
|
|
45
44
|
"PT2EQuantizer",
|
|
46
45
|
"get_symmetric_quantization_config",
|
|
47
46
|
]
|
|
48
47
|
|
|
49
48
|
|
|
50
|
-
def _supported_symmetric_quantized_operators() ->
|
|
49
|
+
def _supported_symmetric_quantized_operators() -> (
|
|
50
|
+
Dict[str, List[OperatorPatternType]]
|
|
51
|
+
):
|
|
51
52
|
supported_operators: Dict[str, List[OperatorPatternType]] = {
|
|
52
53
|
# Both conv and linear should be able to handle relu + hardtanh fusion since
|
|
53
54
|
# those are clamp ops
|
|
@@ -92,7 +93,9 @@ def get_symmetric_quantization_config(
|
|
|
92
93
|
):
|
|
93
94
|
if is_qat:
|
|
94
95
|
if is_dynamic:
|
|
95
|
-
raise NotImplementedError(
|
|
96
|
+
raise NotImplementedError(
|
|
97
|
+
"dynamic quantization for qat is not yet implemented."
|
|
98
|
+
)
|
|
96
99
|
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
|
97
100
|
else:
|
|
98
101
|
if is_dynamic:
|
|
@@ -106,12 +109,18 @@ def get_symmetric_quantization_config(
|
|
|
106
109
|
quant_max=127,
|
|
107
110
|
qscheme=torch.per_tensor_affine,
|
|
108
111
|
is_dynamic=is_dynamic,
|
|
109
|
-
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
|
112
|
+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
|
|
113
|
+
eps=2**-12
|
|
114
|
+
),
|
|
110
115
|
)
|
|
111
116
|
qscheme = (
|
|
112
|
-
torch.per_channel_symmetric
|
|
117
|
+
torch.per_channel_symmetric
|
|
118
|
+
if is_per_channel
|
|
119
|
+
else torch.per_tensor_symmetric
|
|
120
|
+
)
|
|
121
|
+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
|
|
122
|
+
MinMaxObserver
|
|
113
123
|
)
|
|
114
|
-
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
|
|
115
124
|
if is_qat:
|
|
116
125
|
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
|
|
117
126
|
elif is_per_channel:
|
|
@@ -197,7 +206,9 @@ def _get_module_name_filter(module_name: str):
|
|
|
197
206
|
# }
|
|
198
207
|
# get_attr nodes doesn't have nn_module_stack?
|
|
199
208
|
nn_module_stack = n.meta.get("nn_module_stack", {})
|
|
200
|
-
names = [
|
|
209
|
+
names = [
|
|
210
|
+
n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
|
|
211
|
+
]
|
|
201
212
|
return module_name in names
|
|
202
213
|
|
|
203
214
|
return module_name_filter
|
|
@@ -232,7 +243,9 @@ def _get_not_module_type_or_name_filter(
|
|
|
232
243
|
tp_list: List[Callable], module_name_list: List[str]
|
|
233
244
|
) -> Callable[[Node], bool]:
|
|
234
245
|
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
|
|
235
|
-
module_name_list_filters = [
|
|
246
|
+
module_name_list_filters = [
|
|
247
|
+
_get_module_name_filter(m) for m in module_name_list
|
|
248
|
+
]
|
|
236
249
|
|
|
237
250
|
def not_module_type_or_name_filter(n: Node) -> bool:
|
|
238
251
|
return not any(f(n) for f in module_type_filters + module_name_list_filters)
|
|
@@ -307,7 +320,9 @@ class PT2EQuantizer(Quantizer):
|
|
|
307
320
|
return ops
|
|
308
321
|
return []
|
|
309
322
|
|
|
310
|
-
def set_global(
|
|
323
|
+
def set_global(
|
|
324
|
+
self, quantization_config: QuantizationConfig
|
|
325
|
+
) -> PT2EQuantizer:
|
|
311
326
|
self.global_config = quantization_config
|
|
312
327
|
return self
|
|
313
328
|
|
|
@@ -154,7 +154,9 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
|
|
|
154
154
|
torch.per_tensor_symmetric,
|
|
155
155
|
torch.per_channel_symmetric,
|
|
156
156
|
]:
|
|
157
|
-
raise ValueError(
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Unsupported quantization_spec {quantization_spec} for weight"
|
|
159
|
+
)
|
|
158
160
|
return quantization_spec
|
|
159
161
|
|
|
160
162
|
|
|
@@ -193,7 +195,10 @@ def _annotate_linear(
|
|
|
193
195
|
weight_qspec = get_weight_qspec(quantization_config)
|
|
194
196
|
bias_qspec = get_bias_qspec(quantization_config)
|
|
195
197
|
for node in gm.graph.nodes:
|
|
196
|
-
if
|
|
198
|
+
if (
|
|
199
|
+
node.op != "call_function"
|
|
200
|
+
or node.target != torch.ops.aten.linear.default
|
|
201
|
+
):
|
|
197
202
|
continue
|
|
198
203
|
if filter_fn and not filter_fn(node):
|
|
199
204
|
continue
|
|
@@ -417,7 +422,9 @@ def _annotate_conv_bn(
|
|
|
417
422
|
Find conv + batchnorm parititions
|
|
418
423
|
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
|
419
424
|
"""
|
|
420
|
-
return _do_annotate_conv_bn(
|
|
425
|
+
return _do_annotate_conv_bn(
|
|
426
|
+
gm, quantization_config, filter_fn, has_relu=False
|
|
427
|
+
)
|
|
421
428
|
|
|
422
429
|
|
|
423
430
|
@register_annotator("conv_bn_relu")
|
|
@@ -486,7 +493,9 @@ def _do_annotate_conv_bn(
|
|
|
486
493
|
# Match against all conv dimensions and cuda variants
|
|
487
494
|
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
|
|
488
495
|
pattern = get_pattern(conv_fn, relu_is_inplace)
|
|
489
|
-
pattern = _get_aten_graph_module_for_pattern(
|
|
496
|
+
pattern = _get_aten_graph_module_for_pattern(
|
|
497
|
+
pattern, example_inputs, is_cuda
|
|
498
|
+
)
|
|
490
499
|
pattern.graph.eliminate_dead_code()
|
|
491
500
|
pattern.recompile()
|
|
492
501
|
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
|
@@ -676,7 +685,9 @@ def _annotate_adaptive_avg_pool2d(
|
|
|
676
685
|
and pool_node.target != torch.ops.aten.mean.dim
|
|
677
686
|
and pool_node.target != torch.ops.aten.as_strided_.default
|
|
678
687
|
):
|
|
679
|
-
raise ValueError(
|
|
688
|
+
raise ValueError(
|
|
689
|
+
f"{pool_node} is not an aten adaptive_avg_pool2d operator"
|
|
690
|
+
)
|
|
680
691
|
|
|
681
692
|
if _is_annotated([pool_node]):
|
|
682
693
|
continue
|
|
@@ -741,7 +752,8 @@ def _annotate_fixed_qparams(
|
|
|
741
752
|
continue
|
|
742
753
|
|
|
743
754
|
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
|
744
|
-
output_qspec=get_fixed_qparams_qspec(quantization_config),
|
|
755
|
+
output_qspec=get_fixed_qparams_qspec(quantization_config),
|
|
756
|
+
_annotated=True,
|
|
745
757
|
)
|
|
746
758
|
_mark_nodes_as_annotated(partition)
|
|
747
759
|
annotated_partitions.append(partition)
|
|
@@ -885,7 +897,9 @@ def _annotate_mul(
|
|
|
885
897
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
886
898
|
) -> Optional[List[List[Node]]]:
|
|
887
899
|
mul_partitions = get_source_partitions(
|
|
888
|
-
gm.graph,
|
|
900
|
+
gm.graph,
|
|
901
|
+
["mul", "mul_", operator.mul, torch.mul, operator.imul],
|
|
902
|
+
filter_fn,
|
|
889
903
|
)
|
|
890
904
|
mul_partitions = list(itertools.chain(*mul_partitions.values()))
|
|
891
905
|
annotated_partitions = []
|
|
@@ -932,8 +946,9 @@ def _annotate_cat(
|
|
|
932
946
|
|
|
933
947
|
if cat_node.target != torch.ops.aten.cat.default:
|
|
934
948
|
raise Exception(
|
|
935
|
-
|
|
936
|
-
" please check if you are calling the correct
|
|
949
|
+
"Expected cat node: torch.ops.aten.cat.default, but found"
|
|
950
|
+
f" {cat_node.target} please check if you are calling the correct"
|
|
951
|
+
" capture API"
|
|
937
952
|
)
|
|
938
953
|
|
|
939
954
|
annotated_partitions.append(cat_partition.nodes)
|
|
@@ -987,7 +1002,9 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
|
|
|
987
1002
|
if not isinstance(prev_node, Node):
|
|
988
1003
|
continue
|
|
989
1004
|
|
|
990
|
-
quantization_annotation = prev_node.meta.get(
|
|
1005
|
+
quantization_annotation = prev_node.meta.get(
|
|
1006
|
+
"quantization_annotation", None
|
|
1007
|
+
)
|
|
991
1008
|
if not quantization_annotation:
|
|
992
1009
|
continue
|
|
993
1010
|
|
|
@@ -1014,7 +1031,9 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
|
|
|
1014
1031
|
|
|
1015
1032
|
|
|
1016
1033
|
# TODO: make the list of ops customizable
|
|
1017
|
-
def _convert_scalars_to_attrs(
|
|
1034
|
+
def _convert_scalars_to_attrs(
|
|
1035
|
+
model: torch.fx.GraphModule,
|
|
1036
|
+
) -> torch.fx.GraphModule:
|
|
1018
1037
|
for n in model.graph.nodes:
|
|
1019
1038
|
if n.op != "call_function" or n.target not in [
|
|
1020
1039
|
torch.ops.aten.add.Tensor,
|
|
@@ -76,6 +76,10 @@ class QuantConfig:
|
|
|
76
76
|
elif generative_recipe is not None:
|
|
77
77
|
generative_recipe.verify()
|
|
78
78
|
object.__setattr__(self, 'generative_recipe', generative_recipe)
|
|
79
|
-
object.__setattr__(
|
|
79
|
+
object.__setattr__(
|
|
80
|
+
self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER
|
|
81
|
+
)
|
|
80
82
|
else:
|
|
81
|
-
raise ValueError(
|
|
83
|
+
raise ValueError(
|
|
84
|
+
'Either pt2e_quantizer or generative_recipe must be set.'
|
|
85
|
+
)
|