onnxscript 0.6.3.dev20260403__tar.gz → 0.6.3.dev20260411__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {onnxscript-0.6.3.dev20260403/onnxscript.egg-info → onnxscript-0.6.3.dev20260411}/PKG-INFO +2 -2
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/builder.py +48 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/core.py +14 -2
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/__init__.py +4 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +69 -8
- onnxscript-0.6.3.dev20260411/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py +295 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411/onnxscript.egg-info}/PKG-INFO +2 -2
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/SOURCES.txt +1 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/LICENSE +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/MANIFEST.in +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/README.md +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/VERSION +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_11.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_5.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_6.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_7.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_8.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_9.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/_inference.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/_inliner.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/analysis.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/ast_utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/autocast.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/converter.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/deprecation.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/evaluator.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/irbuilder.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/main.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/param_manipulation.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/sourceinfo.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/type_annotation.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/values.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/version_utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/onnx_backend.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/onnx_export.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/evaluator.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/_constants.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/_flags.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/graph_building/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/common.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/fft.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/linalg.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/nested.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/nn.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/prims.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/sparse.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/special.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/vision.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/registration.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/tensor_typing.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/_schemas.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/_tape.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/convenience.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/passes/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/passes/common/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/_module.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/_module_list.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/_parameter.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/_sequential.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset1.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset10.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset11.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset12.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset13.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset14.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset15.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset16.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset17.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset18.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset19.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset2.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset20.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset21.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset22.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset23.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset24.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset3.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset4.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset5.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset6.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset7.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset8.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset9.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml5.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_types.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/optimizer/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/optimizer/_constant_folding.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/optimizer/_optimizer.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/py.typed +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_basics.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_fusion_utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_ir_utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_matcher.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_pattern_ir.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_rewrite_rule.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_bart_encoder.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_phi2lm.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_phi4lm.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_rotary_embedding_models.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_smollm_1.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_smollm_2.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_test_models.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_whisper_decoder.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_whisper_encoder.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/onnx_fusions/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/onnxruntime/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/_core.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/_test_utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/attention.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/bias_gelu.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/erfgelu.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/gelu.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/gqa.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/group_normalization_merge_silu.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/mha.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/mha_bias.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/mha_scale.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/rms_normalization.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/rotary_embedding.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/sdpa.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/shape_optimization.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/skip_normalization.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/softmax.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/pattern.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_basic_rules.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_broadcast_to_matmul.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_cast_constant_of_shape.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_collapse_slices.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_conv_affine.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_hardswish.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_relus_clips.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_min_max_to_clip.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_no_op.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_remove_optional_bias.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/_gqa.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/_layer_norm.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/_rms_normalization.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/_rotary_embedding.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/testing.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tensor.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/testing/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/memory_peak.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/llama.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/mistral.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/phi.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/phi3.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/evaluation_utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/metadata_merger.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/replace.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/timing_utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/values.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/version_converter/__init__.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/version_converter/_c_api_utils.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/version_converter/_version_converter.py +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/dependency_links.txt +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/requires.txt +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/top_level.txt +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/pyproject.toml +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/setup.cfg +0 -0
- {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/setup.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnxscript
|
|
3
|
-
Version: 0.6.3.
|
|
3
|
+
Version: 0.6.3.dev20260411
|
|
4
4
|
Summary: Naturally author ONNX functions and models using a subset of Python
|
|
5
5
|
Author-email: Microsoft Corporation <onnx@microsoft.com>
|
|
6
6
|
License: MIT License
|
|
@@ -27,7 +27,7 @@ License: MIT License
|
|
|
27
27
|
|
|
28
28
|
Project-URL: Homepage, https://microsoft.github.io/onnxscript/
|
|
29
29
|
Project-URL: Repository, https://github.com/microsoft/onnxscript
|
|
30
|
-
Project-URL: Commit, https://github.com/microsoft/onnxscript/tree/
|
|
30
|
+
Project-URL: Commit, https://github.com/microsoft/onnxscript/tree/13f265cd01b21210267b86a24efdb0072c0ee374
|
|
31
31
|
Classifier: Development Status :: 4 - Beta
|
|
32
32
|
Classifier: Environment :: Console
|
|
33
33
|
Classifier: Intended Audience :: Developers
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/builder.py
RENAMED
|
@@ -276,6 +276,54 @@ class GraphBuilder:
|
|
|
276
276
|
self._graph.register_initializer(value)
|
|
277
277
|
return value
|
|
278
278
|
|
|
279
|
+
def input(
|
|
280
|
+
self,
|
|
281
|
+
name: str,
|
|
282
|
+
dtype: ir.DataType | None = None,
|
|
283
|
+
shape: ir.Shape | Sequence[int | str | None] | None = None,
|
|
284
|
+
*,
|
|
285
|
+
type: ir.TypeProtocol | None = None,
|
|
286
|
+
const_value: ir.TensorProtocol | None = None,
|
|
287
|
+
metadata_props: dict[str, str] | None = None,
|
|
288
|
+
) -> ir.Value:
|
|
289
|
+
"""Create an input to the graph and return the corresponding ir.Value.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
name: The name of the value.
|
|
293
|
+
dtype: The data type of the TensorType of the value. This is used only when type is None.
|
|
294
|
+
shape: The shape of the value.
|
|
295
|
+
type: The type of the value. Only one of dtype and type can be specified.
|
|
296
|
+
const_value: The constant tensor that initializes the value. Supply this argument
|
|
297
|
+
when you want to create an initializer. The type and shape can be obtained from the tensor.
|
|
298
|
+
metadata_props: The metadata properties that will be serialized to the ONNX proto.
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
A Value object.
|
|
302
|
+
"""
|
|
303
|
+
value = ir.val(
|
|
304
|
+
name=name,
|
|
305
|
+
dtype=dtype,
|
|
306
|
+
shape=shape,
|
|
307
|
+
type=type,
|
|
308
|
+
const_value=const_value,
|
|
309
|
+
metadata_props=metadata_props,
|
|
310
|
+
)
|
|
311
|
+
self._graph.inputs.append(value)
|
|
312
|
+
if const_value is not None:
|
|
313
|
+
self._graph.register_initializer(value)
|
|
314
|
+
return value
|
|
315
|
+
|
|
316
|
+
def add_output(self, value: ir.Value, name: str | None) -> None:
|
|
317
|
+
"""Add an output to the graph.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
value: The ir.Value to add as an output.
|
|
321
|
+
name: The name to assign to the output value. If None, no renaming is done.
|
|
322
|
+
"""
|
|
323
|
+
if name:
|
|
324
|
+
value.name = name
|
|
325
|
+
self._graph.outputs.append(value)
|
|
326
|
+
|
|
279
327
|
def _input_to_ir_value(
|
|
280
328
|
self, value: VALUE_LIKE, like_type: ir.Value | None = None
|
|
281
329
|
) -> ir.Value | None:
|
|
@@ -6231,7 +6231,7 @@ def aten_mean_complex(self: TReal) -> TReal:
|
|
|
6231
6231
|
|
|
6232
6232
|
|
|
6233
6233
|
@torch_op("aten::mean.dim", trace_only=True)
|
|
6234
|
-
def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
|
|
6234
|
+
def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False, dtype: int = -1) -> TReal:
|
|
6235
6235
|
"""mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
|
|
6236
6236
|
|
|
6237
6237
|
if len(self.shape) == 0:
|
|
@@ -6239,11 +6239,17 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
|
|
|
6239
6239
|
else:
|
|
6240
6240
|
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
|
|
6241
6241
|
result = op.ReduceMean(self, dims, keepdims=keepdim)
|
|
6242
|
+
|
|
6243
|
+
if dtype != -1 and dtype is not None:
|
|
6244
|
+
result = op.Cast(result, to=dtype)
|
|
6245
|
+
|
|
6242
6246
|
return result
|
|
6243
6247
|
|
|
6244
6248
|
|
|
6245
6249
|
@torch_op("aten::mean.dim", trace_only=True, complex=True)
|
|
6246
|
-
def aten_mean_dim_complex(
|
|
6250
|
+
def aten_mean_dim_complex(
|
|
6251
|
+
self: TReal, dim: INT64, keepdim: bool = False, dtype: int = -1
|
|
6252
|
+
) -> TReal:
|
|
6247
6253
|
"""mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
|
|
6248
6254
|
|
|
6249
6255
|
if len(self.shape) == 1:
|
|
@@ -6254,6 +6260,12 @@ def aten_mean_dim_complex(self: TReal, dim: INT64, keepdim: bool = False) -> TRe
|
|
|
6254
6260
|
dim = op.Where(op.Less(dim, zero), op.Sub(dim, one), dim)
|
|
6255
6261
|
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
|
|
6256
6262
|
result = op.ReduceMean(self, dims, keepdims=keepdim)
|
|
6263
|
+
|
|
6264
|
+
if dtype != -1 and dtype is not None:
|
|
6265
|
+
raise NotImplementedError(
|
|
6266
|
+
"support for the dtype argument is not implemented for complex tensors"
|
|
6267
|
+
)
|
|
6268
|
+
|
|
6257
6269
|
return result
|
|
6258
6270
|
|
|
6259
6271
|
|
|
@@ -12,6 +12,7 @@ __all__ = [
|
|
|
12
12
|
"div_by_1_rule",
|
|
13
13
|
"dropout_inference_rule",
|
|
14
14
|
"dropout_zero_rule",
|
|
15
|
+
"expand_before_binary_op_rules",
|
|
15
16
|
"flatten_to_reshape_rule",
|
|
16
17
|
"fuse_batchnorm_into_conv_rule",
|
|
17
18
|
"fuse_batchnorm_into_conv_transpose_rule",
|
|
@@ -125,6 +126,9 @@ from onnxscript.rewriter.rules.common._redundant_scatter_nd import (
|
|
|
125
126
|
no_op_dynamic_scatter_nd_rule,
|
|
126
127
|
no_op_static_scatter_nd_rule,
|
|
127
128
|
)
|
|
129
|
+
from onnxscript.rewriter.rules.common._remove_expand_before_binary_op import (
|
|
130
|
+
expand_before_binary_op_rules,
|
|
131
|
+
)
|
|
128
132
|
from onnxscript.rewriter.rules.common._remove_optional_bias import (
|
|
129
133
|
remove_optional_bias_from_conv_rule,
|
|
130
134
|
remove_optional_bias_from_conv_transpose_rule,
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# Licensed under the MIT License.
|
|
3
3
|
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
|
|
4
4
|
- BatchNormalization ∘ Conv -> Conv
|
|
5
|
-
- BatchNormalization ∘
|
|
5
|
+
- BatchNormalization ∘ ConvTranspose -> ConvTranspose
|
|
6
6
|
- BatchNormalization ∘ Gemm -> Gemm
|
|
7
7
|
|
|
8
8
|
Approach:
|
|
@@ -14,7 +14,7 @@ Approach:
|
|
|
14
14
|
- B_fused = (B - μ) * (gamma / std) + β
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from abc import ABC
|
|
17
|
+
from abc import ABC
|
|
18
18
|
from typing import ClassVar, Mapping
|
|
19
19
|
|
|
20
20
|
import numpy as np
|
|
@@ -33,9 +33,18 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra
|
|
|
33
33
|
class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
|
|
34
34
|
"""Interface for BatchNormalization nodes fusion."""
|
|
35
35
|
|
|
36
|
-
@abstractmethod
|
|
37
36
|
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
|
|
38
37
|
"""Return the axis along which BatchNorm scale should be broadcasted."""
|
|
38
|
+
raise NotImplementedError()
|
|
39
|
+
|
|
40
|
+
def _scale_weights(
|
|
41
|
+
self,
|
|
42
|
+
weights: np.ndarray,
|
|
43
|
+
scale_factor: np.ndarray,
|
|
44
|
+
attributes: Mapping[str, ir.Attr],
|
|
45
|
+
) -> np.ndarray:
|
|
46
|
+
axis = self.get_filters_axis(attributes)
|
|
47
|
+
return weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
|
|
39
48
|
|
|
40
49
|
def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
|
|
41
50
|
batchnorm_node = batchnorm_out.producer()
|
|
@@ -56,10 +65,8 @@ class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
|
|
|
56
65
|
inbound_node = inbound_out.producer()
|
|
57
66
|
weights = inbound_node.inputs[1].const_value.numpy()
|
|
58
67
|
|
|
59
|
-
# Reshape scale factor so it is broadcastable
|
|
60
|
-
axis = self.get_filters_axis(inbound_node.attributes)
|
|
61
68
|
fused_weights = ir.tensor(
|
|
62
|
-
weights
|
|
69
|
+
self._scale_weights(weights, scale_factor, inbound_node.attributes)
|
|
63
70
|
)
|
|
64
71
|
|
|
65
72
|
# Update bias
|
|
@@ -103,6 +110,23 @@ class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
|
|
|
103
110
|
if initializer.is_graph_input():
|
|
104
111
|
return check_result.fail(f"{initializer.name} is a graph input.")
|
|
105
112
|
|
|
113
|
+
# Check that the inbound node's weight and bias initializers are not shared
|
|
114
|
+
# with other nodes outside this matched pattern. When the fusion creates new
|
|
115
|
+
# initializers with the same name as the original shared weights, it overwrites
|
|
116
|
+
# the original initializer in the graph, leaving other nodes that reference the
|
|
117
|
+
# original value with an invalid (unregistered) input.
|
|
118
|
+
matched_nodes = {inbound_node, batchnorm_node}
|
|
119
|
+
inbound_initializers = [inbound_node.inputs[1]]
|
|
120
|
+
if len(inbound_node.inputs) > 2:
|
|
121
|
+
inbound_initializers.append(inbound_node.inputs[2])
|
|
122
|
+
for init_value in inbound_initializers:
|
|
123
|
+
for user, _ in init_value.uses():
|
|
124
|
+
if user not in matched_nodes:
|
|
125
|
+
return check_result.fail(
|
|
126
|
+
f"Initializer '{init_value.name}' is used by another node "
|
|
127
|
+
f"'{user.name}' outside the matched pattern."
|
|
128
|
+
)
|
|
129
|
+
|
|
106
130
|
return check_result
|
|
107
131
|
|
|
108
132
|
|
|
@@ -127,8 +151,26 @@ class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
|
|
|
127
151
|
|
|
128
152
|
op_type: ClassVar = "ConvTranspose"
|
|
129
153
|
|
|
130
|
-
def
|
|
131
|
-
|
|
154
|
+
def _scale_weights(
|
|
155
|
+
self,
|
|
156
|
+
weights: np.ndarray,
|
|
157
|
+
scale_factor: np.ndarray,
|
|
158
|
+
attributes: Mapping[str, ir.Attr],
|
|
159
|
+
) -> np.ndarray:
|
|
160
|
+
# ConvTranspose weight: (in_channels, out_channels/group, *kernel)
|
|
161
|
+
# Reshape weights: [in_channels, out_channels/group, *kernel] → [group, in_channels/group, out_channels/group, *kernel]
|
|
162
|
+
in_channels = weights.shape[0]
|
|
163
|
+
out_channels_per_group = weights.shape[1]
|
|
164
|
+
kernel_shape = weights.shape[2:]
|
|
165
|
+
group = attributes.get("group", ir.AttrInt64("group", 1)).as_int()
|
|
166
|
+
w = weights.reshape(group, in_channels // group, out_channels_per_group, *kernel_shape)
|
|
167
|
+
|
|
168
|
+
# Per group scale_factor (out_channels,) -> (group, out_channels/group) -> (group, 1, out_channels/group, 1, ..., 1)
|
|
169
|
+
s = scale_factor.reshape((group, out_channels_per_group) + (1,) * len(kernel_shape))
|
|
170
|
+
# insert in_channels/group axis -> (group, 1, out_channels/group, *ones)
|
|
171
|
+
s = s[:, None, ...]
|
|
172
|
+
|
|
173
|
+
return (w * s).reshape(weights.shape)
|
|
132
174
|
|
|
133
175
|
def pattern(self, op, x):
|
|
134
176
|
return op.BatchNormalization(
|
|
@@ -137,6 +179,25 @@ class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
|
|
|
137
179
|
_outputs=["batchnorm_out"],
|
|
138
180
|
)
|
|
139
181
|
|
|
182
|
+
def check(self, context, x, inbound_out, batchnorm_out):
|
|
183
|
+
check_result = super().check(context, x, inbound_out, batchnorm_out)
|
|
184
|
+
if not check_result:
|
|
185
|
+
return check_result
|
|
186
|
+
|
|
187
|
+
inbound_node = inbound_out.producer()
|
|
188
|
+
|
|
189
|
+
in_channels = inbound_node.inputs[1].const_value.numpy().shape[0]
|
|
190
|
+
group = inbound_node.attributes.get("group", ir.AttrInt64("group", 1)).as_int()
|
|
191
|
+
|
|
192
|
+
# Check that in_channels is divisible by group as ONNX checker allows it
|
|
193
|
+
# But this is invalid case
|
|
194
|
+
if in_channels % group != 0:
|
|
195
|
+
return check_result.fail(
|
|
196
|
+
f"ConvTranspose in_channels ({in_channels}) is not divisible by group ({group})."
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return check_result
|
|
200
|
+
|
|
140
201
|
|
|
141
202
|
class FuseBatchNormIntoGemm(_FuseBatchNormBase):
|
|
142
203
|
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
|
onnxscript-0.6.3.dev20260411/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""Fusion rule to remove an Expand node before a binary operator.
|
|
4
|
+
|
|
5
|
+
This implements the optimization:
|
|
6
|
+
|
|
7
|
+
BinaryOp(Expand(x, shape), y) -> BinaryOp(x, y)
|
|
8
|
+
BinaryOp(x, Expand(y, shape)) -> BinaryOp(x, y)
|
|
9
|
+
|
|
10
|
+
This is valid when the binary operator's broadcasting semantics would produce
|
|
11
|
+
the same output shape as first expanding the input and then applying the op.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from onnxscript import ir
|
|
17
|
+
from onnxscript.rewriter._basics import MatchResult
|
|
18
|
+
from onnxscript.rewriter._ir_utils import get_numpy_value
|
|
19
|
+
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
|
|
20
|
+
|
|
21
|
+
# Binary operators in ONNX standard opset that support numpy-style broadcasting.
|
|
22
|
+
_BROADCAST_BINARY_OPS: tuple[str, ...] = (
|
|
23
|
+
"Add",
|
|
24
|
+
"And",
|
|
25
|
+
"BitShift",
|
|
26
|
+
"BitwiseAnd",
|
|
27
|
+
"BitwiseOr",
|
|
28
|
+
"BitwiseXor",
|
|
29
|
+
"Div",
|
|
30
|
+
"Equal",
|
|
31
|
+
"Greater",
|
|
32
|
+
"GreaterOrEqual",
|
|
33
|
+
"Less",
|
|
34
|
+
"LessOrEqual",
|
|
35
|
+
"Mod",
|
|
36
|
+
"Mul",
|
|
37
|
+
"Or",
|
|
38
|
+
"Pow",
|
|
39
|
+
"PRelu",
|
|
40
|
+
"Sub",
|
|
41
|
+
"Xor",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _compute_broadcast_dim(d1, d2):
|
|
46
|
+
"""Return the numpy broadcast of two dimension values.
|
|
47
|
+
|
|
48
|
+
Each dimension value may be an ``int`` or an ``onnx_ir.SymbolicDim``.
|
|
49
|
+
Returns ``None`` when the result cannot be determined statically (e.g. two
|
|
50
|
+
distinct symbolic values neither of which is known to be 1).
|
|
51
|
+
"""
|
|
52
|
+
if d1 == 1:
|
|
53
|
+
return d2
|
|
54
|
+
if d2 == 1:
|
|
55
|
+
return d1
|
|
56
|
+
if d1 == d2:
|
|
57
|
+
return d1
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _compute_broadcast_shape(shape1: ir.Shape, shape2: ir.Shape) -> list | None:
|
|
62
|
+
"""Compute numpy-style broadcast shape symbolically.
|
|
63
|
+
|
|
64
|
+
Returns the broadcast shape as a list of dimension values (``int`` or
|
|
65
|
+
``SymbolicDim``), or ``None`` when the result cannot be determined (e.g.
|
|
66
|
+
unknown ranks or incompatible static dims).
|
|
67
|
+
"""
|
|
68
|
+
rank1 = shape1.rank()
|
|
69
|
+
rank2 = shape2.rank()
|
|
70
|
+
if rank1 is None or rank2 is None:
|
|
71
|
+
return None
|
|
72
|
+
rank = max(rank1, rank2)
|
|
73
|
+
result = []
|
|
74
|
+
for i in range(rank):
|
|
75
|
+
idx1 = rank1 - rank + i
|
|
76
|
+
d1 = shape1[idx1] if idx1 >= 0 else 1
|
|
77
|
+
idx2 = rank2 - rank + i
|
|
78
|
+
d2 = shape2[idx2] if idx2 >= 0 else 1
|
|
79
|
+
d = _compute_broadcast_dim(d1, d2)
|
|
80
|
+
if d is None:
|
|
81
|
+
return None
|
|
82
|
+
result.append(d)
|
|
83
|
+
return result
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _check_dims_sufficient(
|
|
87
|
+
expand_shape: ir.Shape,
|
|
88
|
+
x_shape: ir.Shape,
|
|
89
|
+
y_shape: ir.Shape,
|
|
90
|
+
) -> MatchResult:
|
|
91
|
+
"""Check that x and y together cover every dimension of the expand target.
|
|
92
|
+
|
|
93
|
+
For each dimension ``i`` of *expand_shape* (right-aligned) the expand is
|
|
94
|
+
considered redundant when at least one of the following holds:
|
|
95
|
+
|
|
96
|
+
- ``expand_shape[i] == 1`` - expand cannot shrink a dim, so ``x_d`` must
|
|
97
|
+
also be 1 and both with and without expand produce ``y_d``.
|
|
98
|
+
- ``x_d == expand_shape[i]`` - the expand is a no-op at this dim.
|
|
99
|
+
- ``y_d == expand_shape[i]`` - ``y`` already supplies this expansion.
|
|
100
|
+
|
|
101
|
+
Comparisons work for both ``int`` and ``SymbolicDim`` values.
|
|
102
|
+
"""
|
|
103
|
+
check_result = MatchResult()
|
|
104
|
+
e_rank = expand_shape.rank()
|
|
105
|
+
x_rank = x_shape.rank()
|
|
106
|
+
y_rank = y_shape.rank()
|
|
107
|
+
if e_rank is None:
|
|
108
|
+
return check_result.fail("Expand output rank is unknown.")
|
|
109
|
+
|
|
110
|
+
for rev_i in range(e_rank):
|
|
111
|
+
i = e_rank - 1 - rev_i
|
|
112
|
+
e_d = expand_shape[i]
|
|
113
|
+
|
|
114
|
+
if isinstance(e_d, int) and e_d == 1:
|
|
115
|
+
continue # expand cannot shrink; x_d is also 1, no-op
|
|
116
|
+
|
|
117
|
+
x_idx = x_rank - 1 - rev_i
|
|
118
|
+
x_d = x_shape[x_idx] if x_idx >= 0 else 1
|
|
119
|
+
if x_d == e_d:
|
|
120
|
+
continue # expand is a no-op at this dimension
|
|
121
|
+
|
|
122
|
+
y_idx = y_rank - 1 - rev_i
|
|
123
|
+
y_d = y_shape[y_idx] if y_idx >= 0 else 1
|
|
124
|
+
if y_d == e_d:
|
|
125
|
+
continue # y already supplies this dimension
|
|
126
|
+
|
|
127
|
+
return check_result.fail(
|
|
128
|
+
f"Cannot verify that removing Expand is safe at dimension {i}: "
|
|
129
|
+
f"x_d={x_d!r}, expand_d={e_d!r}, y_d={y_d!r}."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return check_result
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _check_expand_removable(
|
|
136
|
+
expand_input: ir.Value,
|
|
137
|
+
shape: ir.Value,
|
|
138
|
+
other_input: ir.Value,
|
|
139
|
+
expand_output: ir.Value | None = None,
|
|
140
|
+
binary_op_output: ir.Value | None = None,
|
|
141
|
+
) -> MatchResult:
|
|
142
|
+
"""Check if an Expand node can be safely removed before a binary op.
|
|
143
|
+
|
|
144
|
+
The Expand ``expanded_x = Expand(x, expand_shape)`` before a binary op
|
|
145
|
+
``out = BinaryOp(expanded_x, y)`` is redundant when the binary op's own
|
|
146
|
+
broadcasting produces the same output as if the expand had been applied.
|
|
147
|
+
|
|
148
|
+
Three strategies are tried in order:
|
|
149
|
+
|
|
150
|
+
1. **Constant expand shape** - When ``shape`` is a compile-time constant,
|
|
151
|
+
the dimension values are extracted from it and the check is performed
|
|
152
|
+
directly.
|
|
153
|
+
|
|
154
|
+
2. **Expand output shape annotation** - When ``shape`` is dynamic but the
|
|
155
|
+
Expand node's output value already carries a shape annotation (e.g.
|
|
156
|
+
after ONNX shape inference has been applied to the model), those
|
|
157
|
+
dimension values are used for the check.
|
|
158
|
+
|
|
159
|
+
3. **Binary op output shape** - When neither of the above is available,
|
|
160
|
+
the rule verifies that ``broadcast(x.shape, y.shape)`` symbolically
|
|
161
|
+
equals the binary op's output shape. If they agree, the binary op's
|
|
162
|
+
own broadcasting already accounts for all the expansion and the
|
|
163
|
+
Expand is redundant.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
expand_input: The value fed into the Expand node (``x``).
|
|
167
|
+
shape: The target shape operand of the Expand node.
|
|
168
|
+
other_input: The other operand of the binary op (``y``).
|
|
169
|
+
expand_output: The output value of the Expand node. Required for
|
|
170
|
+
strategy 2.
|
|
171
|
+
binary_op_output: The output value of the binary op. Required for
|
|
172
|
+
strategy 3.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
A :class:`MatchResult` that is successful when the Expand can be
|
|
176
|
+
removed.
|
|
177
|
+
"""
|
|
178
|
+
check_result = MatchResult()
|
|
179
|
+
|
|
180
|
+
x_shape = expand_input.shape
|
|
181
|
+
y_shape = other_input.shape
|
|
182
|
+
if x_shape is None or y_shape is None:
|
|
183
|
+
return check_result.fail("Input shapes are not known.")
|
|
184
|
+
|
|
185
|
+
x_rank = x_shape.rank()
|
|
186
|
+
y_rank = y_shape.rank()
|
|
187
|
+
|
|
188
|
+
# --- Strategy 1: expand target shape is a compile-time constant ---
|
|
189
|
+
expand_shape_val = get_numpy_value(shape)
|
|
190
|
+
if expand_shape_val is not None:
|
|
191
|
+
expand_shape = tuple(int(v) for v in expand_shape_val.tolist())
|
|
192
|
+
expand_rank = len(expand_shape)
|
|
193
|
+
|
|
194
|
+
for rev_i in range(expand_rank):
|
|
195
|
+
i = expand_rank - 1 - rev_i
|
|
196
|
+
e_d = expand_shape[i] # always a known integer from numpy
|
|
197
|
+
|
|
198
|
+
if e_d == 1:
|
|
199
|
+
continue # expand cannot shrink; x_d is also 1, no-op
|
|
200
|
+
|
|
201
|
+
x_idx = x_rank - 1 - rev_i
|
|
202
|
+
x_d = x_shape[x_idx] if x_idx >= 0 else 1
|
|
203
|
+
|
|
204
|
+
if isinstance(x_d, int) and x_d == e_d:
|
|
205
|
+
continue # expand is a no-op at this dimension
|
|
206
|
+
|
|
207
|
+
y_idx = y_rank - 1 - rev_i
|
|
208
|
+
y_d = y_shape[y_idx] if y_idx >= 0 else 1
|
|
209
|
+
|
|
210
|
+
if isinstance(y_d, int) and y_d == e_d:
|
|
211
|
+
continue # y already supplies this dimension
|
|
212
|
+
|
|
213
|
+
return check_result.fail(
|
|
214
|
+
f"Cannot verify that removing Expand is safe at dimension {i}: "
|
|
215
|
+
f"x_d={x_d!r}, expand_d={e_d}, y_d={y_d!r}."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return check_result
|
|
219
|
+
|
|
220
|
+
# --- Strategy 2: Expand output shape is known (e.g. from shape inference) ---
|
|
221
|
+
if expand_output is not None and expand_output.shape is not None:
|
|
222
|
+
return _check_dims_sufficient(expand_output.shape, x_shape, y_shape)
|
|
223
|
+
|
|
224
|
+
# --- Strategy 3: use the binary op's output shape ---
|
|
225
|
+
# broadcast(x.shape, y.shape) must equal the binary op's output shape.
|
|
226
|
+
# If it does, the binary op's own broadcasting already produces the same
|
|
227
|
+
# result as first expanding x and then broadcasting.
|
|
228
|
+
if binary_op_output is not None and binary_op_output.shape is not None:
|
|
229
|
+
op_output_shape = binary_op_output.shape
|
|
230
|
+
if op_output_shape.rank() is not None:
|
|
231
|
+
computed = _compute_broadcast_shape(x_shape, y_shape)
|
|
232
|
+
if computed is not None and len(computed) == op_output_shape.rank():
|
|
233
|
+
if all(c == a for c, a in zip(computed, op_output_shape)):
|
|
234
|
+
return check_result
|
|
235
|
+
return check_result.fail(
|
|
236
|
+
"broadcast(x.shape, y.shape) does not match the binary op output shape."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
return check_result.fail(
|
|
240
|
+
"Expand target shape is not a constant and no shape annotations are available."
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class _ExpandFirstInput(RewriteRuleClassBase):
|
|
245
|
+
"""Removes ``BinaryOp(Expand(x, shape), y)`` -> ``BinaryOp(x, y)``."""
|
|
246
|
+
|
|
247
|
+
def __init__(self, op_type: str) -> None:
|
|
248
|
+
super().__init__(f"ExpandFirst_{op_type}", remove_nodes=False)
|
|
249
|
+
self._op_type = op_type
|
|
250
|
+
|
|
251
|
+
def pattern(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
|
|
252
|
+
return getattr(op, self._op_type)(op.Expand(x, shape), y)
|
|
253
|
+
|
|
254
|
+
def check(self, context, x: ir.Value, shape: ir.Value, y: ir.Value) -> MatchResult:
|
|
255
|
+
expand_output = context.root.inputs[0] if context.root.inputs else None
|
|
256
|
+
binary_op_output = context.root.outputs[0] if context.root.outputs else None
|
|
257
|
+
return _check_expand_removable(
|
|
258
|
+
x, shape, y, expand_output=expand_output, binary_op_output=binary_op_output
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def rewrite(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
|
|
262
|
+
return getattr(op, self._op_type)(x, y)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class _ExpandSecondInput(RewriteRuleClassBase):
|
|
266
|
+
"""Removes ``BinaryOp(x, Expand(y, shape))`` -> ``BinaryOp(x, y)``."""
|
|
267
|
+
|
|
268
|
+
def __init__(self, op_type: str) -> None:
|
|
269
|
+
super().__init__(f"ExpandSecond_{op_type}", remove_nodes=False)
|
|
270
|
+
self._op_type = op_type
|
|
271
|
+
|
|
272
|
+
def pattern(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
|
|
273
|
+
return getattr(op, self._op_type)(x, op.Expand(y, shape))
|
|
274
|
+
|
|
275
|
+
def check(self, context, x: ir.Value, y: ir.Value, shape: ir.Value) -> MatchResult:
|
|
276
|
+
expand_output = context.root.inputs[1] if context.root.inputs else None
|
|
277
|
+
binary_op_output = context.root.outputs[0] if context.root.outputs else None
|
|
278
|
+
return _check_expand_removable(
|
|
279
|
+
y, shape, x, expand_output=expand_output, binary_op_output=binary_op_output
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def rewrite(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
|
|
283
|
+
return getattr(op, self._op_type)(x, y)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _make_expand_before_binary_op_rules() -> list:
|
|
287
|
+
"""Create rewrite rules for removing Expand before each supported binary op."""
|
|
288
|
+
rules = []
|
|
289
|
+
for op_type in _BROADCAST_BINARY_OPS:
|
|
290
|
+
rules.append(_ExpandFirstInput.rule(op_type))
|
|
291
|
+
rules.append(_ExpandSecondInput.rule(op_type))
|
|
292
|
+
return rules
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
expand_before_binary_op_rules = RewriteRuleSet(_make_expand_before_binary_op_rules())
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnxscript
|
|
3
|
-
Version: 0.6.3.
|
|
3
|
+
Version: 0.6.3.dev20260411
|
|
4
4
|
Summary: Naturally author ONNX functions and models using a subset of Python
|
|
5
5
|
Author-email: Microsoft Corporation <onnx@microsoft.com>
|
|
6
6
|
License: MIT License
|
|
@@ -27,7 +27,7 @@ License: MIT License
|
|
|
27
27
|
|
|
28
28
|
Project-URL: Homepage, https://microsoft.github.io/onnxscript/
|
|
29
29
|
Project-URL: Repository, https://github.com/microsoft/onnxscript
|
|
30
|
-
Project-URL: Commit, https://github.com/microsoft/onnxscript/tree/
|
|
30
|
+
Project-URL: Commit, https://github.com/microsoft/onnxscript/tree/13f265cd01b21210267b86a24efdb0072c0ee374
|
|
31
31
|
Classifier: Development Status :: 4 - Beta
|
|
32
32
|
Classifier: Environment :: Console
|
|
33
33
|
Classifier: Intended Audience :: Developers
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/SOURCES.txt
RENAMED
|
@@ -169,6 +169,7 @@ onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py
|
|
|
169
169
|
onnxscript/rewriter/rules/common/_min_max_to_clip.py
|
|
170
170
|
onnxscript/rewriter/rules/common/_no_op.py
|
|
171
171
|
onnxscript/rewriter/rules/common/_redundant_scatter_nd.py
|
|
172
|
+
onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py
|
|
172
173
|
onnxscript/rewriter/rules/common/_remove_optional_bias.py
|
|
173
174
|
onnxscript/rewriter/rules/fusion/__init__.py
|
|
174
175
|
onnxscript/rewriter/rules/fusion/_gqa.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/__init__.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/_inference.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/_inliner.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/analysis.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/ast_utils.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/autocast.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/converter.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/deprecation.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/evaluator.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/irbuilder.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/sourceinfo.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/values.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/version_utils.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/__init__.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/onnx_backend.py
RENAMED
|
File without changes
|
{onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/onnx_export.py
RENAMED
|
File without changes
|
|
File without changes
|