onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,108 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ from diffusion_models import PipelineInfo
8
+ from engine_builder import EngineBuilder, EngineType
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class TorchEngineBuilder(EngineBuilder):
14
+ def __init__(
15
+ self,
16
+ pipeline_info: PipelineInfo,
17
+ max_batch_size=16,
18
+ device="cuda",
19
+ use_cuda_graph=False,
20
+ ):
21
+ """
22
+ Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
23
+
24
+ Args:
25
+ pipeline_info (PipelineInfo):
26
+ Version and Type of pipeline.
27
+ max_batch_size (int):
28
+ Maximum batch size for dynamic batch engine.
29
+ device (str):
30
+ device to run.
31
+ use_cuda_graph (bool):
32
+ Use CUDA graph to capture engine execution and then launch inference
33
+ """
34
+ super().__init__(
35
+ EngineType.TORCH,
36
+ pipeline_info,
37
+ max_batch_size=max_batch_size,
38
+ device=device,
39
+ use_cuda_graph=use_cuda_graph,
40
+ )
41
+
42
+ self.compile_config = {}
43
+ if use_cuda_graph:
44
+ self.compile_config = {
45
+ "clip": {"mode": "reduce-overhead", "dynamic": False},
46
+ "clip2": {"mode": "reduce-overhead", "dynamic": False},
47
+ "unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
48
+ "unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
49
+ "vae": {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False},
50
+ }
51
+
52
+ def build_engines(
53
+ self,
54
+ framework_model_dir: str,
55
+ ):
56
+ import torch # noqa: PLC0415
57
+
58
+ self.torch_device = torch.device("cuda", torch.cuda.current_device())
59
+ self.load_models(framework_model_dir)
60
+
61
+ pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None
62
+
63
+ built_engines = {}
64
+ for model_name, model_obj in self.models.items():
65
+ model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
66
+ if self.pipeline_info.is_xl() and not self.custom_fp16_vae:
67
+ model = model.to(device=self.torch_device, dtype=torch.float32)
68
+ else:
69
+ model = model.to(device=self.torch_device, dtype=torch.float16)
70
+
71
+ if model_name in self.compile_config:
72
+ compile_config = self.compile_config[model_name]
73
+ if model_name in ["unet", "unetxl"]:
74
+ model.to(memory_format=torch.channels_last)
75
+ engine = torch.compile(model, **compile_config)
76
+ built_engines[model_name] = engine
77
+ else: # eager mode
78
+ built_engines[model_name] = model
79
+
80
+ self.engines = built_engines
81
+
82
+ def run_engine(self, model_name, feed_dict):
83
+ if model_name in ["unet", "unetxl"]:
84
+ if "controlnet_images" in feed_dict:
85
+ return {"latent": self.engines[model_name](**feed_dict)}
86
+
87
+ if model_name == "unetxl":
88
+ added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]}
89
+ return {
90
+ "latent": self.engines[model_name](
91
+ feed_dict["sample"],
92
+ feed_dict["timestep"],
93
+ feed_dict["encoder_hidden_states"],
94
+ added_cond_kwargs=added_cond_kwargs,
95
+ return_dict=False,
96
+ )[0]
97
+ }
98
+
99
+ return {
100
+ "latent": self.engines[model_name](
101
+ feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False
102
+ )[0]
103
+ }
104
+
105
+ if model_name in ["vae_encoder"]:
106
+ return {"latent": self.engines[model_name](feed_dict["images"])}
107
+
108
+ raise RuntimeError(f"Shall not reach here: {model_name}")
@@ -0,0 +1,590 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ #
6
+ # This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
7
+ #
8
+ # Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint
9
+ # to float32 onnx models.
10
+ #
11
+ # For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16
12
+ # like the following:
13
+ # python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16
14
+ #
15
+ # Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support
16
+ # for the fused operators. The users could disable the operator fusion manually to workaround.
17
+
18
+ import argparse
19
+ import logging
20
+ import os
21
+ import shutil
22
+ import tempfile
23
+ import warnings
24
+ from pathlib import Path
25
+
26
+ import onnx
27
+ from fusion_options import FusionOptions
28
+ from onnx_model_clip import ClipOnnxModel
29
+ from onnx_model_mmdit import MmditOnnxModel
30
+ from onnx_model_t5 import T5OnnxModel
31
+ from onnx_model_unet import UnetOnnxModel
32
+ from onnx_model_vae import VaeOnnxModel
33
+ from optimizer import optimize_by_onnxruntime, optimize_model
34
+ from packaging import version
35
+
36
+ import onnxruntime
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ def has_external_data(onnx_model_path):
42
+ original_model = onnx.load_model(str(onnx_model_path), load_external_data=False)
43
+ for initializer in original_model.graph.initializer:
44
+ if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL:
45
+ return True
46
+ return False
47
+
48
+
49
+ def is_sd_3(source_dir: Path):
50
+ return (source_dir / "text_encoder_3").exists()
51
+
52
+
53
+ def is_sdxl(source_dir: Path):
54
+ return (
55
+ (source_dir / "text_encoder_2").exists()
56
+ and not (source_dir / "text_encoder_3").exists()
57
+ and not (source_dir / "transformer").exists()
58
+ )
59
+
60
+
61
+ def is_flux(source_dir: Path):
62
+ return (
63
+ (source_dir / "text_encoder_2").exists()
64
+ and not (source_dir / "text_encoder_3").exists()
65
+ and (source_dir / "transformer").exists()
66
+ )
67
+
68
+
69
+ def _classify_pipeline_type(source_dir: Path):
70
+ # May also check _class_name in model_index.json like `StableDiffusion3Pipeline` or `FluxPipeline` etc to classify.
71
+ if is_sd_3(source_dir):
72
+ return "sd3"
73
+
74
+ if is_flux(source_dir):
75
+ return "flux"
76
+
77
+ if is_sdxl(source_dir):
78
+ return "sdxl"
79
+
80
+ # sd 1.x and 2.x
81
+ return "sd"
82
+
83
+
84
+ def _get_model_list(pipeline_type: str):
85
+ if pipeline_type == "sd3":
86
+ return ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae_encoder", "vae_decoder"]
87
+
88
+ if pipeline_type == "flux":
89
+ return ["text_encoder", "text_encoder_2", "transformer", "vae_encoder", "vae_decoder"]
90
+
91
+ if pipeline_type == "sdxl":
92
+ return ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder"]
93
+
94
+ assert pipeline_type == "sd"
95
+ return ["text_encoder", "unet", "vae_encoder", "vae_decoder"]
96
+
97
+
98
+ def _optimize_sd_pipeline(
99
+ source_dir: Path,
100
+ target_dir: Path,
101
+ pipeline_type: str,
102
+ model_list: list[str],
103
+ use_external_data_format: bool | None,
104
+ float16: bool,
105
+ bfloat16: bool,
106
+ force_fp32_ops: list[str],
107
+ enable_runtime_optimization: bool,
108
+ args,
109
+ ):
110
+ """Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
111
+
112
+ Args:
113
+ source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
114
+ target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
115
+ model_list (List[str]): list of directory names with onnx model.
116
+ use_external_data_format (Optional[bool]): use external data format.
117
+ float16 (bool): use half precision
118
+ bfloat16 (bool): use bfloat16 as fallback if float16 is also provided.
119
+ force_fp32_ops(List[str]): operators that are forced to run in float32.
120
+ enable_runtime_optimization(bool): run graph optimization using Onnx Runtime.
121
+
122
+ Raises:
123
+ RuntimeError: input onnx model does not exist
124
+ RuntimeError: output onnx model path existed
125
+ """
126
+ is_flux_pipeline = pipeline_type == "flux"
127
+ model_type_mapping = {
128
+ "transformer": "mmdit",
129
+ "unet": "unet",
130
+ "vae_encoder": "vae",
131
+ "vae_decoder": "vae",
132
+ "text_encoder": "clip",
133
+ "text_encoder_2": "t5" if is_flux_pipeline else "clip",
134
+ "text_encoder_3": "t5", # t5-v1_1-xxl is used in SD 3.x text_encoder_3 and Flux text_encoder_2.
135
+ "safety_checker": "unet",
136
+ }
137
+
138
+ model_type_class_mapping = {
139
+ "unet": UnetOnnxModel,
140
+ "vae": VaeOnnxModel,
141
+ "clip": ClipOnnxModel,
142
+ "t5": T5OnnxModel,
143
+ "mmdit": MmditOnnxModel,
144
+ }
145
+
146
+ force_fp32_operators = {
147
+ "unet": [],
148
+ "vae_encoder": [],
149
+ "vae_decoder": [],
150
+ "text_encoder": [],
151
+ "text_encoder_2": [],
152
+ "safety_checker": [],
153
+ "text_encoder_3": [],
154
+ "transformer": [],
155
+ }
156
+
157
+ # The node block list is generated by running the fp32 model and get statistics of node inputs and outputs.
158
+ # Nodes with any input or output of float or double data type, but value ouf of range of float16 are candidates.
159
+ # python optimize_pipeline.py -i ./flux1_schnell_onnx/fp32 -o ./flux1_schnell_onnx/fp32_opt
160
+ # export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1
161
+ # export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1
162
+ # export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1
163
+ # python benchmark.py --height 1024 --width 1024 --steps 4 -b 1 -v Flux.1S -p flux1_schnell_onnx/fp32_opt -e optimum >stdout.txt 2>stderr.txt
164
+ # Warning: The node name might change in different export settings. See benchmark_flux.sh for the settings.
165
+ flux_node_block_list = {
166
+ "text_encoder_2": [
167
+ "/encoder/block.10/layer.1/DenseReluDense/wo/MatMul",
168
+ "SkipLayerNorm_20",
169
+ "SkipLayerNorm_21",
170
+ "SkipLayerNorm_22",
171
+ "SkipLayerNorm_23",
172
+ "SkipLayerNorm_24",
173
+ "SkipLayerNorm_25",
174
+ "SkipLayerNorm_26",
175
+ "SkipLayerNorm_27",
176
+ "SkipLayerNorm_28",
177
+ "SkipLayerNorm_29",
178
+ "SkipLayerNorm_30",
179
+ "SkipLayerNorm_31",
180
+ "SkipLayerNorm_32",
181
+ "SkipLayerNorm_33",
182
+ "SkipLayerNorm_34",
183
+ "SkipLayerNorm_35",
184
+ "SkipLayerNorm_36",
185
+ "SkipLayerNorm_37",
186
+ "SkipLayerNorm_38",
187
+ "SkipLayerNorm_39",
188
+ "SkipLayerNorm_40",
189
+ "SkipLayerNorm_41",
190
+ "SkipLayerNorm_42",
191
+ "SkipLayerNorm_43",
192
+ "SkipLayerNorm_44",
193
+ "SkipLayerNorm_45",
194
+ "/encoder/block.23/layer.1/DenseReluDense/wo/MatMul",
195
+ "SkipLayerNorm_46",
196
+ ],
197
+ "vae_decoder": [
198
+ "/decoder/mid_block/attentions.0/MatMul",
199
+ "/decoder/mid_block/attentions.0/Softmax",
200
+ ],
201
+ "transformer": [
202
+ "/transformer_blocks.18/Mul_5",
203
+ "/transformer_blocks.18/Add_7",
204
+ "/Concat_1",
205
+ "LayerNorm_76",
206
+ "/single_transformer_blocks.0/Add",
207
+ "LayerNorm_77",
208
+ "/single_transformer_blocks.1/Add",
209
+ "LayerNorm_78",
210
+ "/single_transformer_blocks.2/Add",
211
+ "LayerNorm_79",
212
+ "/single_transformer_blocks.3/Add",
213
+ "LayerNorm_80",
214
+ "/single_transformer_blocks.4/Add",
215
+ "LayerNorm_81",
216
+ "/single_transformer_blocks.5/Add",
217
+ "LayerNorm_82",
218
+ "/single_transformer_blocks.6/Add",
219
+ "LayerNorm_83",
220
+ "/single_transformer_blocks.7/Add",
221
+ "LayerNorm_84",
222
+ "/single_transformer_blocks.8/Add",
223
+ "LayerNorm_85",
224
+ "/single_transformer_blocks.9/Add",
225
+ "LayerNorm_86",
226
+ "/single_transformer_blocks.10/Add",
227
+ "LayerNorm_87",
228
+ "/single_transformer_blocks.11/Add",
229
+ "LayerNorm_88",
230
+ "/single_transformer_blocks.12/Add",
231
+ "LayerNorm_89",
232
+ "/single_transformer_blocks.13/Add",
233
+ "LayerNorm_90",
234
+ "/single_transformer_blocks.14/Add",
235
+ "LayerNorm_91",
236
+ "/single_transformer_blocks.15/Add",
237
+ "LayerNorm_92",
238
+ "/single_transformer_blocks.16/Add",
239
+ "LayerNorm_93",
240
+ "/single_transformer_blocks.17/Add",
241
+ "LayerNorm_94",
242
+ "/single_transformer_blocks.18/Add",
243
+ "LayerNorm_95",
244
+ "/single_transformer_blocks.19/Add",
245
+ "LayerNorm_96",
246
+ "/single_transformer_blocks.20/Add",
247
+ "LayerNorm_97",
248
+ "/single_transformer_blocks.21/Add",
249
+ "LayerNorm_98",
250
+ "/single_transformer_blocks.22/Add",
251
+ "LayerNorm_99",
252
+ "/single_transformer_blocks.23/Add",
253
+ "LayerNorm_100",
254
+ "/single_transformer_blocks.24/Add",
255
+ "LayerNorm_101",
256
+ "/single_transformer_blocks.25/Add",
257
+ "LayerNorm_102",
258
+ "/single_transformer_blocks.26/Add",
259
+ "LayerNorm_103",
260
+ "/single_transformer_blocks.27/Add",
261
+ "LayerNorm_104",
262
+ "/single_transformer_blocks.28/Add",
263
+ "LayerNorm_105",
264
+ "/single_transformer_blocks.29/Add",
265
+ "LayerNorm_106",
266
+ "/single_transformer_blocks.30/Add",
267
+ "LayerNorm_107",
268
+ "/single_transformer_blocks.31/Add",
269
+ "LayerNorm_108",
270
+ "/single_transformer_blocks.32/Add",
271
+ "LayerNorm_109",
272
+ "/single_transformer_blocks.33/Add",
273
+ "LayerNorm_110",
274
+ "/single_transformer_blocks.34/Add",
275
+ "LayerNorm_111",
276
+ "/single_transformer_blocks.35/Add",
277
+ "LayerNorm_112",
278
+ "/single_transformer_blocks.36/Add",
279
+ "LayerNorm_113",
280
+ "/single_transformer_blocks.37/Add",
281
+ "/Shape",
282
+ "/Slice",
283
+ ],
284
+ }
285
+
286
+ sd3_node_block_list = {"text_encoder_3": flux_node_block_list["text_encoder_2"]}
287
+
288
+ if force_fp32_ops:
289
+ for fp32_operator in force_fp32_ops:
290
+ parts = fp32_operator.split(":")
291
+ if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()):
292
+ force_fp32_operators[parts[0]].append(parts[1])
293
+ else:
294
+ raise ValueError(
295
+ f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}"
296
+ )
297
+
298
+ op_counters = {}
299
+ for name, model_type in model_type_mapping.items():
300
+ onnx_model_path = source_dir / name / "model.onnx"
301
+ if not os.path.exists(onnx_model_path):
302
+ if name != "safety_checker" and name in model_list:
303
+ logger.warning("input onnx model does not exist: %s", onnx_model_path)
304
+ # some model are optional so we do not raise error here.
305
+ continue
306
+
307
+ # Prepare output directory
308
+ optimized_model_path = target_dir / name / "model.onnx"
309
+ if os.path.exists(optimized_model_path):
310
+ if not args.overwrite:
311
+ logger.warning("Skipped optimization since the target file existed: %s", optimized_model_path)
312
+ continue
313
+ output_dir = optimized_model_path.parent
314
+ output_dir.mkdir(parents=True, exist_ok=True)
315
+
316
+ if use_external_data_format is None:
317
+ use_external_data_format = has_external_data(onnx_model_path)
318
+
319
+ # Graph fusion before fp16 conversion, otherwise they cannot be fused later.
320
+ logger.info("Optimize %s ...", onnx_model_path)
321
+
322
+ args.model_type = model_type
323
+ fusion_options = FusionOptions.parse(args)
324
+
325
+ if model_type in ["unet"]:
326
+ # Some optimizations are not available in v1.14 or older version: packed QKV and BiasAdd
327
+ has_all_optimizations = version.parse(onnxruntime.__version__) >= version.parse("1.15.0")
328
+ fusion_options.enable_packed_kv = float16 and fusion_options.enable_packed_kv
329
+ fusion_options.enable_packed_qkv = float16 and has_all_optimizations and fusion_options.enable_packed_qkv
330
+ fusion_options.enable_bias_add = has_all_optimizations and fusion_options.enable_bias_add
331
+
332
+ m = optimize_model(
333
+ str(onnx_model_path),
334
+ model_type=model_type,
335
+ num_heads=0, # will be deduced from graph
336
+ hidden_size=0, # will be deduced from graph
337
+ opt_level=0,
338
+ optimization_options=fusion_options,
339
+ use_gpu=True,
340
+ provider=args.provider,
341
+ )
342
+
343
+ if float16:
344
+ model_node_block_list = (
345
+ flux_node_block_list if is_flux_pipeline else sd3_node_block_list if pipeline_type == "sd3" else {}
346
+ )
347
+ if name in model_node_block_list:
348
+ # Opset 12 does not support bfloat16.
349
+ # By default, optimum exports T5 model with opset 12. So we need to check the opset version.
350
+ use_bfloat16 = bfloat16
351
+ if use_bfloat16:
352
+ for opset in m.model.opset_import:
353
+ if opset.domain in ["", "ai.onnx"] and opset.version < 13:
354
+ logger.warning(
355
+ "onnx model requires opset 13 or higher to use bfloat16. Fall back to float32."
356
+ )
357
+ use_bfloat16 = False
358
+
359
+ m.convert_float_to_float16(
360
+ keep_io_types=False,
361
+ node_block_list=model_node_block_list[name],
362
+ use_bfloat16_as_blocked_nodes_dtype=use_bfloat16,
363
+ )
364
+ # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
365
+ elif pipeline_type in ["sdxl"] and name in ["vae_decoder"]:
366
+ logger.info("Skip converting %s to float16 to avoid NaN", name)
367
+ else:
368
+ logger.info("Convert %s to float16 ...", name)
369
+ m.convert_float_to_float16(
370
+ keep_io_types=False,
371
+ op_block_list=force_fp32_operators[name],
372
+ )
373
+
374
+ if enable_runtime_optimization:
375
+ # Use this step to see the final graph that executed by Onnx Runtime.
376
+ with tempfile.TemporaryDirectory() as tmp_dir:
377
+ # Save to a temporary file so that we can load it with Onnx Runtime.
378
+ logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
379
+ tmp_model_path = Path(tmp_dir) / "model.onnx"
380
+ m.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
381
+ ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
382
+ optimize_by_onnxruntime(
383
+ str(tmp_model_path),
384
+ use_gpu=True,
385
+ provider=args.provider,
386
+ optimized_model_path=str(ort_optimized_model_path),
387
+ save_as_external_data=use_external_data_format,
388
+ )
389
+ model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
390
+ m = model_type_class_mapping[model_type](model)
391
+
392
+ m.get_operator_statistics()
393
+ op_counters[name] = m.get_fused_operator_statistics()
394
+ m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
395
+ logger.info("%s is optimized", name)
396
+ logger.info("*" * 20)
397
+
398
+ return op_counters
399
+
400
+
401
+ def _copy_extra_directory(source_dir: Path, target_dir: Path, model_list: list[str]):
402
+ """Copy extra directory that does not have onnx model
403
+
404
+ Args:
405
+ source_dir (Path): source directory
406
+ target_dir (Path): target directory
407
+ model_list (List[str]): list of directory names with onnx model.
408
+
409
+ Raises:
410
+ RuntimeError: source path does not exist
411
+ """
412
+ extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "tokenizer_3", "feature_extractor"]
413
+
414
+ for name in extra_dirs:
415
+ source_path = source_dir / name
416
+ if not os.path.exists(source_path):
417
+ continue
418
+
419
+ target_path = target_dir / name
420
+ if target_path.exists():
421
+ shutil.rmtree(target_path)
422
+ shutil.copytree(source_path, target_path)
423
+ logger.info("%s => %s", source_path, target_path)
424
+
425
+ extra_files = ["model_index.json"]
426
+ for name in extra_files:
427
+ source_path = source_dir / name
428
+ if not os.path.exists(source_path):
429
+ raise RuntimeError(f"source path does not exist: {source_path}")
430
+
431
+ target_path = target_dir / name
432
+ shutil.copyfile(source_path, target_path)
433
+ logger.info("%s => %s", source_path, target_path)
434
+
435
+ # Some directory are optional
436
+ for onnx_model_dir in model_list:
437
+ source_path = source_dir / onnx_model_dir / "config.json"
438
+ target_path = target_dir / onnx_model_dir / "config.json"
439
+ if source_path.exists():
440
+ target_path.parent.mkdir(parents=True, exist_ok=True)
441
+ shutil.copyfile(source_path, target_path)
442
+ logger.info("%s => %s", source_path, target_path)
443
+
444
+
445
+ def optimize_stable_diffusion_pipeline(
446
+ input_dir: str,
447
+ output_dir: str,
448
+ overwrite: bool,
449
+ use_external_data_format: bool | None,
450
+ float16: bool,
451
+ enable_runtime_optimization: bool,
452
+ args,
453
+ ):
454
+ if os.path.exists(output_dir):
455
+ if overwrite:
456
+ shutil.rmtree(output_dir, ignore_errors=True)
457
+
458
+ source_dir = Path(input_dir)
459
+ target_dir = Path(output_dir)
460
+ target_dir.mkdir(parents=True, exist_ok=True)
461
+
462
+ pipeline_type = _classify_pipeline_type(source_dir)
463
+ model_list = _get_model_list(pipeline_type)
464
+
465
+ _copy_extra_directory(source_dir, target_dir, model_list)
466
+
467
+ return _optimize_sd_pipeline(
468
+ source_dir,
469
+ target_dir,
470
+ pipeline_type,
471
+ model_list,
472
+ use_external_data_format,
473
+ float16,
474
+ args.bfloat16,
475
+ args.force_fp32_ops,
476
+ enable_runtime_optimization,
477
+ args,
478
+ )
479
+
480
+
481
+ def parse_arguments(argv: list[str] | None = None):
482
+ """Parse arguments
483
+
484
+ Returns:
485
+ Namespace: arguments
486
+ """
487
+ parser = argparse.ArgumentParser()
488
+
489
+ parser.add_argument(
490
+ "-i",
491
+ "--input",
492
+ required=True,
493
+ type=str,
494
+ help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
495
+ )
496
+
497
+ parser.add_argument(
498
+ "-o",
499
+ "--output",
500
+ required=True,
501
+ type=str,
502
+ help="Root of output directory of stable diffusion onnx pipeline with optimized models.",
503
+ )
504
+
505
+ parser.add_argument(
506
+ "--float16",
507
+ required=False,
508
+ action="store_true",
509
+ help="Output models of float16, except some nodes falls back to float32 or bfloat16 to avoid overflow.",
510
+ )
511
+ parser.set_defaults(float16=False)
512
+
513
+ parser.add_argument(
514
+ "--bfloat16",
515
+ required=False,
516
+ action="store_true",
517
+ help="Allow bfloat16 as fallback if --float16 is also provided.",
518
+ )
519
+ parser.set_defaults(bfloat16=False)
520
+
521
+ parser.add_argument(
522
+ "--force_fp32_ops",
523
+ required=False,
524
+ nargs="+",
525
+ type=str,
526
+ help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!",
527
+ )
528
+
529
+ parser.add_argument(
530
+ "--inspect",
531
+ required=False,
532
+ action="store_true",
533
+ help="Save the optimized graph from Onnx Runtime. "
534
+ "This option has no impact on inference performance except it might reduce session creation time.",
535
+ )
536
+ parser.set_defaults(inspect=False)
537
+
538
+ parser.add_argument(
539
+ "--overwrite",
540
+ required=False,
541
+ action="store_true",
542
+ help="Overwrite exists files.",
543
+ )
544
+ parser.set_defaults(overwrite=False)
545
+
546
+ parser.add_argument(
547
+ "-e",
548
+ "--use_external_data_format",
549
+ required=False,
550
+ action="store_true",
551
+ help="Onnx model larger than 2GB need to use external data format. "
552
+ "If specified, save each onnx model to two files: one for onnx graph, another for weights. "
553
+ "If not specified, use same format as original model by default. ",
554
+ )
555
+ parser.set_defaults(use_external_data_format=None)
556
+
557
+ parser.add_argument(
558
+ "--provider",
559
+ required=False,
560
+ type=str,
561
+ default=None,
562
+ help="Execution provider to use.",
563
+ )
564
+
565
+ FusionOptions.add_arguments(parser)
566
+
567
+ args = parser.parse_args(argv)
568
+ return args
569
+
570
+
571
+ def main(argv: list[str] | None = None):
572
+ warnings.warn(
573
+ "This example is deprecated. Use the Olive recipe instead: "
574
+ "https://github.com/microsoft/olive-recipes/tree/main",
575
+ DeprecationWarning,
576
+ stacklevel=2,
577
+ )
578
+ args = parse_arguments(argv)
579
+
580
+ logger.info("Arguments: %s", str(args))
581
+
582
+ # Return op counters for testing purpose.
583
+ return optimize_stable_diffusion_pipeline(
584
+ args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args
585
+ )
586
+
587
+
588
+ if __name__ == "__main__":
589
+ logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO)
590
+ main()