tico 0.2.0.dev260511__tar.gz → 0.2.0.dev260513__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/PKG-INFO +1 -1
- tico-0.2.0.dev260513/tico/_version.py +1 -0
- tico-0.2.0.dev260513/tico/passes/const_prop_pass.py +595 -0
- tico-0.2.0.dev260513/tico/passes/remove_unused_placeholder.py +130 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/gptq/quantizer.py +47 -14
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/builders.py +23 -5
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/gptq.py +15 -0
- tico-0.2.0.dev260513/tico/quantization/config/llama_attention.py +209 -0
- tico-0.2.0.dev260513/tico/quantization/evaluation/mmmu_eval_utils.py +411 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/vlm_eval_utils.py +125 -12
- tico-0.2.0.dev260513/tico/quantization/passes/quantize_bias.py +145 -0
- tico-0.2.0.dev260513/tico/quantization/passes/remove_weight_dequant_op.py +292 -0
- tico-0.2.0.dev260513/tico/quantization/wrapq/examples/nn/quantize_tied_embedding.py +330 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +334 -79
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/quantize_qwen3_vl_with_gptq.py +59 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/llama/export_adapters.py +197 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/llama/quant_attention.py +391 -90
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +32 -5
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/llama/quant_model.py +22 -1
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/circle_serializer.py +175 -14
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/convert.py +2 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico.egg-info/PKG-INFO +1 -1
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico.egg-info/SOURCES.txt +4 -1
- tico-0.2.0.dev260511/tico/_version.py +0 -1
- tico-0.2.0.dev260511/tico/passes/const_prop_pass.py +0 -307
- tico-0.2.0.dev260511/tico/quantization/passes/quantize_bias.py +0 -122
- tico-0.2.0.dev260511/tico/quantization/passes/remove_weight_dequant_op.py +0 -177
- tico-0.2.0.dev260511/tico/quantization/wrapq/examples/quantize_full_vlm_model_with_gptq.py +0 -257
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/LICENSE +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/README.md +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/pyproject.toml +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/setup.cfg +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/config/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/config/base.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/config/factory.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/config/v1.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/experimental/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/interpreter/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/interpreter/infer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/interpreter/interpreter.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/cast_aten_where_arg_type.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/cast_clamp_mixed_type_args.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/cast_mixed_type_args.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/convert_conv1d_to_conv2d.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/convert_conv3d_to_conv2d.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/convert_expand_to_slice_cat.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/convert_layout_op_to_reshape.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/convert_matmul_to_linear.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/convert_repeat_to_expand_copy.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/convert_sym_size_to_circle_shape.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/convert_to_relu6.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/decompose_addmm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/decompose_batch_norm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/decompose_fake_quantize.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/decompose_fake_quantize_tensor_qparams.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/decompose_group_norm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/decompose_grouped_conv2d.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/decompose_slice_scatter.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/eliminate_rank_round_trip_region.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/extract_dtype_kwargs.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/fill_meta_val.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/fuse_leading_unsqueeze_reshape.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/fuse_redundant_reshape_to_mean.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/legalize_causal_mask_value.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/legalize_predefined_layout_operators.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/lower_copy.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/lower_pow2_to_mul.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/lower_to_resize_nearest_neighbor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/lower_to_slice.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/merge_consecutive_cat.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/ops.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/remove_nop.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/remove_redundant_assert_nodes.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/remove_redundant_expand.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/remove_redundant_permute.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/remove_redundant_reshape.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/remove_redundant_slice.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/remove_redundant_to_copy.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/restore_linear.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/passes/segment_index_select.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/pt2_to_circle.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/cle/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/cle/cle.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/cle/quantizer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/fpi_gptq/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/fpi_gptq/quantizer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/gptq/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/gptq/gptq.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/gptq/quant.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/gptq/utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/qwen3_vl_gptq/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/qwen3_vl_gptq/gptq.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/qwen3_vl_gptq/quantizer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/qwen3_vl_gptq/utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/smoothquant/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/smoothquant/observer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/smoothquant/quantizer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/smoothquant/smooth_quant.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/spinquant/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/spinquant/fuse_norm_utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/spinquant/hadamard_utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/spinquant/quantizer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/spinquant/rotation_utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/algorithm/spinquant/spin_llama.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/base.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/cle.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/fpi_gptq.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/ptq.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/qwen3_vl_gptq.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/smoothquant.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/spinquant.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/config/utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/backend.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/evaluate.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/executor/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/executor/backend_executor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/executor/circle_executor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/executor/triv24_executor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/metric.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/mmlu_eval_utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/script/llm_tasks_eval.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/script/mini_vqa_eval.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/evaluation/utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/passes/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/passes/fold_quant_ops.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/passes/propagate_qparam_backward.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/passes/propagate_qparam_forward.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/public_interface.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/quantizer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/quantizer_registry.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/dtypes.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/compare_ppl.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/debug_quant_outputs.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/evaluate_fk_llama_model.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/llama/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/llama/quantize_attention_decode.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/llama/quantize_attention_prefill.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_decode.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_prefill.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/llama/quantize_mlp.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/nn/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/nn/quantize_conv3d.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/nn/quantize_conv3d_special_case.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/nn/quantize_layernorm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/nn/quantize_linear.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/quantize_with_gptq.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_for_conditional_generation.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_model.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_text_attention.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_text_decoder_layer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_text_mlp.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_text_model.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_vision_attention.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_vision_block.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_vision_mlp.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_vision_model.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_vision_patch_embed.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/quantize_vision_patch_merger.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/qwen/trace_qwen.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/examples/static_llama_layer_runtime.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/mode.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/observers/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/observers/affine_base.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/observers/base.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/observers/ema.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/observers/identity.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/observers/minmax.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/observers/mx.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/qscheme.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/quantizer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/utils/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/utils/check_missing_qparam.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/utils/introspection.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/utils/metrics.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/utils/reduce_utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/utils/utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/utils/version.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrap_helper.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/fairseq/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/llama/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/llama/quant_mlp.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/nn/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/nn/quant_conv3d.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/nn/quant_conv3d_decomposed.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/nn/quant_embedding.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/nn/quant_linear.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/nn/quant_silu.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/ops/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/ops/quant_rmsnorm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/ptq_wrapper.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/quant_elementwise.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/quant_module_base.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_for_conditional_generation.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attention.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_mlp.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_attention.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_block.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_embed.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/quantization/wrapq/wrappers/registry.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/circle_graph.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/circle_mapping.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/adapters/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/adapters/llama_rmsnorm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/adapters/onert/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/adapters/onert/llama_attention.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/hashable_opcode.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/node_visitor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_abs.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_add.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_alias_copy.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_any.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_arange_start_step.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_argmax.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_attention.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_avg_pool2d.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_bmm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_cat.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_circle_shape.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_clamp.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_clone.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_constant_pad_nd.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_conv2d.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_cos.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_cumsum.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_depthwise_conv2d.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_dequantize_per_channel.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_dequantize_per_tensor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_div.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_embedding.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_eq.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_exp.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_expand.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_full.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_full_like.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_ge.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_gelu.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_gt.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_index.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_index_select.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_instance_norm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_le.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_leaky_relu.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_linear.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_log.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_log1p.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_logical_and.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_logical_not.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_lt.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_max_dim.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_max_pool2d_with_indices.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_maximum.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_mean.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_minimum.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_mm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_mul.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_ne.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_neg.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_permute.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_pow.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_prelu.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_quantize_per_tensor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_reciprocal.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_relu.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_relu6.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_repeat.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_reshape.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_resize_nearest_neighbor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_rmsnorm.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_round.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_rsqrt.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_scalar_tensor.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_select_copy.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_sigmoid.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_sin.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_slice.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_softmax.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_split_with_sizes.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_sqrt.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_squeeze.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_sub.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_sum.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_tanh.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_to_copy.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_transpose_conv.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_unsqueeze.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_view.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/op_where.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/operators/utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/pack.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/serialize/quant_param.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/compat/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/compat/torch.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/compat/transformers.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/define.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/diff_graph.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/dtype.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/errors.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/graph.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/installed_packages.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/logging.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/model.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/mx/__init__.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/mx/elemwise_ops.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/mx/formats.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/mx/mx_ops.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/padding.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/passes.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/pytree_utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/record_input.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/register_custom_op.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/serialize.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/signature.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/trace_decorators.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/utils.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/validate_args_kwargs.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico/utils/version.py +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico.egg-info/dependency_links.txt +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico.egg-info/entry_points.txt +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico.egg-info/requires.txt +0 -0
- {tico-0.2.0.dev260511 → tico-0.2.0.dev260513}/tico.egg-info/top_level.txt +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.0.dev260513"
|
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
# Portions of this file are adapted from code originally authored by
|
|
2
|
+
# Meta Platforms, Inc. and affiliates, licensed under the BSD-style
|
|
3
|
+
# license found in the LICENSE file in the root directory of their source tree.
|
|
4
|
+
|
|
5
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
6
|
+
#
|
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
8
|
+
# you may not use this file except in compliance with the License.
|
|
9
|
+
# You may obtain a copy of the License at
|
|
10
|
+
#
|
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
12
|
+
#
|
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
16
|
+
# See the License for the specific language governing permissions and
|
|
17
|
+
# limitations under the License.
|
|
18
|
+
|
|
19
|
+
# https://github.com/pytorch/executorch/blob/61ddee5/exir/passes/constant_prop_pass.py
|
|
20
|
+
|
|
21
|
+
from collections import OrderedDict
|
|
22
|
+
from typing import Any, List, Mapping, Optional, TYPE_CHECKING
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
import torch.fx
|
|
26
|
+
import torch
|
|
27
|
+
from torch._export.utils import (
|
|
28
|
+
get_buffer,
|
|
29
|
+
get_lifted_tensor_constant,
|
|
30
|
+
get_param,
|
|
31
|
+
is_buffer,
|
|
32
|
+
is_lifted_tensor_constant,
|
|
33
|
+
is_param,
|
|
34
|
+
)
|
|
35
|
+
from torch.export import ExportedProgram
|
|
36
|
+
from torch.export.exported_program import InputKind, InputSpec
|
|
37
|
+
from torch.utils import _pytree as pytree
|
|
38
|
+
|
|
39
|
+
from tico.serialize.circle_graph import _PRIMITIVE_TYPES
|
|
40
|
+
from tico.utils import logging
|
|
41
|
+
from tico.utils.graph import create_input_spec, generate_fqn, get_first_user_input
|
|
42
|
+
from tico.utils.passes import PassBase, PassResult
|
|
43
|
+
from tico.utils.trace_decorators import (
|
|
44
|
+
trace_const_diff_on_pass,
|
|
45
|
+
trace_graph_diff_on_pass,
|
|
46
|
+
)
|
|
47
|
+
from tico.utils.utils import get_fake_mode
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
_MISSING = object()
|
|
51
|
+
|
|
52
|
+
TensorStorageIdentityKey = tuple[
|
|
53
|
+
str,
|
|
54
|
+
int,
|
|
55
|
+
int,
|
|
56
|
+
tuple[int, ...],
|
|
57
|
+
tuple[int, ...],
|
|
58
|
+
torch.dtype,
|
|
59
|
+
torch.layout,
|
|
60
|
+
]
|
|
61
|
+
QuantizedTensorCacheKey = tuple[Any, ...]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _get_quantized_decomposed_default_op(op_name: str) -> Any | None:
|
|
65
|
+
"""Return a quantized_decomposed default op if it is registered.
|
|
66
|
+
|
|
67
|
+
Quantized decomposed ops may not be registered when this module is imported.
|
|
68
|
+
Therefore, all access to torch.ops.quantized_decomposed must be lazy.
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
namespace = torch.ops.quantized_decomposed
|
|
72
|
+
except AttributeError:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
overload_packet = getattr(namespace, op_name)
|
|
77
|
+
except AttributeError:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
return overload_packet.default
|
|
82
|
+
except AttributeError:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _get_dequantize_ops() -> tuple[Any, ...]:
|
|
87
|
+
"""Return registered quantized_decomposed dequantize ops."""
|
|
88
|
+
return tuple(
|
|
89
|
+
op
|
|
90
|
+
for op in (
|
|
91
|
+
_get_quantized_decomposed_default_op("dequantize_per_channel"),
|
|
92
|
+
_get_quantized_decomposed_default_op("dequantize_per_tensor"),
|
|
93
|
+
)
|
|
94
|
+
if op is not None
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _get_quantize_ops() -> tuple[Any, ...]:
|
|
99
|
+
"""Return registered quantized_decomposed quantize ops."""
|
|
100
|
+
return tuple(
|
|
101
|
+
op
|
|
102
|
+
for op in (
|
|
103
|
+
_get_quantized_decomposed_default_op("quantize_per_channel"),
|
|
104
|
+
_get_quantized_decomposed_default_op("quantize_per_tensor"),
|
|
105
|
+
)
|
|
106
|
+
if op is not None
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _get_tensor_storage_identity_key(
|
|
111
|
+
tensor: torch.Tensor,
|
|
112
|
+
) -> Optional[TensorStorageIdentityKey]:
|
|
113
|
+
"""Return a hashable key that identifies the logical storage of a tensor.
|
|
114
|
+
|
|
115
|
+
The key is based on storage identity, not tensor contents. This avoids
|
|
116
|
+
accidentally merging cloned tensors that happen to have the same values,
|
|
117
|
+
while still detecting tied weights that share the same tensor storage.
|
|
118
|
+
"""
|
|
119
|
+
if tensor.layout != torch.strided:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
if tensor.numel() == 0:
|
|
123
|
+
# Empty tensors often have a zero data pointer, so unrelated empty
|
|
124
|
+
# tensors may look identical.
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
data_ptr = tensor.data_ptr()
|
|
129
|
+
except RuntimeError:
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
if data_ptr == 0:
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
return (
|
|
136
|
+
str(tensor.device),
|
|
137
|
+
data_ptr,
|
|
138
|
+
tensor.storage_offset(),
|
|
139
|
+
tuple(tensor.shape),
|
|
140
|
+
tuple(tensor.stride()),
|
|
141
|
+
tensor.dtype,
|
|
142
|
+
tensor.layout,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _make_hashable_constant(value: Any) -> Any:
|
|
147
|
+
"""Convert constant values into hashable values for cache keys.
|
|
148
|
+
|
|
149
|
+
Quantization parameters are part of the operation semantics, so tensor
|
|
150
|
+
values used as quantization parameters are compared by value. The quantized
|
|
151
|
+
weight input itself is handled separately by storage identity.
|
|
152
|
+
"""
|
|
153
|
+
if isinstance(value, torch.Tensor):
|
|
154
|
+
tensor = value.detach()
|
|
155
|
+
if tensor.layout != torch.strided:
|
|
156
|
+
return ("tensor", str(tensor.layout), repr(tensor))
|
|
157
|
+
|
|
158
|
+
cpu_tensor = tensor.cpu().contiguous()
|
|
159
|
+
return (
|
|
160
|
+
"tensor",
|
|
161
|
+
cpu_tensor.dtype,
|
|
162
|
+
tuple(cpu_tensor.shape),
|
|
163
|
+
tuple(cpu_tensor.reshape(-1).tolist()),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if isinstance(value, (tuple, list)):
|
|
167
|
+
return tuple(_make_hashable_constant(v) for v in value)
|
|
168
|
+
|
|
169
|
+
if isinstance(value, dict):
|
|
170
|
+
return tuple(sorted((k, _make_hashable_constant(v)) for k, v in value.items()))
|
|
171
|
+
|
|
172
|
+
if isinstance(value, torch.dtype):
|
|
173
|
+
return ("torch.dtype", str(value))
|
|
174
|
+
|
|
175
|
+
if isinstance(value, torch.device):
|
|
176
|
+
return ("torch.device", str(value))
|
|
177
|
+
|
|
178
|
+
if isinstance(value, torch.layout):
|
|
179
|
+
return ("torch.layout", str(value))
|
|
180
|
+
|
|
181
|
+
if isinstance(value, (str, int, float, bool, type(None))):
|
|
182
|
+
return value
|
|
183
|
+
|
|
184
|
+
try:
|
|
185
|
+
hash(value)
|
|
186
|
+
return value
|
|
187
|
+
except TypeError:
|
|
188
|
+
return repr(value)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _get_argument_value(
|
|
192
|
+
args: tuple[Any, ...],
|
|
193
|
+
kwargs: Mapping[str, Any],
|
|
194
|
+
position: int,
|
|
195
|
+
names: tuple[str, ...],
|
|
196
|
+
) -> Any:
|
|
197
|
+
"""Return an argument value by position or by one of its possible names."""
|
|
198
|
+
if len(args) > position:
|
|
199
|
+
return args[position]
|
|
200
|
+
|
|
201
|
+
for name in names:
|
|
202
|
+
if name in kwargs:
|
|
203
|
+
return kwargs[name]
|
|
204
|
+
|
|
205
|
+
return _MISSING
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _get_quantized_tensor_cache_key(
|
|
209
|
+
node: torch.fx.Node,
|
|
210
|
+
args_data: tuple[Any, ...],
|
|
211
|
+
kwargs_data: Mapping[str, Any],
|
|
212
|
+
) -> Optional[QuantizedTensorCacheKey]:
|
|
213
|
+
"""Return a cache key for quantizing tied constant tensors.
|
|
214
|
+
|
|
215
|
+
The source tensor is keyed by storage identity, while quantization
|
|
216
|
+
parameters are keyed by value. This lets tied weights reuse a single
|
|
217
|
+
propagated quantized tensor only when the quantization operation is
|
|
218
|
+
semantically identical.
|
|
219
|
+
"""
|
|
220
|
+
quantize_per_tensor = _get_quantized_decomposed_default_op("quantize_per_tensor")
|
|
221
|
+
quantize_per_channel = _get_quantized_decomposed_default_op("quantize_per_channel")
|
|
222
|
+
|
|
223
|
+
if quantize_per_tensor is not None and node.target == quantize_per_tensor:
|
|
224
|
+
input_tensor = _get_argument_value(
|
|
225
|
+
args_data, kwargs_data, 0, ("input", "tensor")
|
|
226
|
+
)
|
|
227
|
+
scale = _get_argument_value(args_data, kwargs_data, 1, ("scale",))
|
|
228
|
+
zero_point = _get_argument_value(
|
|
229
|
+
args_data, kwargs_data, 2, ("zero_point", "zero_p")
|
|
230
|
+
)
|
|
231
|
+
quant_min = _get_argument_value(args_data, kwargs_data, 3, ("quant_min",))
|
|
232
|
+
quant_max = _get_argument_value(args_data, kwargs_data, 4, ("quant_max",))
|
|
233
|
+
dtype = _get_argument_value(args_data, kwargs_data, 5, ("dtype",))
|
|
234
|
+
|
|
235
|
+
if any(
|
|
236
|
+
x is _MISSING
|
|
237
|
+
for x in (input_tensor, scale, zero_point, quant_min, quant_max, dtype)
|
|
238
|
+
):
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
if not isinstance(input_tensor, torch.Tensor):
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
input_key = _get_tensor_storage_identity_key(input_tensor)
|
|
245
|
+
if input_key is None:
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
return (
|
|
249
|
+
str(node.target),
|
|
250
|
+
input_key,
|
|
251
|
+
_make_hashable_constant(scale),
|
|
252
|
+
_make_hashable_constant(zero_point),
|
|
253
|
+
_make_hashable_constant(quant_min),
|
|
254
|
+
_make_hashable_constant(quant_max),
|
|
255
|
+
_make_hashable_constant(dtype),
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
if quantize_per_channel is not None and node.target == quantize_per_channel:
|
|
259
|
+
input_tensor = _get_argument_value(
|
|
260
|
+
args_data, kwargs_data, 0, ("input", "tensor")
|
|
261
|
+
)
|
|
262
|
+
scales = _get_argument_value(args_data, kwargs_data, 1, ("scales", "scale"))
|
|
263
|
+
zero_points = _get_argument_value(
|
|
264
|
+
args_data, kwargs_data, 2, ("zero_points", "zero_point", "zero_p")
|
|
265
|
+
)
|
|
266
|
+
axis = _get_argument_value(args_data, kwargs_data, 3, ("axis",))
|
|
267
|
+
quant_min = _get_argument_value(args_data, kwargs_data, 4, ("quant_min",))
|
|
268
|
+
quant_max = _get_argument_value(args_data, kwargs_data, 5, ("quant_max",))
|
|
269
|
+
dtype = _get_argument_value(args_data, kwargs_data, 6, ("dtype",))
|
|
270
|
+
|
|
271
|
+
if any(
|
|
272
|
+
x is _MISSING
|
|
273
|
+
for x in (
|
|
274
|
+
input_tensor,
|
|
275
|
+
scales,
|
|
276
|
+
zero_points,
|
|
277
|
+
axis,
|
|
278
|
+
quant_min,
|
|
279
|
+
quant_max,
|
|
280
|
+
dtype,
|
|
281
|
+
)
|
|
282
|
+
):
|
|
283
|
+
return None
|
|
284
|
+
|
|
285
|
+
if not isinstance(input_tensor, torch.Tensor):
|
|
286
|
+
return None
|
|
287
|
+
|
|
288
|
+
input_key = _get_tensor_storage_identity_key(input_tensor)
|
|
289
|
+
if input_key is None:
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
return (
|
|
293
|
+
str(node.target),
|
|
294
|
+
input_key,
|
|
295
|
+
_make_hashable_constant(scales),
|
|
296
|
+
_make_hashable_constant(zero_points),
|
|
297
|
+
_make_hashable_constant(axis),
|
|
298
|
+
_make_hashable_constant(quant_min),
|
|
299
|
+
_make_hashable_constant(quant_max),
|
|
300
|
+
_make_hashable_constant(dtype),
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def get_constant_placeholder_to_tensor_dict(
|
|
307
|
+
exported_program: ExportedProgram,
|
|
308
|
+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
|
|
309
|
+
"""Return a dictionary from constant placeholder nodes to constant tensors."""
|
|
310
|
+
const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict()
|
|
311
|
+
graph_module = exported_program.graph_module
|
|
312
|
+
graph: torch.fx.Graph = graph_module.graph
|
|
313
|
+
for node in graph.nodes:
|
|
314
|
+
if node.op != "placeholder":
|
|
315
|
+
continue
|
|
316
|
+
tensor: Optional[torch.Tensor] = None
|
|
317
|
+
if is_param(exported_program, node):
|
|
318
|
+
tensor = get_param(exported_program, node)
|
|
319
|
+
elif is_buffer(exported_program, node):
|
|
320
|
+
tensor = get_buffer(exported_program, node)
|
|
321
|
+
elif is_lifted_tensor_constant(exported_program, node):
|
|
322
|
+
tensor = get_lifted_tensor_constant(exported_program, node)
|
|
323
|
+
|
|
324
|
+
if tensor is not None:
|
|
325
|
+
assert node not in const_node_to_tensor
|
|
326
|
+
const_node_to_tensor[node] = tensor
|
|
327
|
+
|
|
328
|
+
return const_node_to_tensor
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def has_constant_data(arg, const_node_to_tensor=None) -> bool:
|
|
332
|
+
"""Check whether an argument has constant data.
|
|
333
|
+
|
|
334
|
+
Placeholder nodes are checked against the exported program's constant
|
|
335
|
+
placeholder mapping because placeholders do not carry enough information by
|
|
336
|
+
themselves to distinguish constants from user inputs.
|
|
337
|
+
"""
|
|
338
|
+
if isinstance(arg, (tuple, list)):
|
|
339
|
+
return all(has_constant_data(a, const_node_to_tensor) for a in arg)
|
|
340
|
+
elif isinstance(arg, dict):
|
|
341
|
+
return all(has_constant_data(a, const_node_to_tensor) for a in arg.values())
|
|
342
|
+
elif isinstance(
|
|
343
|
+
arg,
|
|
344
|
+
_PRIMITIVE_TYPES,
|
|
345
|
+
):
|
|
346
|
+
return True
|
|
347
|
+
elif not isinstance(arg, torch.fx.Node):
|
|
348
|
+
return False
|
|
349
|
+
elif const_node_to_tensor is not None and arg in const_node_to_tensor:
|
|
350
|
+
return True
|
|
351
|
+
|
|
352
|
+
return False
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def get_data(
|
|
356
|
+
arg,
|
|
357
|
+
exported_program: ExportedProgram,
|
|
358
|
+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
|
|
359
|
+
):
|
|
360
|
+
"""Return concrete constant data for a constant argument."""
|
|
361
|
+
if isinstance(arg, (tuple, list)):
|
|
362
|
+
return (get_data(x, exported_program, const_node_to_tensor) for x in arg)
|
|
363
|
+
elif isinstance(arg, _PRIMITIVE_TYPES):
|
|
364
|
+
return arg
|
|
365
|
+
elif arg in const_node_to_tensor:
|
|
366
|
+
return const_node_to_tensor[arg]
|
|
367
|
+
return None
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def propagate_constants(
|
|
371
|
+
exported_program: ExportedProgram,
|
|
372
|
+
) -> OrderedDict[torch.fx.Node, torch.Tensor]:
|
|
373
|
+
"""Propagate constants and return node-to-constant tensor mappings.
|
|
374
|
+
|
|
375
|
+
Quantize ops are cached by tied source tensor identity and quantization
|
|
376
|
+
parameters. This preserves tied weight sharing through constant propagation:
|
|
377
|
+
two quantize nodes that quantize the same tied weight with the same
|
|
378
|
+
quantization parameters reuse the same propagated quantized tensor object.
|
|
379
|
+
"""
|
|
380
|
+
const_node_to_tensor = get_constant_placeholder_to_tensor_dict(exported_program)
|
|
381
|
+
quantized_tensor_cache: dict[QuantizedTensorCacheKey, torch.Tensor] = {}
|
|
382
|
+
|
|
383
|
+
dequantize_ops = _get_dequantize_ops()
|
|
384
|
+
quantize_ops = _get_quantize_ops()
|
|
385
|
+
|
|
386
|
+
graph_module = exported_program.graph_module
|
|
387
|
+
graph: torch.fx.Graph = graph_module.graph
|
|
388
|
+
for node in graph.nodes:
|
|
389
|
+
if node.op != "call_function":
|
|
390
|
+
continue
|
|
391
|
+
if node.target in dequantize_ops:
|
|
392
|
+
continue
|
|
393
|
+
if not has_constant_data(
|
|
394
|
+
[node.args, node.kwargs],
|
|
395
|
+
const_node_to_tensor,
|
|
396
|
+
):
|
|
397
|
+
continue
|
|
398
|
+
|
|
399
|
+
args_data, kwargs_data = pytree.tree_map(
|
|
400
|
+
lambda x: get_data(x, exported_program, const_node_to_tensor),
|
|
401
|
+
(node.args, node.kwargs),
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
quantized_tensor_cache_key = None
|
|
405
|
+
if node.target in quantize_ops:
|
|
406
|
+
quantized_tensor_cache_key = _get_quantized_tensor_cache_key(
|
|
407
|
+
node, args_data, kwargs_data
|
|
408
|
+
)
|
|
409
|
+
if (
|
|
410
|
+
quantized_tensor_cache_key is not None
|
|
411
|
+
and quantized_tensor_cache_key in quantized_tensor_cache
|
|
412
|
+
):
|
|
413
|
+
const_node_to_tensor[node] = quantized_tensor_cache[
|
|
414
|
+
quantized_tensor_cache_key
|
|
415
|
+
]
|
|
416
|
+
continue
|
|
417
|
+
|
|
418
|
+
# Propagate constant because all of its args are constant tensors.
|
|
419
|
+
with torch.no_grad():
|
|
420
|
+
prop_constant_tensor = node.target(*args_data, **kwargs_data)
|
|
421
|
+
|
|
422
|
+
if quantized_tensor_cache_key is not None:
|
|
423
|
+
quantized_tensor_cache[quantized_tensor_cache_key] = prop_constant_tensor
|
|
424
|
+
|
|
425
|
+
const_node_to_tensor[node] = prop_constant_tensor
|
|
426
|
+
|
|
427
|
+
return const_node_to_tensor
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def erase_constant_node(
|
|
431
|
+
exported_program: ExportedProgram,
|
|
432
|
+
node: torch.fx.Node,
|
|
433
|
+
) -> None:
|
|
434
|
+
"""Remove the corresponding tensor from parameter or constant dictionaries.
|
|
435
|
+
|
|
436
|
+
The input signature maps do not need to be updated here because the final
|
|
437
|
+
input specs are rebuilt at the end of this pass.
|
|
438
|
+
"""
|
|
439
|
+
signature = exported_program.graph_signature
|
|
440
|
+
if name := signature.inputs_to_parameters.get(node.name, None):
|
|
441
|
+
exported_program.state_dict.pop(name, None)
|
|
442
|
+
elif name := signature.inputs_to_lifted_tensor_constants.get(node.name, None):
|
|
443
|
+
exported_program.constants.pop(name, None)
|
|
444
|
+
elif name := signature.inputs_to_buffers.get(node.name, None):
|
|
445
|
+
exported_program.constants.pop(name, None)
|
|
446
|
+
exported_program.state_dict.pop(name, None)
|
|
447
|
+
|
|
448
|
+
# Remove from graph.
|
|
449
|
+
exported_program.graph.erase_node(node)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def create_constant_placeholder(
|
|
453
|
+
const_node_to_tensor: Mapping[torch.fx.Node, torch.Tensor],
|
|
454
|
+
exported_program: ExportedProgram,
|
|
455
|
+
) -> List[torch.fx.Node]:
|
|
456
|
+
"""Create constant placeholder nodes for propagated constant tensors.
|
|
457
|
+
|
|
458
|
+
If multiple propagated nodes share the same tensor object, only one
|
|
459
|
+
placeholder is created and the other nodes are replaced by that placeholder.
|
|
460
|
+
This is used for tied weights whose quantize ops are cached by
|
|
461
|
+
`propagate_constants`.
|
|
462
|
+
"""
|
|
463
|
+
placeholders = []
|
|
464
|
+
tensor_id_to_placeholder: dict[int, torch.fx.Node] = {}
|
|
465
|
+
|
|
466
|
+
fake_mode = get_fake_mode(exported_program)
|
|
467
|
+
first_user_input = get_first_user_input(exported_program)
|
|
468
|
+
if not first_user_input:
|
|
469
|
+
# Placeholder nodes must be the first N nodes in the nodes list of a graph.
|
|
470
|
+
# Therefore, insert the newly created placeholders at the start of the node list.
|
|
471
|
+
assert exported_program.graph.nodes
|
|
472
|
+
first_node = list(exported_program.graph.nodes)[0]
|
|
473
|
+
first_user_input = first_node
|
|
474
|
+
|
|
475
|
+
# Iterate over nodes in reverse order to insert created placeholder before the `first_user_input`.
|
|
476
|
+
for node, prop_constant_tensor in reversed(const_node_to_tensor.items()):
|
|
477
|
+
if all(x in const_node_to_tensor for x in node.users):
|
|
478
|
+
# All users of this constant node are also constant, so we don't need to create a new constant node.
|
|
479
|
+
erase_constant_node(exported_program, node)
|
|
480
|
+
continue
|
|
481
|
+
|
|
482
|
+
if node.op == "placeholder":
|
|
483
|
+
continue
|
|
484
|
+
|
|
485
|
+
tensor_id = id(prop_constant_tensor)
|
|
486
|
+
if tensor_id in tensor_id_to_placeholder:
|
|
487
|
+
const_placeholder_node = tensor_id_to_placeholder[tensor_id]
|
|
488
|
+
node.replace_all_uses_with(const_placeholder_node, propagate_meta=False)
|
|
489
|
+
exported_program.graph.erase_node(node)
|
|
490
|
+
continue
|
|
491
|
+
|
|
492
|
+
# Add `prop_constant_tensor` to program.state_dict.
|
|
493
|
+
prop_constant_tensor_fqn = generate_fqn(
|
|
494
|
+
"_prop_tensor_constant", exported_program
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Insert a new placeholder node for the propagated constant tensor.
|
|
498
|
+
with exported_program.graph.inserting_before(first_user_input):
|
|
499
|
+
const_placeholder_node = exported_program.graph.placeholder(
|
|
500
|
+
prop_constant_tensor_fqn
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# The key here should be same with "target" arg of InputSpec when creating input specs.
|
|
504
|
+
exported_program.constants[prop_constant_tensor_fqn] = prop_constant_tensor
|
|
505
|
+
|
|
506
|
+
# Replace the original node with the new constant node.
|
|
507
|
+
node.replace_all_uses_with(const_placeholder_node, propagate_meta=True)
|
|
508
|
+
exported_program.graph.erase_node(node)
|
|
509
|
+
|
|
510
|
+
# Update the meta data of the new placeholder node.
|
|
511
|
+
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
|
|
512
|
+
prop_constant_tensor, static_shapes=True
|
|
513
|
+
)
|
|
514
|
+
const_placeholder_node.meta["val"].constant = prop_constant_tensor
|
|
515
|
+
|
|
516
|
+
tensor_id_to_placeholder[tensor_id] = const_placeholder_node
|
|
517
|
+
placeholders.append(const_placeholder_node)
|
|
518
|
+
|
|
519
|
+
return placeholders
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def create_input_specs(
|
|
523
|
+
placeholders: List[torch.fx.Node],
|
|
524
|
+
) -> dict[str, InputSpec]:
|
|
525
|
+
"""Create input specs for newly created constant placeholders."""
|
|
526
|
+
name_to_spec: dict[str, InputSpec] = {}
|
|
527
|
+
|
|
528
|
+
# https://pytorch.org/docs/stable/export.ir_spec.html#placeholder
|
|
529
|
+
# %name = placeholder[target = name](args = ())
|
|
530
|
+
for node in placeholders:
|
|
531
|
+
name_to_spec[node.name] = create_input_spec(node, InputKind.CONSTANT_TENSOR)
|
|
532
|
+
|
|
533
|
+
return name_to_spec
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@trace_graph_diff_on_pass
|
|
537
|
+
@trace_const_diff_on_pass
|
|
538
|
+
class ConstPropPass(PassBase):
|
|
539
|
+
"""Perform constant folding and constant propagation.
|
|
540
|
+
|
|
541
|
+
The exported program guarantees that parameters, buffers, and constant
|
|
542
|
+
tensors are lifted out of the graph as inputs. Therefore, this pass updates
|
|
543
|
+
input specs after folding constant nodes.
|
|
544
|
+
|
|
545
|
+
[WHAT IT DOES]
|
|
546
|
+
[1] Propagate the constants.
|
|
547
|
+
[2] Get propagated data from constant nodes.
|
|
548
|
+
[3] Create the constant placeholder nodes according to the propagated data.
|
|
549
|
+
[4] Create input specs according to the created placeholders.
|
|
550
|
+
[5] Update the input specs.
|
|
551
|
+
"""
|
|
552
|
+
|
|
553
|
+
def __init__(self) -> None:
|
|
554
|
+
super().__init__()
|
|
555
|
+
|
|
556
|
+
def call(self, exported_program: ExportedProgram) -> PassResult:
|
|
557
|
+
logger = logging.getLogger(__name__)
|
|
558
|
+
|
|
559
|
+
graph_module = exported_program.graph_module
|
|
560
|
+
graph: torch.fx.Graph = graph_module.graph
|
|
561
|
+
|
|
562
|
+
# [1], [2]
|
|
563
|
+
const_node_to_tensor: OrderedDict[
|
|
564
|
+
torch.fx.Node, torch.Tensor
|
|
565
|
+
] = propagate_constants(exported_program)
|
|
566
|
+
# [3]
|
|
567
|
+
placeholders = create_constant_placeholder(
|
|
568
|
+
const_node_to_tensor, exported_program
|
|
569
|
+
)
|
|
570
|
+
# [4]
|
|
571
|
+
new_name_to_spec = create_input_specs(placeholders)
|
|
572
|
+
|
|
573
|
+
# [5]
|
|
574
|
+
# Get existing input specs.
|
|
575
|
+
existing_name_to_spec = {
|
|
576
|
+
s.arg.name: s for s in exported_program.graph_signature.input_specs
|
|
577
|
+
}
|
|
578
|
+
# Add the new constants to existing input specs dict.
|
|
579
|
+
existing_name_to_spec.update(new_name_to_spec)
|
|
580
|
+
# Generate new input spec.
|
|
581
|
+
new_input_specs = []
|
|
582
|
+
for node in exported_program.graph.nodes:
|
|
583
|
+
if node.op != "placeholder":
|
|
584
|
+
continue
|
|
585
|
+
assert node.name in existing_name_to_spec, node.name
|
|
586
|
+
new_input_specs.append(existing_name_to_spec[node.name])
|
|
587
|
+
exported_program.graph_signature.input_specs = new_input_specs
|
|
588
|
+
|
|
589
|
+
graph.eliminate_dead_code()
|
|
590
|
+
graph_module.recompile()
|
|
591
|
+
|
|
592
|
+
logger.debug("Constant nodes are propagated")
|
|
593
|
+
# Constant folding can be done with only one time run. Let's set `modified` to False.
|
|
594
|
+
modified = False
|
|
595
|
+
return PassResult(modified)
|