onnxscript 0.6.3.dev20260403__tar.gz → 0.6.3.dev20260411__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. {onnxscript-0.6.3.dev20260403/onnxscript.egg-info → onnxscript-0.6.3.dev20260411}/PKG-INFO +2 -2
  2. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/builder.py +48 -0
  3. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/core.py +14 -2
  4. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/__init__.py +4 -0
  5. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +69 -8
  6. onnxscript-0.6.3.dev20260411/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py +295 -0
  7. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411/onnxscript.egg-info}/PKG-INFO +2 -2
  8. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/SOURCES.txt +1 -0
  9. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/LICENSE +0 -0
  10. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/MANIFEST.in +0 -0
  11. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/README.md +0 -0
  12. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/VERSION +0 -0
  13. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/__init__.py +0 -0
  14. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/__init__.py +0 -0
  15. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_11.py +0 -0
  16. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_5.py +0 -0
  17. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_6.py +0 -0
  18. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_7.py +0 -0
  19. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_8.py +0 -0
  20. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_framework_apis/torch_2_9.py +0 -0
  21. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/__init__.py +0 -0
  22. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/_inference.py +0 -0
  23. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/_inliner.py +0 -0
  24. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/analysis.py +0 -0
  25. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/ast_utils.py +0 -0
  26. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/autocast.py +0 -0
  27. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/converter.py +0 -0
  28. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/deprecation.py +0 -0
  29. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/evaluator.py +0 -0
  30. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/irbuilder.py +0 -0
  31. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/main.py +0 -0
  32. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/param_manipulation.py +0 -0
  33. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/sourceinfo.py +0 -0
  34. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/type_annotation.py +0 -0
  35. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/utils.py +0 -0
  36. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/values.py +0 -0
  37. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/_internal/version_utils.py +0 -0
  38. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/__init__.py +0 -0
  39. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/onnx_backend.py +0 -0
  40. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/backend/onnx_export.py +0 -0
  41. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/evaluator.py +0 -0
  42. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py +0 -0
  43. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py +0 -0
  44. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +0 -0
  45. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/__init__.py +0 -0
  46. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/_constants.py +0 -0
  47. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/_flags.py +0 -0
  48. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/graph_building/__init__.py +0 -0
  49. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/__init__.py +0 -0
  50. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/common.py +0 -0
  51. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/fft.py +0 -0
  52. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/linalg.py +0 -0
  53. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/nested.py +0 -0
  54. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/nn.py +0 -0
  55. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/prims.py +0 -0
  56. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/quantized_decomposed.py +0 -0
  57. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/sparse.py +0 -0
  58. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/special.py +0 -0
  59. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/ops/vision.py +0 -0
  60. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/registration.py +0 -0
  61. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/function_libs/torch_lib/tensor_typing.py +0 -0
  62. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/__init__.py +0 -0
  63. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/_schemas.py +0 -0
  64. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/_tape.py +0 -0
  65. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/convenience.py +0 -0
  66. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/passes/__init__.py +0 -0
  67. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/ir/passes/common/__init__.py +0 -0
  68. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/__init__.py +0 -0
  69. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/_module.py +0 -0
  70. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/_module_list.py +0 -0
  71. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/_parameter.py +0 -0
  72. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/nn/_sequential.py +0 -0
  73. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/__init__.py +0 -0
  74. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset1.py +0 -0
  75. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset10.py +0 -0
  76. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset11.py +0 -0
  77. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset12.py +0 -0
  78. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset13.py +0 -0
  79. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset14.py +0 -0
  80. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset15.py +0 -0
  81. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset16.py +0 -0
  82. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset17.py +0 -0
  83. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset18.py +0 -0
  84. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset19.py +0 -0
  85. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset2.py +0 -0
  86. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset20.py +0 -0
  87. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset21.py +0 -0
  88. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset22.py +0 -0
  89. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset23.py +0 -0
  90. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset24.py +0 -0
  91. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset3.py +0 -0
  92. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset4.py +0 -0
  93. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset5.py +0 -0
  94. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset6.py +0 -0
  95. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset7.py +0 -0
  96. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset8.py +0 -0
  97. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset9.py +0 -0
  98. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml1.py +0 -0
  99. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml2.py +0 -0
  100. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml3.py +0 -0
  101. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml4.py +0 -0
  102. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_opset/_impl/opset_ai_onnx_ml5.py +0 -0
  103. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/onnx_types.py +0 -0
  104. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/optimizer/__init__.py +0 -0
  105. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/optimizer/_constant_folding.py +0 -0
  106. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/optimizer/_optimizer.py +0 -0
  107. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/py.typed +0 -0
  108. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/__init__.py +0 -0
  109. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_basics.py +0 -0
  110. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_fusion_utils.py +0 -0
  111. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_ir_utils.py +0 -0
  112. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_matcher.py +0 -0
  113. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_pattern_ir.py +0 -0
  114. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/_rewrite_rule.py +0 -0
  115. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_bart_encoder.py +0 -0
  116. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_phi2lm.py +0 -0
  117. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_phi4lm.py +0 -0
  118. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_rotary_embedding_models.py +0 -0
  119. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_smollm_1.py +0 -0
  120. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_smollm_2.py +0 -0
  121. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_test_models.py +0 -0
  122. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_whisper_decoder.py +0 -0
  123. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/models/_whisper_encoder.py +0 -0
  124. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/onnx_fusions/__init__.py +0 -0
  125. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +0 -0
  126. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/onnxruntime/__init__.py +0 -0
  127. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter.py +0 -0
  128. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/__init__.py +0 -0
  129. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/_core.py +0 -0
  130. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/_test_utils.py +0 -0
  131. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/attention.py +0 -0
  132. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/bias_gelu.py +0 -0
  133. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +0 -0
  134. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/erfgelu.py +0 -0
  135. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +0 -0
  136. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/gelu.py +0 -0
  137. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/gqa.py +0 -0
  138. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py +0 -0
  139. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/group_normalization_merge_silu.py +0 -0
  140. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/instance_to_group_normalization.py +0 -0
  141. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/mha.py +0 -0
  142. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/mha_bias.py +0 -0
  143. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/mha_scale.py +0 -0
  144. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/rms_normalization.py +0 -0
  145. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/rotary_embedding.py +0 -0
  146. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/sdpa.py +0 -0
  147. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py +0 -0
  148. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/shape_optimization.py +0 -0
  149. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/skip_normalization.py +0 -0
  150. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/ort_fusions/softmax.py +0 -0
  151. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/pattern.py +0 -0
  152. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/__init__.py +0 -0
  153. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_basic_rules.py +0 -0
  154. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_broadcast_to_matmul.py +0 -0
  155. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_cast_constant_of_shape.py +0 -0
  156. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_collapse_slices.py +0 -0
  157. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_conv_affine.py +0 -0
  158. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_hardswish.py +0 -0
  159. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py +0 -0
  160. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_fuse_relus_clips.py +0 -0
  161. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py +0 -0
  162. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py +0 -0
  163. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_min_max_to_clip.py +0 -0
  164. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_no_op.py +0 -0
  165. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py +0 -0
  166. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/common/_remove_optional_bias.py +0 -0
  167. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/__init__.py +0 -0
  168. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/_gqa.py +0 -0
  169. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/_layer_norm.py +0 -0
  170. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/_rms_normalization.py +0 -0
  171. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/rules/fusion/_rotary_embedding.py +0 -0
  172. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/rewriter/testing.py +0 -0
  173. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tensor.py +0 -0
  174. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/testing/__init__.py +0 -0
  175. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/__init__.py +0 -0
  176. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/memory_peak.py +0 -0
  177. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/__init__.py +0 -0
  178. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/llama.py +0 -0
  179. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/mistral.py +0 -0
  180. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/phi.py +0 -0
  181. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/tools/transformers_models/phi3.py +0 -0
  182. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/__init__.py +0 -0
  183. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/evaluation_utils.py +0 -0
  184. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/metadata_merger.py +0 -0
  185. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/replace.py +0 -0
  186. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/timing_utils.py +0 -0
  187. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/utils/utils.py +0 -0
  188. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/values.py +0 -0
  189. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/version_converter/__init__.py +0 -0
  190. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/version_converter/_c_api_utils.py +0 -0
  191. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript/version_converter/_version_converter.py +0 -0
  192. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/dependency_links.txt +0 -0
  193. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/requires.txt +0 -0
  194. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/onnxscript.egg-info/top_level.txt +0 -0
  195. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/pyproject.toml +0 -0
  196. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/setup.cfg +0 -0
  197. {onnxscript-0.6.3.dev20260403 → onnxscript-0.6.3.dev20260411}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnxscript
3
- Version: 0.6.3.dev20260403
3
+ Version: 0.6.3.dev20260411
4
4
  Summary: Naturally author ONNX functions and models using a subset of Python
5
5
  Author-email: Microsoft Corporation <onnx@microsoft.com>
6
6
  License: MIT License
@@ -27,7 +27,7 @@ License: MIT License
27
27
 
28
28
  Project-URL: Homepage, https://microsoft.github.io/onnxscript/
29
29
  Project-URL: Repository, https://github.com/microsoft/onnxscript
30
- Project-URL: Commit, https://github.com/microsoft/onnxscript/tree/a47ccbadd85851ebdafd41e13f880e6ea399535d
30
+ Project-URL: Commit, https://github.com/microsoft/onnxscript/tree/13f265cd01b21210267b86a24efdb0072c0ee374
31
31
  Classifier: Development Status :: 4 - Beta
32
32
  Classifier: Environment :: Console
33
33
  Classifier: Intended Audience :: Developers
@@ -276,6 +276,54 @@ class GraphBuilder:
276
276
  self._graph.register_initializer(value)
277
277
  return value
278
278
 
279
+ def input(
280
+ self,
281
+ name: str,
282
+ dtype: ir.DataType | None = None,
283
+ shape: ir.Shape | Sequence[int | str | None] | None = None,
284
+ *,
285
+ type: ir.TypeProtocol | None = None,
286
+ const_value: ir.TensorProtocol | None = None,
287
+ metadata_props: dict[str, str] | None = None,
288
+ ) -> ir.Value:
289
+ """Create an input to the graph and return the corresponding ir.Value.
290
+
291
+ Args:
292
+ name: The name of the value.
293
+ dtype: The data type of the TensorType of the value. This is used only when type is None.
294
+ shape: The shape of the value.
295
+ type: The type of the value. Only one of dtype and type can be specified.
296
+ const_value: The constant tensor that initializes the value. Supply this argument
297
+ when you want to create an initializer. The type and shape can be obtained from the tensor.
298
+ metadata_props: The metadata properties that will be serialized to the ONNX proto.
299
+
300
+ Returns:
301
+ A Value object.
302
+ """
303
+ value = ir.val(
304
+ name=name,
305
+ dtype=dtype,
306
+ shape=shape,
307
+ type=type,
308
+ const_value=const_value,
309
+ metadata_props=metadata_props,
310
+ )
311
+ self._graph.inputs.append(value)
312
+ if const_value is not None:
313
+ self._graph.register_initializer(value)
314
+ return value
315
+
316
+ def add_output(self, value: ir.Value, name: str | None) -> None:
317
+ """Add an output to the graph.
318
+
319
+ Args:
320
+ value: The ir.Value to add as an output.
321
+ name: The name to assign to the output value. If None, no renaming is done.
322
+ """
323
+ if name:
324
+ value.name = name
325
+ self._graph.outputs.append(value)
326
+
279
327
  def _input_to_ir_value(
280
328
  self, value: VALUE_LIKE, like_type: ir.Value | None = None
281
329
  ) -> ir.Value | None:
@@ -6231,7 +6231,7 @@ def aten_mean_complex(self: TReal) -> TReal:
6231
6231
 
6232
6232
 
6233
6233
  @torch_op("aten::mean.dim", trace_only=True)
6234
- def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
6234
+ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False, dtype: int = -1) -> TReal:
6235
6235
  """mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
6236
6236
 
6237
6237
  if len(self.shape) == 0:
@@ -6239,11 +6239,17 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
6239
6239
  else:
6240
6240
  dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
6241
6241
  result = op.ReduceMean(self, dims, keepdims=keepdim)
6242
+
6243
+ if dtype != -1 and dtype is not None:
6244
+ result = op.Cast(result, to=dtype)
6245
+
6242
6246
  return result
6243
6247
 
6244
6248
 
6245
6249
  @torch_op("aten::mean.dim", trace_only=True, complex=True)
6246
- def aten_mean_dim_complex(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
6250
+ def aten_mean_dim_complex(
6251
+ self: TReal, dim: INT64, keepdim: bool = False, dtype: int = -1
6252
+ ) -> TReal:
6247
6253
  """mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
6248
6254
 
6249
6255
  if len(self.shape) == 1:
@@ -6254,6 +6260,12 @@ def aten_mean_dim_complex(self: TReal, dim: INT64, keepdim: bool = False) -> TRe
6254
6260
  dim = op.Where(op.Less(dim, zero), op.Sub(dim, one), dim)
6255
6261
  dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
6256
6262
  result = op.ReduceMean(self, dims, keepdims=keepdim)
6263
+
6264
+ if dtype != -1 and dtype is not None:
6265
+ raise NotImplementedError(
6266
+ "support for the dtype argument is not implemented for complex tensors"
6267
+ )
6268
+
6257
6269
  return result
6258
6270
 
6259
6271
 
@@ -12,6 +12,7 @@ __all__ = [
12
12
  "div_by_1_rule",
13
13
  "dropout_inference_rule",
14
14
  "dropout_zero_rule",
15
+ "expand_before_binary_op_rules",
15
16
  "flatten_to_reshape_rule",
16
17
  "fuse_batchnorm_into_conv_rule",
17
18
  "fuse_batchnorm_into_conv_transpose_rule",
@@ -125,6 +126,9 @@ from onnxscript.rewriter.rules.common._redundant_scatter_nd import (
125
126
  no_op_dynamic_scatter_nd_rule,
126
127
  no_op_static_scatter_nd_rule,
127
128
  )
129
+ from onnxscript.rewriter.rules.common._remove_expand_before_binary_op import (
130
+ expand_before_binary_op_rules,
131
+ )
128
132
  from onnxscript.rewriter.rules.common._remove_optional_bias import (
129
133
  remove_optional_bias_from_conv_rule,
130
134
  remove_optional_bias_from_conv_transpose_rule,
@@ -2,7 +2,7 @@
2
2
  # Licensed under the MIT License.
3
3
  """Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
4
4
  - BatchNormalization ∘ Conv -> Conv
5
- - BatchNormalization ∘ ConvTranpose -> ConvTranpose
5
+ - BatchNormalization ∘ ConvTranspose -> ConvTranspose
6
6
  - BatchNormalization ∘ Gemm -> Gemm
7
7
 
8
8
  Approach:
@@ -14,7 +14,7 @@ Approach:
14
14
  - B_fused = (B - μ) * (gamma / std) + β
15
15
  """
16
16
 
17
- from abc import ABC, abstractmethod
17
+ from abc import ABC
18
18
  from typing import ClassVar, Mapping
19
19
 
20
20
  import numpy as np
@@ -33,9 +33,18 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra
33
33
  class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
34
34
  """Interface for BatchNormalization nodes fusion."""
35
35
 
36
- @abstractmethod
37
36
  def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
38
37
  """Return the axis along which BatchNorm scale should be broadcasted."""
38
+ raise NotImplementedError()
39
+
40
+ def _scale_weights(
41
+ self,
42
+ weights: np.ndarray,
43
+ scale_factor: np.ndarray,
44
+ attributes: Mapping[str, ir.Attr],
45
+ ) -> np.ndarray:
46
+ axis = self.get_filters_axis(attributes)
47
+ return weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
39
48
 
40
49
  def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
41
50
  batchnorm_node = batchnorm_out.producer()
@@ -56,10 +65,8 @@ class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
56
65
  inbound_node = inbound_out.producer()
57
66
  weights = inbound_node.inputs[1].const_value.numpy()
58
67
 
59
- # Reshape scale factor so it is broadcastable
60
- axis = self.get_filters_axis(inbound_node.attributes)
61
68
  fused_weights = ir.tensor(
62
- weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
69
+ self._scale_weights(weights, scale_factor, inbound_node.attributes)
63
70
  )
64
71
 
65
72
  # Update bias
@@ -103,6 +110,23 @@ class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
103
110
  if initializer.is_graph_input():
104
111
  return check_result.fail(f"{initializer.name} is a graph input.")
105
112
 
113
+ # Check that the inbound node's weight and bias initializers are not shared
114
+ # with other nodes outside this matched pattern. When the fusion creates new
115
+ # initializers with the same name as the original shared weights, it overwrites
116
+ # the original initializer in the graph, leaving other nodes that reference the
117
+ # original value with an invalid (unregistered) input.
118
+ matched_nodes = {inbound_node, batchnorm_node}
119
+ inbound_initializers = [inbound_node.inputs[1]]
120
+ if len(inbound_node.inputs) > 2:
121
+ inbound_initializers.append(inbound_node.inputs[2])
122
+ for init_value in inbound_initializers:
123
+ for user, _ in init_value.uses():
124
+ if user not in matched_nodes:
125
+ return check_result.fail(
126
+ f"Initializer '{init_value.name}' is used by another node "
127
+ f"'{user.name}' outside the matched pattern."
128
+ )
129
+
106
130
  return check_result
107
131
 
108
132
 
@@ -127,8 +151,26 @@ class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
127
151
 
128
152
  op_type: ClassVar = "ConvTranspose"
129
153
 
130
- def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
131
- return 1
154
+ def _scale_weights(
155
+ self,
156
+ weights: np.ndarray,
157
+ scale_factor: np.ndarray,
158
+ attributes: Mapping[str, ir.Attr],
159
+ ) -> np.ndarray:
160
+ # ConvTranspose weight: (in_channels, out_channels/group, *kernel)
161
+ # Reshape weights: [in_channels, out_channels/group, *kernel] → [group, in_channels/group, out_channels/group, *kernel]
162
+ in_channels = weights.shape[0]
163
+ out_channels_per_group = weights.shape[1]
164
+ kernel_shape = weights.shape[2:]
165
+ group = attributes.get("group", ir.AttrInt64("group", 1)).as_int()
166
+ w = weights.reshape(group, in_channels // group, out_channels_per_group, *kernel_shape)
167
+
168
+ # Per group scale_factor (out_channels,) -> (group, out_channels/group) -> (group, 1, out_channels/group, 1, ..., 1)
169
+ s = scale_factor.reshape((group, out_channels_per_group) + (1,) * len(kernel_shape))
170
+ # insert in_channels/group axis -> (group, 1, out_channels/group, *ones)
171
+ s = s[:, None, ...]
172
+
173
+ return (w * s).reshape(weights.shape)
132
174
 
133
175
  def pattern(self, op, x):
134
176
  return op.BatchNormalization(
@@ -137,6 +179,25 @@ class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
137
179
  _outputs=["batchnorm_out"],
138
180
  )
139
181
 
182
+ def check(self, context, x, inbound_out, batchnorm_out):
183
+ check_result = super().check(context, x, inbound_out, batchnorm_out)
184
+ if not check_result:
185
+ return check_result
186
+
187
+ inbound_node = inbound_out.producer()
188
+
189
+ in_channels = inbound_node.inputs[1].const_value.numpy().shape[0]
190
+ group = inbound_node.attributes.get("group", ir.AttrInt64("group", 1)).as_int()
191
+
192
+ # Check that in_channels is divisible by group as ONNX checker allows it
193
+ # But this is invalid case
194
+ if in_channels % group != 0:
195
+ return check_result.fail(
196
+ f"ConvTranspose in_channels ({in_channels}) is not divisible by group ({group})."
197
+ )
198
+
199
+ return check_result
200
+
140
201
 
141
202
  class FuseBatchNormIntoGemm(_FuseBatchNormBase):
142
203
  """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
@@ -0,0 +1,295 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Fusion rule to remove an Expand node before a binary operator.
4
+
5
+ This implements the optimization:
6
+
7
+ BinaryOp(Expand(x, shape), y) -> BinaryOp(x, y)
8
+ BinaryOp(x, Expand(y, shape)) -> BinaryOp(x, y)
9
+
10
+ This is valid when the binary operator's broadcasting semantics would produce
11
+ the same output shape as first expanding the input and then applying the op.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from onnxscript import ir
17
+ from onnxscript.rewriter._basics import MatchResult
18
+ from onnxscript.rewriter._ir_utils import get_numpy_value
19
+ from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
20
+
21
+ # Binary operators in ONNX standard opset that support numpy-style broadcasting.
22
+ _BROADCAST_BINARY_OPS: tuple[str, ...] = (
23
+ "Add",
24
+ "And",
25
+ "BitShift",
26
+ "BitwiseAnd",
27
+ "BitwiseOr",
28
+ "BitwiseXor",
29
+ "Div",
30
+ "Equal",
31
+ "Greater",
32
+ "GreaterOrEqual",
33
+ "Less",
34
+ "LessOrEqual",
35
+ "Mod",
36
+ "Mul",
37
+ "Or",
38
+ "Pow",
39
+ "PRelu",
40
+ "Sub",
41
+ "Xor",
42
+ )
43
+
44
+
45
+ def _compute_broadcast_dim(d1, d2):
46
+ """Return the numpy broadcast of two dimension values.
47
+
48
+ Each dimension value may be an ``int`` or an ``onnx_ir.SymbolicDim``.
49
+ Returns ``None`` when the result cannot be determined statically (e.g. two
50
+ distinct symbolic values neither of which is known to be 1).
51
+ """
52
+ if d1 == 1:
53
+ return d2
54
+ if d2 == 1:
55
+ return d1
56
+ if d1 == d2:
57
+ return d1
58
+ return None
59
+
60
+
61
+ def _compute_broadcast_shape(shape1: ir.Shape, shape2: ir.Shape) -> list | None:
62
+ """Compute numpy-style broadcast shape symbolically.
63
+
64
+ Returns the broadcast shape as a list of dimension values (``int`` or
65
+ ``SymbolicDim``), or ``None`` when the result cannot be determined (e.g.
66
+ unknown ranks or incompatible static dims).
67
+ """
68
+ rank1 = shape1.rank()
69
+ rank2 = shape2.rank()
70
+ if rank1 is None or rank2 is None:
71
+ return None
72
+ rank = max(rank1, rank2)
73
+ result = []
74
+ for i in range(rank):
75
+ idx1 = rank1 - rank + i
76
+ d1 = shape1[idx1] if idx1 >= 0 else 1
77
+ idx2 = rank2 - rank + i
78
+ d2 = shape2[idx2] if idx2 >= 0 else 1
79
+ d = _compute_broadcast_dim(d1, d2)
80
+ if d is None:
81
+ return None
82
+ result.append(d)
83
+ return result
84
+
85
+
86
+ def _check_dims_sufficient(
87
+ expand_shape: ir.Shape,
88
+ x_shape: ir.Shape,
89
+ y_shape: ir.Shape,
90
+ ) -> MatchResult:
91
+ """Check that x and y together cover every dimension of the expand target.
92
+
93
+ For each dimension ``i`` of *expand_shape* (right-aligned) the expand is
94
+ considered redundant when at least one of the following holds:
95
+
96
+ - ``expand_shape[i] == 1`` - expand cannot shrink a dim, so ``x_d`` must
97
+ also be 1 and both with and without expand produce ``y_d``.
98
+ - ``x_d == expand_shape[i]`` - the expand is a no-op at this dim.
99
+ - ``y_d == expand_shape[i]`` - ``y`` already supplies this expansion.
100
+
101
+ Comparisons work for both ``int`` and ``SymbolicDim`` values.
102
+ """
103
+ check_result = MatchResult()
104
+ e_rank = expand_shape.rank()
105
+ x_rank = x_shape.rank()
106
+ y_rank = y_shape.rank()
107
+ if e_rank is None:
108
+ return check_result.fail("Expand output rank is unknown.")
109
+
110
+ for rev_i in range(e_rank):
111
+ i = e_rank - 1 - rev_i
112
+ e_d = expand_shape[i]
113
+
114
+ if isinstance(e_d, int) and e_d == 1:
115
+ continue # expand cannot shrink; x_d is also 1, no-op
116
+
117
+ x_idx = x_rank - 1 - rev_i
118
+ x_d = x_shape[x_idx] if x_idx >= 0 else 1
119
+ if x_d == e_d:
120
+ continue # expand is a no-op at this dimension
121
+
122
+ y_idx = y_rank - 1 - rev_i
123
+ y_d = y_shape[y_idx] if y_idx >= 0 else 1
124
+ if y_d == e_d:
125
+ continue # y already supplies this dimension
126
+
127
+ return check_result.fail(
128
+ f"Cannot verify that removing Expand is safe at dimension {i}: "
129
+ f"x_d={x_d!r}, expand_d={e_d!r}, y_d={y_d!r}."
130
+ )
131
+
132
+ return check_result
133
+
134
+
135
+ def _check_expand_removable(
136
+ expand_input: ir.Value,
137
+ shape: ir.Value,
138
+ other_input: ir.Value,
139
+ expand_output: ir.Value | None = None,
140
+ binary_op_output: ir.Value | None = None,
141
+ ) -> MatchResult:
142
+ """Check if an Expand node can be safely removed before a binary op.
143
+
144
+ The Expand ``expanded_x = Expand(x, expand_shape)`` before a binary op
145
+ ``out = BinaryOp(expanded_x, y)`` is redundant when the binary op's own
146
+ broadcasting produces the same output as if the expand had been applied.
147
+
148
+ Three strategies are tried in order:
149
+
150
+ 1. **Constant expand shape** - When ``shape`` is a compile-time constant,
151
+ the dimension values are extracted from it and the check is performed
152
+ directly.
153
+
154
+ 2. **Expand output shape annotation** - When ``shape`` is dynamic but the
155
+ Expand node's output value already carries a shape annotation (e.g.
156
+ after ONNX shape inference has been applied to the model), those
157
+ dimension values are used for the check.
158
+
159
+ 3. **Binary op output shape** - When neither of the above is available,
160
+ the rule verifies that ``broadcast(x.shape, y.shape)`` symbolically
161
+ equals the binary op's output shape. If they agree, the binary op's
162
+ own broadcasting already accounts for all the expansion and the
163
+ Expand is redundant.
164
+
165
+ Args:
166
+ expand_input: The value fed into the Expand node (``x``).
167
+ shape: The target shape operand of the Expand node.
168
+ other_input: The other operand of the binary op (``y``).
169
+ expand_output: The output value of the Expand node. Required for
170
+ strategy 2.
171
+ binary_op_output: The output value of the binary op. Required for
172
+ strategy 3.
173
+
174
+ Returns:
175
+ A :class:`MatchResult` that is successful when the Expand can be
176
+ removed.
177
+ """
178
+ check_result = MatchResult()
179
+
180
+ x_shape = expand_input.shape
181
+ y_shape = other_input.shape
182
+ if x_shape is None or y_shape is None:
183
+ return check_result.fail("Input shapes are not known.")
184
+
185
+ x_rank = x_shape.rank()
186
+ y_rank = y_shape.rank()
187
+
188
+ # --- Strategy 1: expand target shape is a compile-time constant ---
189
+ expand_shape_val = get_numpy_value(shape)
190
+ if expand_shape_val is not None:
191
+ expand_shape = tuple(int(v) for v in expand_shape_val.tolist())
192
+ expand_rank = len(expand_shape)
193
+
194
+ for rev_i in range(expand_rank):
195
+ i = expand_rank - 1 - rev_i
196
+ e_d = expand_shape[i] # always a known integer from numpy
197
+
198
+ if e_d == 1:
199
+ continue # expand cannot shrink; x_d is also 1, no-op
200
+
201
+ x_idx = x_rank - 1 - rev_i
202
+ x_d = x_shape[x_idx] if x_idx >= 0 else 1
203
+
204
+ if isinstance(x_d, int) and x_d == e_d:
205
+ continue # expand is a no-op at this dimension
206
+
207
+ y_idx = y_rank - 1 - rev_i
208
+ y_d = y_shape[y_idx] if y_idx >= 0 else 1
209
+
210
+ if isinstance(y_d, int) and y_d == e_d:
211
+ continue # y already supplies this dimension
212
+
213
+ return check_result.fail(
214
+ f"Cannot verify that removing Expand is safe at dimension {i}: "
215
+ f"x_d={x_d!r}, expand_d={e_d}, y_d={y_d!r}."
216
+ )
217
+
218
+ return check_result
219
+
220
+ # --- Strategy 2: Expand output shape is known (e.g. from shape inference) ---
221
+ if expand_output is not None and expand_output.shape is not None:
222
+ return _check_dims_sufficient(expand_output.shape, x_shape, y_shape)
223
+
224
+ # --- Strategy 3: use the binary op's output shape ---
225
+ # broadcast(x.shape, y.shape) must equal the binary op's output shape.
226
+ # If it does, the binary op's own broadcasting already produces the same
227
+ # result as first expanding x and then broadcasting.
228
+ if binary_op_output is not None and binary_op_output.shape is not None:
229
+ op_output_shape = binary_op_output.shape
230
+ if op_output_shape.rank() is not None:
231
+ computed = _compute_broadcast_shape(x_shape, y_shape)
232
+ if computed is not None and len(computed) == op_output_shape.rank():
233
+ if all(c == a for c, a in zip(computed, op_output_shape)):
234
+ return check_result
235
+ return check_result.fail(
236
+ "broadcast(x.shape, y.shape) does not match the binary op output shape."
237
+ )
238
+
239
+ return check_result.fail(
240
+ "Expand target shape is not a constant and no shape annotations are available."
241
+ )
242
+
243
+
244
+ class _ExpandFirstInput(RewriteRuleClassBase):
245
+ """Removes ``BinaryOp(Expand(x, shape), y)`` -> ``BinaryOp(x, y)``."""
246
+
247
+ def __init__(self, op_type: str) -> None:
248
+ super().__init__(f"ExpandFirst_{op_type}", remove_nodes=False)
249
+ self._op_type = op_type
250
+
251
+ def pattern(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
252
+ return getattr(op, self._op_type)(op.Expand(x, shape), y)
253
+
254
+ def check(self, context, x: ir.Value, shape: ir.Value, y: ir.Value) -> MatchResult:
255
+ expand_output = context.root.inputs[0] if context.root.inputs else None
256
+ binary_op_output = context.root.outputs[0] if context.root.outputs else None
257
+ return _check_expand_removable(
258
+ x, shape, y, expand_output=expand_output, binary_op_output=binary_op_output
259
+ )
260
+
261
+ def rewrite(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value:
262
+ return getattr(op, self._op_type)(x, y)
263
+
264
+
265
+ class _ExpandSecondInput(RewriteRuleClassBase):
266
+ """Removes ``BinaryOp(x, Expand(y, shape))`` -> ``BinaryOp(x, y)``."""
267
+
268
+ def __init__(self, op_type: str) -> None:
269
+ super().__init__(f"ExpandSecond_{op_type}", remove_nodes=False)
270
+ self._op_type = op_type
271
+
272
+ def pattern(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
273
+ return getattr(op, self._op_type)(x, op.Expand(y, shape))
274
+
275
+ def check(self, context, x: ir.Value, y: ir.Value, shape: ir.Value) -> MatchResult:
276
+ expand_output = context.root.inputs[1] if context.root.inputs else None
277
+ binary_op_output = context.root.outputs[0] if context.root.outputs else None
278
+ return _check_expand_removable(
279
+ y, shape, x, expand_output=expand_output, binary_op_output=binary_op_output
280
+ )
281
+
282
+ def rewrite(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value:
283
+ return getattr(op, self._op_type)(x, y)
284
+
285
+
286
+ def _make_expand_before_binary_op_rules() -> list:
287
+ """Create rewrite rules for removing Expand before each supported binary op."""
288
+ rules = []
289
+ for op_type in _BROADCAST_BINARY_OPS:
290
+ rules.append(_ExpandFirstInput.rule(op_type))
291
+ rules.append(_ExpandSecondInput.rule(op_type))
292
+ return rules
293
+
294
+
295
+ expand_before_binary_op_rules = RewriteRuleSet(_make_expand_before_binary_op_rules())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnxscript
3
- Version: 0.6.3.dev20260403
3
+ Version: 0.6.3.dev20260411
4
4
  Summary: Naturally author ONNX functions and models using a subset of Python
5
5
  Author-email: Microsoft Corporation <onnx@microsoft.com>
6
6
  License: MIT License
@@ -27,7 +27,7 @@ License: MIT License
27
27
 
28
28
  Project-URL: Homepage, https://microsoft.github.io/onnxscript/
29
29
  Project-URL: Repository, https://github.com/microsoft/onnxscript
30
- Project-URL: Commit, https://github.com/microsoft/onnxscript/tree/a47ccbadd85851ebdafd41e13f880e6ea399535d
30
+ Project-URL: Commit, https://github.com/microsoft/onnxscript/tree/13f265cd01b21210267b86a24efdb0072c0ee374
31
31
  Classifier: Development Status :: 4 - Beta
32
32
  Classifier: Environment :: Console
33
33
  Classifier: Intended Audience :: Developers
@@ -169,6 +169,7 @@ onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py
169
169
  onnxscript/rewriter/rules/common/_min_max_to_clip.py
170
170
  onnxscript/rewriter/rules/common/_no_op.py
171
171
  onnxscript/rewriter/rules/common/_redundant_scatter_nd.py
172
+ onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py
172
173
  onnxscript/rewriter/rules/common/_remove_optional_bias.py
173
174
  onnxscript/rewriter/rules/fusion/__init__.py
174
175
  onnxscript/rewriter/rules/fusion/_gqa.py