onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,108 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ from diffusion_models import PipelineInfo
8
+ from engine_builder import EngineBuilder, EngineType
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class TorchEngineBuilder(EngineBuilder):
14
+ def __init__(
15
+ self,
16
+ pipeline_info: PipelineInfo,
17
+ max_batch_size=16,
18
+ device="cuda",
19
+ use_cuda_graph=False,
20
+ ):
21
+ """
22
+ Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
23
+
24
+ Args:
25
+ pipeline_info (PipelineInfo):
26
+ Version and Type of pipeline.
27
+ max_batch_size (int):
28
+ Maximum batch size for dynamic batch engine.
29
+ device (str):
30
+ device to run.
31
+ use_cuda_graph (bool):
32
+ Use CUDA graph to capture engine execution and then launch inference
33
+ """
34
+ super().__init__(
35
+ EngineType.TORCH,
36
+ pipeline_info,
37
+ max_batch_size=max_batch_size,
38
+ device=device,
39
+ use_cuda_graph=use_cuda_graph,
40
+ )
41
+
42
+ self.compile_config = {}
43
+ if use_cuda_graph:
44
+ self.compile_config = {
45
+ "clip": {"mode": "reduce-overhead", "dynamic": False},
46
+ "clip2": {"mode": "reduce-overhead", "dynamic": False},
47
+ "unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
48
+ "unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
49
+ "vae": {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False},
50
+ }
51
+
52
+ def build_engines(
53
+ self,
54
+ framework_model_dir: str,
55
+ ):
56
+ import torch
57
+
58
+ self.torch_device = torch.device("cuda", torch.cuda.current_device())
59
+ self.load_models(framework_model_dir)
60
+
61
+ pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None
62
+
63
+ built_engines = {}
64
+ for model_name, model_obj in self.models.items():
65
+ model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
66
+ if self.pipeline_info.is_xl() and not self.custom_fp16_vae:
67
+ model = model.to(device=self.torch_device, dtype=torch.float32)
68
+ else:
69
+ model = model.to(device=self.torch_device, dtype=torch.float16)
70
+
71
+ if model_name in self.compile_config:
72
+ compile_config = self.compile_config[model_name]
73
+ if model_name in ["unet", "unetxl"]:
74
+ model.to(memory_format=torch.channels_last)
75
+ engine = torch.compile(model, **compile_config)
76
+ built_engines[model_name] = engine
77
+ else: # eager mode
78
+ built_engines[model_name] = model
79
+
80
+ self.engines = built_engines
81
+
82
+ def run_engine(self, model_name, feed_dict):
83
+ if model_name in ["unet", "unetxl"]:
84
+ if "controlnet_images" in feed_dict:
85
+ return {"latent": self.engines[model_name](**feed_dict)}
86
+
87
+ if model_name == "unetxl":
88
+ added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]}
89
+ return {
90
+ "latent": self.engines[model_name](
91
+ feed_dict["sample"],
92
+ feed_dict["timestep"],
93
+ feed_dict["encoder_hidden_states"],
94
+ added_cond_kwargs=added_cond_kwargs,
95
+ return_dict=False,
96
+ )[0]
97
+ }
98
+
99
+ return {
100
+ "latent": self.engines[model_name](
101
+ feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False
102
+ )[0]
103
+ }
104
+
105
+ if model_name in ["vae_encoder"]:
106
+ return {"latent": self.engines[model_name](feed_dict["images"])}
107
+
108
+ raise RuntimeError(f"Shall not reach here: {model_name}")
@@ -0,0 +1,350 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ #
6
+ # This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
7
+ #
8
+ # Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint
9
+ # to float32 onnx models.
10
+ #
11
+ # For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16
12
+ # like the following:
13
+ # python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16
14
+ #
15
+ # Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support
16
+ # for the fused operators. The users could disable the operator fusion manually to workaround.
17
+
18
+ import argparse
19
+ import logging
20
+ import os
21
+ import shutil
22
+ import tempfile
23
+ from pathlib import Path
24
+ from typing import List, Optional
25
+
26
+ import __init__ # noqa: F401. Walk-around to run this script directly
27
+ import coloredlogs
28
+ import onnx
29
+ from fusion_options import FusionOptions
30
+ from onnx_model_clip import ClipOnnxModel
31
+ from onnx_model_unet import UnetOnnxModel
32
+ from onnx_model_vae import VaeOnnxModel
33
+ from optimizer import optimize_by_onnxruntime, optimize_model
34
+ from packaging import version
35
+
36
+ import onnxruntime
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ def has_external_data(onnx_model_path):
42
+ original_model = onnx.load_model(str(onnx_model_path), load_external_data=False)
43
+ for initializer in original_model.graph.initializer:
44
+ if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL:
45
+ return True
46
+ return False
47
+
48
+
49
+ def _optimize_sd_pipeline(
50
+ source_dir: Path,
51
+ target_dir: Path,
52
+ use_external_data_format: Optional[bool],
53
+ float16: bool,
54
+ force_fp32_ops: List[str],
55
+ enable_runtime_optimization: bool,
56
+ args,
57
+ ):
58
+ """Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
59
+
60
+ Args:
61
+ source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
62
+ target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
63
+ use_external_data_format (Optional[bool]): use external data format.
64
+ float16 (bool): use half precision
65
+ force_fp32_ops(List[str]): operators that are forced to run in float32.
66
+ enable_runtime_optimization(bool): run graph optimization using Onnx Runtime.
67
+
68
+ Raises:
69
+ RuntimeError: input onnx model does not exist
70
+ RuntimeError: output onnx model path existed
71
+ """
72
+ model_type_mapping = {
73
+ "unet": "unet",
74
+ "vae_encoder": "vae",
75
+ "vae_decoder": "vae",
76
+ "text_encoder": "clip",
77
+ "text_encoder_2": "clip",
78
+ "safety_checker": "unet",
79
+ }
80
+
81
+ model_type_class_mapping = {
82
+ "unet": UnetOnnxModel,
83
+ "vae": VaeOnnxModel,
84
+ "clip": ClipOnnxModel,
85
+ }
86
+
87
+ force_fp32_operators = {
88
+ "unet": [],
89
+ "vae_encoder": [],
90
+ "vae_decoder": [],
91
+ "text_encoder": [],
92
+ "text_encoder_2": [],
93
+ "safety_checker": [],
94
+ }
95
+
96
+ is_xl = (source_dir / "text_encoder_2").exists()
97
+
98
+ if force_fp32_ops:
99
+ for fp32_operator in force_fp32_ops:
100
+ parts = fp32_operator.split(":")
101
+ if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()):
102
+ force_fp32_operators[parts[0]].append(parts[1])
103
+ else:
104
+ raise ValueError(
105
+ f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}"
106
+ )
107
+
108
+ for name, model_type in model_type_mapping.items():
109
+ onnx_model_path = source_dir / name / "model.onnx"
110
+ if not os.path.exists(onnx_model_path):
111
+ if name != "safety_checker":
112
+ logger.info("input onnx model does not exist: %s", onnx_model_path)
113
+ # some model are optional so we do not raise error here.
114
+ continue
115
+
116
+ # Prepare output directory
117
+ optimized_model_path = target_dir / name / "model.onnx"
118
+ output_dir = optimized_model_path.parent
119
+ output_dir.mkdir(parents=True, exist_ok=True)
120
+
121
+ if use_external_data_format is None:
122
+ use_external_data_format = has_external_data(onnx_model_path)
123
+
124
+ # Graph fusion before fp16 conversion, otherwise they cannot be fused later.
125
+ logger.info(f"Optimize {onnx_model_path}...")
126
+
127
+ args.model_type = model_type
128
+ fusion_options = FusionOptions.parse(args)
129
+
130
+ if model_type in ["unet"]:
131
+ # Some optimizations are not available in v1.14 or older version: packed QKV and BiasAdd
132
+ has_all_optimizations = version.parse(onnxruntime.__version__) >= version.parse("1.15.0")
133
+ fusion_options.enable_packed_kv = float16 and fusion_options.enable_packed_kv
134
+ fusion_options.enable_packed_qkv = float16 and has_all_optimizations and fusion_options.enable_packed_qkv
135
+ fusion_options.enable_bias_add = has_all_optimizations and fusion_options.enable_bias_add
136
+
137
+ m = optimize_model(
138
+ str(onnx_model_path),
139
+ model_type=model_type,
140
+ num_heads=0, # will be deduced from graph
141
+ hidden_size=0, # will be deduced from graph
142
+ opt_level=0,
143
+ optimization_options=fusion_options,
144
+ use_gpu=True,
145
+ provider=args.provider,
146
+ )
147
+
148
+ if float16:
149
+ # For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
150
+ if is_xl and name == "vae_decoder":
151
+ logger.info("Skip converting %s to float16 to avoid NaN", name)
152
+ else:
153
+ logger.info("Convert %s to float16 ...", name)
154
+ m.convert_float_to_float16(
155
+ keep_io_types=False,
156
+ op_block_list=force_fp32_operators[name],
157
+ )
158
+
159
+ if enable_runtime_optimization:
160
+ # Use this step to see the final graph that executed by Onnx Runtime.
161
+ with tempfile.TemporaryDirectory() as tmp_dir:
162
+ # Save to a temporary file so that we can load it with Onnx Runtime.
163
+ logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
164
+ tmp_model_path = Path(tmp_dir) / "model.onnx"
165
+ m.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
166
+ ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
167
+ optimize_by_onnxruntime(
168
+ str(tmp_model_path),
169
+ use_gpu=True,
170
+ provider=args.provider,
171
+ optimized_model_path=str(ort_optimized_model_path),
172
+ save_as_external_data=use_external_data_format,
173
+ )
174
+ model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
175
+ m = model_type_class_mapping[model_type](model)
176
+
177
+ m.get_operator_statistics()
178
+ m.get_fused_operator_statistics()
179
+ m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
180
+ logger.info("%s is optimized", name)
181
+ logger.info("*" * 20)
182
+
183
+
184
+ def _copy_extra_directory(source_dir: Path, target_dir: Path):
185
+ """Copy extra directory that does not have onnx model
186
+
187
+ Args:
188
+ source_dir (Path): source directory
189
+ target_dir (Path): target directory
190
+
191
+ Raises:
192
+ RuntimeError: source path does not exist
193
+ """
194
+ extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "feature_extractor"]
195
+
196
+ for name in extra_dirs:
197
+ source_path = source_dir / name
198
+ if not os.path.exists(source_path):
199
+ continue
200
+
201
+ target_path = target_dir / name
202
+ shutil.copytree(source_path, target_path)
203
+ logger.info("%s => %s", source_path, target_path)
204
+
205
+ extra_files = ["model_index.json"]
206
+ for name in extra_files:
207
+ source_path = source_dir / name
208
+ if not os.path.exists(source_path):
209
+ raise RuntimeError(f"source path does not exist: {source_path}")
210
+
211
+ target_path = target_dir / name
212
+ shutil.copyfile(source_path, target_path)
213
+ logger.info("%s => %s", source_path, target_path)
214
+
215
+ # Some directory are optional
216
+ onnx_model_dirs = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder", "safety_checker"]
217
+ for onnx_model_dir in onnx_model_dirs:
218
+ source_path = source_dir / onnx_model_dir / "config.json"
219
+ target_path = target_dir / onnx_model_dir / "config.json"
220
+ if source_path.exists():
221
+ target_path.parent.mkdir(parents=True, exist_ok=True)
222
+ shutil.copyfile(source_path, target_path)
223
+ logger.info("%s => %s", source_path, target_path)
224
+
225
+
226
+ def optimize_stable_diffusion_pipeline(
227
+ input_dir: str,
228
+ output_dir: str,
229
+ overwrite: bool,
230
+ use_external_data_format: Optional[bool],
231
+ float16: bool,
232
+ enable_runtime_optimization: bool,
233
+ args,
234
+ ):
235
+ if os.path.exists(output_dir):
236
+ if overwrite:
237
+ shutil.rmtree(output_dir, ignore_errors=True)
238
+ else:
239
+ raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.")
240
+
241
+ source_dir = Path(input_dir)
242
+ target_dir = Path(output_dir)
243
+ target_dir.mkdir(parents=True, exist_ok=True)
244
+
245
+ _copy_extra_directory(source_dir, target_dir)
246
+
247
+ _optimize_sd_pipeline(
248
+ source_dir,
249
+ target_dir,
250
+ use_external_data_format,
251
+ float16,
252
+ args.force_fp32_ops,
253
+ enable_runtime_optimization,
254
+ args,
255
+ )
256
+
257
+
258
+ def parse_arguments(argv: Optional[List[str]] = None):
259
+ """Parse arguments
260
+
261
+ Returns:
262
+ Namespace: arguments
263
+ """
264
+ parser = argparse.ArgumentParser()
265
+
266
+ parser.add_argument(
267
+ "-i",
268
+ "--input",
269
+ required=True,
270
+ type=str,
271
+ help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
272
+ )
273
+
274
+ parser.add_argument(
275
+ "-o",
276
+ "--output",
277
+ required=True,
278
+ type=str,
279
+ help="Root of output directory of stable diffusion onnx pipeline with optimized models.",
280
+ )
281
+
282
+ parser.add_argument(
283
+ "--float16",
284
+ required=False,
285
+ action="store_true",
286
+ help="Output models of half or mixed precision.",
287
+ )
288
+ parser.set_defaults(float16=False)
289
+
290
+ parser.add_argument(
291
+ "--force_fp32_ops",
292
+ required=False,
293
+ nargs="+",
294
+ type=str,
295
+ help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!",
296
+ )
297
+
298
+ parser.add_argument(
299
+ "--inspect",
300
+ required=False,
301
+ action="store_true",
302
+ help="Save the optimized graph from Onnx Runtime. "
303
+ "This option has no impact on inference performance except it might reduce session creation time.",
304
+ )
305
+ parser.set_defaults(inspect=False)
306
+
307
+ parser.add_argument(
308
+ "--overwrite",
309
+ required=False,
310
+ action="store_true",
311
+ help="Overwrite exists files.",
312
+ )
313
+ parser.set_defaults(overwrite=False)
314
+
315
+ parser.add_argument(
316
+ "-e",
317
+ "--use_external_data_format",
318
+ required=False,
319
+ action="store_true",
320
+ help="Onnx model larger than 2GB need to use external data format. "
321
+ "If specified, save each onnx model to two files: one for onnx graph, another for weights. "
322
+ "If not specified, use same format as original model by default. ",
323
+ )
324
+ parser.set_defaults(use_external_data_format=None)
325
+
326
+ parser.add_argument(
327
+ "--provider",
328
+ required=False,
329
+ type=str,
330
+ default=None,
331
+ help="Execution provider to use.",
332
+ )
333
+
334
+ FusionOptions.add_arguments(parser)
335
+
336
+ args = parser.parse_args(argv)
337
+ return args
338
+
339
+
340
+ def main(argv: Optional[List[str]] = None):
341
+ args = parse_arguments(argv)
342
+ logger.info("Arguments: %s", str(args))
343
+ optimize_stable_diffusion_pipeline(
344
+ args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args
345
+ )
346
+
347
+
348
+ if __name__ == "__main__":
349
+ coloredlogs.install(fmt="%(funcName)20s: %(message)s")
350
+ main()
@@ -0,0 +1,136 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ """
7
+ ONNX Model Optimizer for Stable Diffusion
8
+ """
9
+
10
+ import gc
11
+ import logging
12
+ import os
13
+ import shutil
14
+ import tempfile
15
+ from pathlib import Path
16
+
17
+ import onnx
18
+ from packaging import version
19
+
20
+ from onnxruntime.transformers.fusion_options import FusionOptions
21
+ from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
22
+ from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
23
+ from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
24
+ from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class OrtStableDiffusionOptimizer:
30
+ def __init__(self, model_type: str):
31
+ assert model_type in ["vae", "unet", "clip"]
32
+ self.model_type = model_type
33
+ self.model_type_class_mapping = {
34
+ "unet": UnetOnnxModel,
35
+ "vae": VaeOnnxModel,
36
+ "clip": ClipOnnxModel,
37
+ }
38
+
39
+ def _optimize_by_ort(self, onnx_model, use_external_data_format, tmp_dir):
40
+ # Save to a temporary file so that we can load it with Onnx Runtime.
41
+ logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
42
+ tmp_model_path = Path(tmp_dir) / "model.onnx"
43
+ onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
44
+
45
+ del onnx_model
46
+ gc.collect()
47
+
48
+ ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
49
+ optimize_by_onnxruntime(
50
+ str(tmp_model_path),
51
+ use_gpu=True,
52
+ optimized_model_path=str(ort_optimized_model_path),
53
+ save_as_external_data=use_external_data_format,
54
+ external_data_filename="optimized.onnx_data",
55
+ )
56
+ model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
57
+ return self.model_type_class_mapping[self.model_type](model)
58
+
59
+ def optimize_by_ort(self, onnx_model, use_external_data_format=False, tmp_dir=None):
60
+ # Use this step to see the final graph that executed by Onnx Runtime.
61
+ if tmp_dir is None:
62
+ with tempfile.TemporaryDirectory() as temp_dir:
63
+ return self._optimize_by_ort(onnx_model, use_external_data_format, temp_dir)
64
+ else:
65
+ os.makedirs(tmp_dir, exist_ok=True)
66
+ model = self._optimize_by_ort(onnx_model, use_external_data_format, tmp_dir)
67
+ shutil.rmtree(tmp_dir)
68
+ return model
69
+
70
+ def optimize(
71
+ self,
72
+ input_fp32_onnx_path,
73
+ optimized_onnx_path,
74
+ float16=True,
75
+ keep_io_types=False,
76
+ fp32_op_list=None,
77
+ keep_outputs=None,
78
+ optimize_by_ort=True,
79
+ optimize_by_fusion=True,
80
+ final_target_float16=True,
81
+ tmp_dir=None,
82
+ ):
83
+ """Optimize onnx model using ONNX Runtime transformers optimizer"""
84
+ logger.info(f"Optimize {input_fp32_onnx_path}...")
85
+
86
+ if optimize_by_fusion:
87
+ fusion_options = FusionOptions(self.model_type)
88
+
89
+ # It is allowed float16=False and final_target_float16=True, for using fp32 as intermediate optimization step.
90
+ # For rare fp32 use case, we can disable packed kv/qkv since there is no fp32 TRT fused attention kernel.
91
+ if self.model_type in ["unet"] and not final_target_float16:
92
+ fusion_options.enable_packed_kv = False
93
+ fusion_options.enable_packed_qkv = False
94
+
95
+ m = optimize_model(
96
+ input_fp32_onnx_path,
97
+ model_type=self.model_type,
98
+ num_heads=0, # will be deduced from graph
99
+ hidden_size=0, # will be deduced from graph
100
+ opt_level=0,
101
+ optimization_options=fusion_options,
102
+ use_gpu=True,
103
+ )
104
+ else:
105
+ model = onnx.load_model(input_fp32_onnx_path, load_external_data=True)
106
+ m = self.model_type_class_mapping[self.model_type](model)
107
+
108
+ if keep_outputs:
109
+ m.prune_graph(outputs=keep_outputs)
110
+
111
+ model_size = m.model.ByteSize()
112
+
113
+ # model size might be negative (overflow?) in Windows.
114
+ use_external_data_format = model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF
115
+
116
+ # Note that ORT < 1.16 could not save model larger than 2GB.
117
+ # This step is is optional since it has no impact on inference latency.
118
+ # The optimized model is not portable. It could only run in the same execution provider (CUDA EP in this case).
119
+ # When the model has been optimized by onnxruntime, we can disable optimization in SessionOption
120
+ # to save session creation time. Another benefit is to inspect the final graph for developing purpose.
121
+ from onnxruntime import __version__ as ort_version
122
+
123
+ if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format):
124
+ m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format, tmp_dir=tmp_dir)
125
+
126
+ if float16:
127
+ logger.info("Convert to float16 ...")
128
+ m.convert_float_to_float16(
129
+ keep_io_types=keep_io_types,
130
+ op_block_list=fp32_op_list,
131
+ )
132
+
133
+ m.get_operator_statistics()
134
+ m.get_fused_operator_statistics()
135
+ m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
136
+ logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)