onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,295 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import hashlib
6
+ import os
7
+ from enum import Enum
8
+
9
+ import torch
10
+ from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL
11
+
12
+
13
+ class EngineType(Enum):
14
+ ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider
15
+ ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider
16
+ TRT = 2 # TensorRT
17
+ TORCH = 3 # PyTorch
18
+
19
+
20
+ def get_engine_type(name: str) -> EngineType:
21
+ name_to_type = {
22
+ "ORT_CUDA": EngineType.ORT_CUDA,
23
+ "ORT_TRT": EngineType.ORT_TRT,
24
+ "TRT": EngineType.TRT,
25
+ "TORCH": EngineType.TORCH,
26
+ }
27
+ return name_to_type[name]
28
+
29
+
30
+ class EngineBuilder:
31
+ def __init__(
32
+ self,
33
+ engine_type: EngineType,
34
+ pipeline_info: PipelineInfo,
35
+ device="cuda",
36
+ max_batch_size=16,
37
+ use_cuda_graph=False,
38
+ ):
39
+ """
40
+ Initializes the Engine Builder.
41
+
42
+ Args:
43
+ pipeline_info (PipelineInfo):
44
+ Version and Type of pipeline.
45
+ device (str | torch.device):
46
+ device to run engine
47
+ max_batch_size (int):
48
+ Maximum batch size for dynamic batch engine.
49
+ use_cuda_graph (bool):
50
+ Use CUDA graph to capture engine execution and then launch inference
51
+ """
52
+ self.engine_type = engine_type
53
+ self.pipeline_info = pipeline_info
54
+ self.max_batch_size = max_batch_size
55
+ self.use_cuda_graph = use_cuda_graph
56
+ self.device = torch.device(device)
57
+ self.torch_device = torch.device(device, torch.cuda.current_device())
58
+ self.stages = pipeline_info.stages()
59
+
60
+ self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() and self.engine_type != EngineType.TORCH
61
+ self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae()
62
+
63
+ self.models = {}
64
+ self.engines = {}
65
+ self.torch_models = {}
66
+ self.use_vae_slicing = False
67
+
68
+ self.torch_sdpa = getattr(torch.nn.functional, "scaled_dot_product_attention", None)
69
+
70
+ def enable_vae_slicing(self):
71
+ self.use_vae_slicing = True
72
+
73
+ def disable_torch_spda(self):
74
+ if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
75
+ delattr(torch.nn.functional, "scaled_dot_product_attention")
76
+
77
+ def enable_torch_spda(self):
78
+ if (not hasattr(torch.nn.functional, "scaled_dot_product_attention")) and self.torch_sdpa:
79
+ torch.nn.functional.scaled_dot_product_attention = self.torch_sdpa
80
+
81
+ def teardown(self):
82
+ for engine in self.engines.values():
83
+ del engine
84
+ self.engines = {}
85
+
86
+ def get_diffusers_module_name(self, model_name):
87
+ name_mapping = {
88
+ "clip": "text_encoder",
89
+ "clip2": "text_encoder_2",
90
+ "unet": "unet",
91
+ "unetxl": "unet",
92
+ "vae": "vae_decoder",
93
+ }
94
+ return name_mapping.get(model_name, model_name)
95
+
96
+ def get_cached_model_name(self, model_name):
97
+ model_name = self.get_diffusers_module_name(model_name)
98
+ is_unet = model_name == "unet"
99
+ hash_source = []
100
+ if model_name in ["text_encoder", "text_encoder_2", "unet"] and self.pipeline_info.lora_weights:
101
+ if self.pipeline_info.lora_weights in [
102
+ "latent-consistency/lcm-lora-sdxl",
103
+ "latent-consistency/lcm-lora-sdv1-5",
104
+ ]:
105
+ if is_unet:
106
+ model_name = "unet_lcm-lora"
107
+ else:
108
+ model_name = model_name + "_lora"
109
+ hash_source.append(self.pipeline_info.lora_weights)
110
+
111
+ # TODO(tianleiwu): save custom model to a directory named by its original model.
112
+ if is_unet and self.pipeline_info.custom_unet():
113
+ model_name = model_name + "_lcm"
114
+
115
+ if model_name in ["unet"] and self.pipeline_info.controlnet:
116
+ model_name = model_name + "_" + "_".join(self.pipeline_info.controlnet)
117
+
118
+ if hash_source:
119
+ model_name += "_" + hashlib.sha256("\t".join(hash_source).encode("utf-8")).hexdigest()[:8]
120
+
121
+ # TODO: When we support original VAE, we shall save custom VAE to another directory.
122
+
123
+ if self.pipeline_info.is_inpaint():
124
+ model_name += "_inpaint"
125
+ return model_name
126
+
127
+ def get_model_dir(self, model_name, root_dir, opt=True, suffix="", create=True):
128
+ engine_name = self.engine_type.name.lower()
129
+ if engine_name != "ort_cuda" and not suffix:
130
+ suffix = f".{engine_name}" if opt else ""
131
+ directory_name = self.get_cached_model_name(model_name) + suffix
132
+ onnx_model_dir = os.path.join(root_dir, directory_name)
133
+ if create:
134
+ os.makedirs(onnx_model_dir, exist_ok=True)
135
+ return onnx_model_dir
136
+
137
+ def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""):
138
+ onnx_model_dir = self.get_model_dir(model_name, onnx_dir, opt=opt, suffix=suffix)
139
+ return os.path.join(onnx_model_dir, "model.onnx")
140
+
141
+ def get_engine_path(self, engine_dir, model_name, profile_id):
142
+ return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id)
143
+
144
+ def load_pipeline_with_lora(self):
145
+ """Load text encoders and UNet with diffusers pipeline"""
146
+ from diffusers import DiffusionPipeline # noqa: PLC0415
147
+
148
+ pipeline = DiffusionPipeline.from_pretrained(
149
+ self.pipeline_info.name(),
150
+ variant="fp16",
151
+ torch_dtype=torch.float16,
152
+ )
153
+ pipeline.load_lora_weights(self.pipeline_info.lora_weights)
154
+ pipeline.fuse_lora(lora_scale=self.pipeline_info.lora_scale)
155
+
156
+ del pipeline.vae
157
+ pipeline.vae = None
158
+ return pipeline
159
+
160
+ def get_or_load_model(self, pipeline, model_name, model_obj, framework_model_dir):
161
+ if model_name in ["clip", "clip2", "unet", "unetxl"] and pipeline:
162
+ if model_name == "clip":
163
+ model = pipeline.text_encoder
164
+ pipeline.text_encoder = None
165
+ elif model_name == "clip2":
166
+ model = pipeline.text_encoder_2
167
+ pipeline.text_encoder_2 = None
168
+ else:
169
+ model = pipeline.unet
170
+ pipeline.unet = None
171
+ else:
172
+ model = model_obj.load_model(framework_model_dir)
173
+
174
+ return model.to(self.torch_device)
175
+
176
+ def load_models(self, framework_model_dir: str):
177
+ # For TRT or ORT_TRT, we will export fp16 torch model for UNet and VAE
178
+ # For ORT_CUDA, we export fp32 model first, then optimize to fp16.
179
+ export_fp16 = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT]
180
+
181
+ if "clip" in self.stages:
182
+ self.models["clip"] = CLIP(
183
+ self.pipeline_info,
184
+ None, # not loaded yet
185
+ device=self.torch_device,
186
+ max_batch_size=self.max_batch_size,
187
+ clip_skip=0,
188
+ )
189
+
190
+ if "clip2" in self.stages:
191
+ self.models["clip2"] = CLIPWithProj(
192
+ self.pipeline_info,
193
+ None, # not loaded yet
194
+ device=self.torch_device,
195
+ max_batch_size=self.max_batch_size,
196
+ clip_skip=0,
197
+ )
198
+
199
+ if "unet" in self.stages:
200
+ self.models["unet"] = UNet(
201
+ self.pipeline_info,
202
+ None, # not loaded yet
203
+ device=self.torch_device,
204
+ fp16=export_fp16,
205
+ max_batch_size=self.max_batch_size,
206
+ unet_dim=(9 if self.pipeline_info.is_inpaint() else 4),
207
+ )
208
+
209
+ if "unetxl" in self.stages:
210
+ self.models["unetxl"] = UNetXL(
211
+ self.pipeline_info,
212
+ None, # not loaded yet
213
+ device=self.torch_device,
214
+ fp16=export_fp16,
215
+ max_batch_size=self.max_batch_size,
216
+ unet_dim=4,
217
+ time_dim=(5 if self.pipeline_info.is_xl_refiner() else 6),
218
+ )
219
+
220
+ # VAE Decoder
221
+ if "vae" in self.stages:
222
+ self.models["vae"] = VAE(
223
+ self.pipeline_info,
224
+ None, # not loaded yet
225
+ device=self.torch_device,
226
+ max_batch_size=self.max_batch_size,
227
+ fp16=export_fp16,
228
+ custom_fp16_vae=self.custom_fp16_vae,
229
+ )
230
+
231
+ if self.vae_torch_fallback:
232
+ self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir)
233
+
234
+ def load_resources(self, image_height, image_width, batch_size):
235
+ if self.engine_type == EngineType.TORCH:
236
+ return
237
+
238
+ # Allocate buffers for I/O bindings
239
+ for model_name, obj in self.models.items():
240
+ if model_name == "vae" and self.vae_torch_fallback:
241
+ continue
242
+ slice_size = 1 if (model_name == "vae" and self.use_vae_slicing) else batch_size
243
+ self.engines[model_name].allocate_buffers(
244
+ shape_dict=obj.get_shape_dict(slice_size, image_height, image_width), device=self.torch_device
245
+ )
246
+
247
+ def _vae_decode(self, latents):
248
+ if self.engine_type == EngineType.TORCH:
249
+ if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
250
+ latents = latents.to(dtype=torch.float32)
251
+ images = self.engines["vae"](latents)["sample"]
252
+ else:
253
+ images = self.engines["vae"](latents)["sample"]
254
+ elif self.vae_torch_fallback:
255
+ if not self.custom_fp16_vae:
256
+ latents = latents.to(dtype=torch.float32)
257
+ self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32)
258
+ images = self.torch_models["vae"](latents)["sample"]
259
+ else:
260
+ if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
261
+ images = self.run_engine("vae", {"latent": latents.to(dtype=torch.float32)})["images"]
262
+ else:
263
+ images = self.run_engine("vae", {"latent": latents})["images"]
264
+
265
+ return images
266
+
267
+ def vae_decode(self, latents):
268
+ if self.use_vae_slicing:
269
+ # The output tensor points to same buffer. Need clone it to avoid overwritten.
270
+ decoded_slices = [self._vae_decode(z_slice).clone() for z_slice in latents.split(1)]
271
+ return torch.cat(decoded_slices)
272
+
273
+ return self._vae_decode(latents)
274
+
275
+
276
+ def get_engine_paths(
277
+ work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: str | None = None
278
+ ):
279
+ root_dir = work_dir or "."
280
+ short_name = pipeline_info.short_name()
281
+
282
+ # When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since
283
+ # ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model.
284
+ onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx")
285
+ engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine")
286
+ output_dir = os.path.join(root_dir, engine_type.name, short_name, "output")
287
+
288
+ timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache")
289
+
290
+ # Shared among ORT_CUDA, ORT_TRT and TRT engines, and need use load_model(..., always_download_fp16=True)
291
+ # So that the shared model is always fp16.
292
+ if framework_model_dir is None:
293
+ framework_model_dir = os.path.join(root_dir, "torch_model")
294
+
295
+ return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache
@@ -0,0 +1,387 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import gc
7
+ import logging
8
+ import os
9
+
10
+ import onnx
11
+ import torch
12
+ from diffusion_models import PipelineInfo
13
+ from engine_builder import EngineBuilder, EngineType
14
+ from packaging import version
15
+
16
+ import onnxruntime as ort
17
+ from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
18
+ from onnxruntime.transformers.onnx_model import OnnxModel
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class OrtCudaEngine:
24
+ def __init__(
25
+ self,
26
+ onnx_path,
27
+ device_id: int = 0,
28
+ enable_cuda_graph: bool = False,
29
+ disable_optimization: bool = False,
30
+ max_cuda_graphs: int = 1,
31
+ ):
32
+ self.onnx_path = onnx_path
33
+ self.provider = "CUDAExecutionProvider"
34
+ self.stream = torch.cuda.current_stream().cuda_stream
35
+ self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph, self.stream)
36
+ session_options = ort.SessionOptions()
37
+
38
+ # When the model has been optimized by onnxruntime, we can disable optimization to save session creation time.
39
+ if disable_optimization:
40
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
41
+
42
+ logger.info("creating CUDA EP session for %s", onnx_path)
43
+ ort_session = ort.InferenceSession(
44
+ onnx_path,
45
+ session_options,
46
+ providers=[
47
+ (self.provider, self.provider_options),
48
+ "CPUExecutionProvider",
49
+ ],
50
+ )
51
+ logger.info("created CUDA EP session for %s", onnx_path)
52
+
53
+ device = torch.device("cuda", device_id)
54
+ self.enable_cuda_graph = enable_cuda_graph
55
+
56
+ # Support multiple CUDA graphs for different input shapes.
57
+ # For clip2 model that disabled cuda graph, max_cuda_graphs is updated to 0 here.
58
+ self.gpu_binding_manager = GpuBindingManager(
59
+ ort_session=ort_session,
60
+ device=device,
61
+ stream=self.stream,
62
+ max_cuda_graphs=max_cuda_graphs if enable_cuda_graph else 0,
63
+ )
64
+
65
+ self.current_gpu_binding = None
66
+
67
+ def metadata(self, name: str):
68
+ data = {}
69
+ if self.current_gpu_binding is not None:
70
+ if self.current_gpu_binding.last_run_gpu_graph_id >= 0:
71
+ data[f"{name}.gpu_graph_id"] = self.current_gpu_binding.last_run_gpu_graph_id
72
+ return data
73
+
74
+ def infer(self, feed_dict: dict[str, torch.Tensor]):
75
+ return self.current_gpu_binding.infer(feed_dict=feed_dict, disable_cuda_graph_in_run=not self.enable_cuda_graph)
76
+
77
+ def allocate_buffers(self, shape_dict, device):
78
+ self.current_gpu_binding = self.gpu_binding_manager.get_binding(
79
+ shape_dict=shape_dict, use_cuda_graph=self.enable_cuda_graph
80
+ )
81
+
82
+
83
+ class _ModelConfig:
84
+ """
85
+ Configuration of one model (like Clip, UNet etc) on ONNX export and optimization for CUDA provider.
86
+ For example, if you want to use fp32 in layer normalization, set the following:
87
+ force_fp32_ops=["SkipLayerNormalization", "LayerNormalization"]
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ onnx_opset_version: int,
93
+ use_cuda_graph: bool,
94
+ fp16: bool = True,
95
+ force_fp32_ops: list[str] | None = None,
96
+ optimize_by_ort: bool = True,
97
+ ):
98
+ self.onnx_opset_version = onnx_opset_version
99
+ self.use_cuda_graph = use_cuda_graph
100
+ self.fp16 = fp16
101
+ self.force_fp32_ops = force_fp32_ops
102
+ self.optimize_by_ort = optimize_by_ort
103
+
104
+
105
+ class OrtCudaEngineBuilder(EngineBuilder):
106
+ def __init__(
107
+ self,
108
+ pipeline_info: PipelineInfo,
109
+ max_batch_size=16,
110
+ device="cuda",
111
+ use_cuda_graph=False,
112
+ ):
113
+ """
114
+ Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
115
+
116
+ Args:
117
+ pipeline_info (PipelineInfo):
118
+ Version and Type of pipeline.
119
+ max_batch_size (int):
120
+ Maximum batch size for dynamic batch engine.
121
+ device (str):
122
+ device to run.
123
+ use_cuda_graph (bool):
124
+ Use CUDA graph to capture engine execution and then launch inference
125
+ """
126
+ super().__init__(
127
+ EngineType.ORT_CUDA,
128
+ pipeline_info,
129
+ max_batch_size=max_batch_size,
130
+ device=device,
131
+ use_cuda_graph=use_cuda_graph,
132
+ )
133
+
134
+ self.model_config = {}
135
+
136
+ def _configure(
137
+ self,
138
+ model_name: str,
139
+ onnx_opset_version: int,
140
+ use_cuda_graph: bool,
141
+ fp16: bool = True,
142
+ force_fp32_ops: list[str] | None = None,
143
+ optimize_by_ort: bool = True,
144
+ ):
145
+ self.model_config[model_name] = _ModelConfig(
146
+ onnx_opset_version,
147
+ use_cuda_graph,
148
+ fp16=fp16,
149
+ force_fp32_ops=force_fp32_ops,
150
+ optimize_by_ort=optimize_by_ort,
151
+ )
152
+
153
+ def configure_xl(self, onnx_opset_version: int):
154
+ self._configure(
155
+ "clip",
156
+ onnx_opset_version=onnx_opset_version,
157
+ use_cuda_graph=self.use_cuda_graph,
158
+ )
159
+ self._configure(
160
+ "clip2",
161
+ onnx_opset_version=onnx_opset_version, # TODO: ArgMax-12 is not implemented in CUDA
162
+ use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph
163
+ )
164
+ self._configure(
165
+ "unetxl",
166
+ onnx_opset_version=onnx_opset_version,
167
+ use_cuda_graph=self.use_cuda_graph,
168
+ )
169
+
170
+ self._configure(
171
+ "vae",
172
+ onnx_opset_version=onnx_opset_version,
173
+ use_cuda_graph=self.use_cuda_graph,
174
+ )
175
+
176
+ def optimized_onnx_path(self, engine_dir, model_name):
177
+ suffix = "" if self.model_config[model_name].fp16 else ".fp32"
178
+ return self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix)
179
+
180
+ def import_diffusers_engine(self, diffusers_onnx_dir: str, engine_dir: str):
181
+ """Import optimized onnx models for diffusers from Olive or optimize_pipeline tools.
182
+
183
+ Args:
184
+ diffusers_onnx_dir (str): optimized onnx directory of Olive
185
+ engine_dir (str): the directory to store imported onnx
186
+ """
187
+ if version.parse(ort.__version__) < version.parse("1.17.0"):
188
+ print("Skip importing since onnxruntime-gpu version < 1.17.0.")
189
+ return
190
+
191
+ for model_name, model_obj in self.models.items():
192
+ onnx_import_path = self.optimized_onnx_path(diffusers_onnx_dir, model_name)
193
+ if not os.path.exists(onnx_import_path):
194
+ print(f"{onnx_import_path} not existed. Skip importing.")
195
+ continue
196
+
197
+ onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
198
+ if os.path.exists(onnx_opt_path):
199
+ print(f"{onnx_opt_path} existed. Skip importing.")
200
+ continue
201
+
202
+ if model_name == "vae" and self.pipeline_info.is_xl():
203
+ print(f"Skip importing VAE since it is not fully compatible with float16: {onnx_import_path}.")
204
+ continue
205
+
206
+ model = OnnxModel(onnx.load(onnx_import_path, load_external_data=True))
207
+
208
+ if model_name in ["clip", "clip2"]:
209
+ hidden_states_per_layer = []
210
+ for output in model.graph().output:
211
+ if output.name.startswith("hidden_states."):
212
+ hidden_states_per_layer.append(output.name)
213
+ if hidden_states_per_layer:
214
+ kept_hidden_states = hidden_states_per_layer[-2 - model_obj.clip_skip]
215
+ model.rename_graph_output(kept_hidden_states, "hidden_states")
216
+
217
+ model.rename_graph_output(
218
+ "last_hidden_state" if model_name == "clip" else "text_embeds", "text_embeddings"
219
+ )
220
+ model.prune_graph(
221
+ ["text_embeddings", "hidden_states"] if hidden_states_per_layer else ["text_embeddings"]
222
+ )
223
+
224
+ if model_name == "clip2":
225
+ model.change_graph_input_type(model.find_graph_input("input_ids"), onnx.TensorProto.INT32)
226
+
227
+ model.save_model_to_file(onnx_opt_path, use_external_data_format=(model_name == "clip2"))
228
+ elif model_name in ["unet", "unetxl"]:
229
+ model.rename_graph_output("out_sample", "latent")
230
+ model.save_model_to_file(onnx_opt_path, use_external_data_format=True)
231
+
232
+ del model
233
+ continue
234
+
235
+ def build_engines(
236
+ self,
237
+ engine_dir: str,
238
+ framework_model_dir: str,
239
+ onnx_dir: str,
240
+ tmp_dir: str | None = None,
241
+ onnx_opset_version: int = 17,
242
+ device_id: int = 0,
243
+ save_fp32_intermediate_model: bool = False,
244
+ import_engine_dir: str | None = None,
245
+ max_cuda_graphs: int = 1,
246
+ ):
247
+ self.torch_device = torch.device("cuda", device_id)
248
+ self.load_models(framework_model_dir)
249
+
250
+ if not os.path.isdir(engine_dir):
251
+ os.makedirs(engine_dir)
252
+
253
+ if not os.path.isdir(onnx_dir):
254
+ os.makedirs(onnx_dir)
255
+
256
+ # Add default configuration if missing
257
+ if self.pipeline_info.is_xl():
258
+ self.configure_xl(onnx_opset_version)
259
+ for model_name in self.models:
260
+ if model_name not in self.model_config:
261
+ self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph)
262
+
263
+ # Import Engine
264
+ if import_engine_dir:
265
+ if self.pipeline_info.is_xl():
266
+ self.import_diffusers_engine(import_engine_dir, engine_dir)
267
+ else:
268
+ print(f"Only support importing SDXL onnx. Ignore --engine-dir {import_engine_dir}")
269
+
270
+ # Load lora only when we need export text encoder or UNet to ONNX.
271
+ load_lora = False
272
+ if self.pipeline_info.lora_weights:
273
+ for model_name in self.models:
274
+ if model_name not in ["clip", "clip2", "unet", "unetxl"]:
275
+ continue
276
+ onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
277
+ onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
278
+ if not os.path.exists(onnx_opt_path):
279
+ if not os.path.exists(onnx_path):
280
+ load_lora = True
281
+ break
282
+
283
+ # Export models to ONNX
284
+ self.disable_torch_spda()
285
+ pipe = self.load_pipeline_with_lora() if load_lora else None
286
+
287
+ for model_name, model_obj in self.models.items():
288
+ if model_name == "vae" and self.vae_torch_fallback:
289
+ continue
290
+
291
+ onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
292
+ onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
293
+ if not os.path.exists(onnx_opt_path):
294
+ if not os.path.exists(onnx_path):
295
+ print("----")
296
+ logger.info("Exporting model: %s", onnx_path)
297
+
298
+ model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
299
+ model = model.to(torch.float32)
300
+
301
+ with torch.inference_mode():
302
+ # For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern.
303
+ # Export model with sample of batch size 1, image size 512 x 512
304
+ inputs = model_obj.get_sample_input(1, 512, 512)
305
+
306
+ torch.onnx.export(
307
+ model,
308
+ inputs,
309
+ onnx_path,
310
+ export_params=True,
311
+ opset_version=self.model_config[model_name].onnx_opset_version,
312
+ do_constant_folding=True,
313
+ input_names=model_obj.get_input_names(),
314
+ output_names=model_obj.get_output_names(),
315
+ dynamic_axes=model_obj.get_dynamic_axes(),
316
+ )
317
+ del model
318
+ torch.cuda.empty_cache()
319
+ gc.collect()
320
+ else:
321
+ logger.info("Found cached model: %s", onnx_path)
322
+
323
+ # Generate fp32 optimized model.
324
+ # If final target is fp16 model, we save fp32 optimized model so that it is easy to tune
325
+ # fp16 conversion. That could save a lot of time in developing.
326
+ use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16
327
+ onnx_fp32_path = onnx_path
328
+ if use_fp32_intermediate:
329
+ onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32")
330
+ if not os.path.exists(onnx_fp32_path):
331
+ print("------")
332
+ logger.info("Generating optimized model: %s", onnx_fp32_path)
333
+ model_obj.optimize_ort(
334
+ onnx_path,
335
+ onnx_fp32_path,
336
+ to_fp16=False,
337
+ fp32_op_list=self.model_config[model_name].force_fp32_ops,
338
+ optimize_by_ort=self.model_config[model_name].optimize_by_ort,
339
+ tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".fp32", create=False),
340
+ )
341
+ else:
342
+ logger.info("Found cached optimized model: %s", onnx_fp32_path)
343
+
344
+ # Generate the final optimized model.
345
+ if not os.path.exists(onnx_opt_path):
346
+ print("------")
347
+ logger.info("Generating optimized model: %s", onnx_opt_path)
348
+
349
+ # When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16.
350
+ optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort
351
+
352
+ model_obj.optimize_ort(
353
+ onnx_fp32_path,
354
+ onnx_opt_path,
355
+ to_fp16=self.model_config[model_name].fp16,
356
+ fp32_op_list=self.model_config[model_name].force_fp32_ops,
357
+ optimize_by_ort=optimize_by_ort,
358
+ optimize_by_fusion=not use_fp32_intermediate,
359
+ tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".ort", create=False),
360
+ )
361
+ else:
362
+ logger.info("Found cached optimized model: %s", onnx_opt_path)
363
+ self.enable_torch_spda()
364
+
365
+ built_engines = {}
366
+ for model_name in self.models:
367
+ if model_name == "vae" and self.vae_torch_fallback:
368
+ continue
369
+
370
+ onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
371
+ use_cuda_graph = self.model_config[model_name].use_cuda_graph
372
+
373
+ engine = OrtCudaEngine(
374
+ onnx_opt_path,
375
+ device_id=device_id,
376
+ enable_cuda_graph=use_cuda_graph,
377
+ disable_optimization=False,
378
+ max_cuda_graphs=max_cuda_graphs,
379
+ )
380
+
381
+ logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options)
382
+ built_engines[model_name] = engine
383
+
384
+ self.engines = built_engines
385
+
386
+ def run_engine(self, model_name, feed_dict):
387
+ return self.engines[model_name].infer(feed_dict)