onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,288 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import gc
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+ from cuda import cudart
12
+ from diffusion_models import PipelineInfo
13
+ from engine_builder import EngineBuilder, EngineType
14
+ from packaging import version
15
+
16
+ import onnxruntime as ort
17
+ from onnxruntime.transformers.io_binding_helper import CudaSession
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class OrtTensorrtEngine(CudaSession):
23
+ def __init__(
24
+ self,
25
+ engine_path,
26
+ device_id,
27
+ onnx_path,
28
+ fp16,
29
+ input_profile,
30
+ workspace_size,
31
+ enable_cuda_graph,
32
+ timing_cache_path=None,
33
+ ):
34
+ self.engine_path = engine_path
35
+ self.ort_trt_provider_options = self.get_tensorrt_provider_options(
36
+ input_profile,
37
+ workspace_size,
38
+ fp16,
39
+ device_id,
40
+ enable_cuda_graph,
41
+ timing_cache_path=timing_cache_path,
42
+ )
43
+
44
+ session_options = ort.SessionOptions()
45
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
46
+ logger.info("creating TRT EP session for %s", onnx_path)
47
+ ort_session = ort.InferenceSession(
48
+ onnx_path,
49
+ session_options,
50
+ providers=[
51
+ ("TensorrtExecutionProvider", self.ort_trt_provider_options),
52
+ ],
53
+ )
54
+ logger.info("created TRT EP session for %s", onnx_path)
55
+
56
+ device = torch.device("cuda", device_id)
57
+ super().__init__(ort_session, device, enable_cuda_graph)
58
+
59
+ def get_tensorrt_provider_options(
60
+ self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph, timing_cache_path=None
61
+ ):
62
+ trt_ep_options = {
63
+ "device_id": device_id,
64
+ "trt_fp16_enable": fp16,
65
+ "trt_engine_cache_enable": True,
66
+ "trt_timing_cache_enable": True,
67
+ "trt_detailed_build_log": True,
68
+ "trt_engine_cache_path": self.engine_path,
69
+ }
70
+
71
+ if version.parse(ort.__version__) > version.parse("1.16.2") and timing_cache_path is not None:
72
+ trt_ep_options["trt_timing_cache_path"] = timing_cache_path
73
+
74
+ if enable_cuda_graph:
75
+ trt_ep_options["trt_cuda_graph_enable"] = True
76
+
77
+ if workspace_size > 0:
78
+ trt_ep_options["trt_max_workspace_size"] = workspace_size
79
+
80
+ if input_profile:
81
+ min_shapes = []
82
+ max_shapes = []
83
+ opt_shapes = []
84
+ for name, profile in input_profile.items():
85
+ assert isinstance(profile, list) and len(profile) == 3
86
+ min_shape = profile[0]
87
+ opt_shape = profile[1]
88
+ max_shape = profile[2]
89
+ assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape)
90
+
91
+ min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape]))
92
+ opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape]))
93
+ max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape]))
94
+
95
+ trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes)
96
+ trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes)
97
+ trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes)
98
+
99
+ logger.info("trt_ep_options=%s", trt_ep_options)
100
+
101
+ return trt_ep_options
102
+
103
+ def allocate_buffers(self, shape_dict, device):
104
+ super().allocate_buffers(shape_dict)
105
+
106
+
107
+ class OrtTensorrtEngineBuilder(EngineBuilder):
108
+ def __init__(
109
+ self,
110
+ pipeline_info: PipelineInfo,
111
+ max_batch_size=16,
112
+ device="cuda",
113
+ use_cuda_graph=False,
114
+ ):
115
+ """
116
+ Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
117
+
118
+ Args:
119
+ pipeline_info (PipelineInfo):
120
+ Version and Type of pipeline.
121
+ max_batch_size (int):
122
+ Maximum batch size for dynamic batch engine.
123
+ device (str):
124
+ device to run.
125
+ use_cuda_graph (bool):
126
+ Use CUDA graph to capture engine execution and then launch inference
127
+ """
128
+ super().__init__(
129
+ EngineType.ORT_TRT,
130
+ pipeline_info,
131
+ max_batch_size=max_batch_size,
132
+ device=device,
133
+ use_cuda_graph=use_cuda_graph,
134
+ )
135
+
136
+ def has_engine_file(self, engine_path):
137
+ if os.path.isdir(engine_path):
138
+ children = os.scandir(engine_path)
139
+ for entry in children:
140
+ if entry.is_file() and entry.name.endswith(".engine"):
141
+ return True
142
+ return False
143
+
144
+ def get_work_space_size(self, model_name, max_workspace_size):
145
+ gibibyte = 2**30
146
+ workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size
147
+ if workspace_size == 0:
148
+ _, free_mem, _ = cudart.cudaMemGetInfo()
149
+ # The following logic are adopted from TensorRT demo diffusion.
150
+ if free_mem > 6 * gibibyte:
151
+ workspace_size = free_mem - 4 * gibibyte
152
+ return workspace_size
153
+
154
+ def build_engines(
155
+ self,
156
+ engine_dir,
157
+ framework_model_dir,
158
+ onnx_dir,
159
+ onnx_opset,
160
+ opt_image_height,
161
+ opt_image_width,
162
+ opt_batch_size=1,
163
+ static_batch=False,
164
+ static_image_shape=True,
165
+ max_workspace_size=0,
166
+ device_id=0,
167
+ timing_cache=None,
168
+ ):
169
+ self.torch_device = torch.device("cuda", device_id)
170
+ self.load_models(framework_model_dir)
171
+
172
+ if not os.path.isdir(engine_dir):
173
+ os.makedirs(engine_dir)
174
+
175
+ if not os.path.isdir(onnx_dir):
176
+ os.makedirs(onnx_dir)
177
+
178
+ # Load lora only when we need export text encoder or UNet to ONNX.
179
+ load_lora = False
180
+ if self.pipeline_info.lora_weights:
181
+ for model_name, model_obj in self.models.items():
182
+ if model_name not in ["clip", "clip2", "unet", "unetxl"]:
183
+ continue
184
+ profile_id = model_obj.get_profile_id(
185
+ opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
186
+ )
187
+ engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
188
+ if not self.has_engine_file(engine_path):
189
+ onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
190
+ onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
191
+ if not os.path.exists(onnx_opt_path):
192
+ if not os.path.exists(onnx_path):
193
+ load_lora = True
194
+ break
195
+
196
+ # Export models to ONNX
197
+ self.disable_torch_spda()
198
+ pipe = self.load_pipeline_with_lora() if load_lora else None
199
+
200
+ for model_name, model_obj in self.models.items():
201
+ if model_name == "vae" and self.vae_torch_fallback:
202
+ continue
203
+
204
+ profile_id = model_obj.get_profile_id(
205
+ opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
206
+ )
207
+ engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
208
+ if not self.has_engine_file(engine_path):
209
+ onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
210
+ onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
211
+ if not os.path.exists(onnx_opt_path):
212
+ if not os.path.exists(onnx_path):
213
+ logger.info(f"Exporting model: {onnx_path}")
214
+ model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
215
+
216
+ with torch.inference_mode(), torch.autocast("cuda"):
217
+ inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
218
+ torch.onnx.export(
219
+ model,
220
+ inputs,
221
+ onnx_path,
222
+ export_params=True,
223
+ opset_version=onnx_opset,
224
+ do_constant_folding=True,
225
+ input_names=model_obj.get_input_names(),
226
+ output_names=model_obj.get_output_names(),
227
+ dynamic_axes=model_obj.get_dynamic_axes(),
228
+ )
229
+ del model
230
+ torch.cuda.empty_cache()
231
+ gc.collect()
232
+ else:
233
+ logger.info("Found cached model: %s", onnx_path)
234
+
235
+ # Optimize onnx
236
+ if not os.path.exists(onnx_opt_path):
237
+ logger.info("Generating optimizing model: %s", onnx_opt_path)
238
+ model_obj.optimize_trt(onnx_path, onnx_opt_path)
239
+ else:
240
+ logger.info("Found cached optimized model: %s", onnx_opt_path)
241
+ self.enable_torch_spda()
242
+
243
+ built_engines = {}
244
+ for model_name, model_obj in self.models.items():
245
+ if model_name == "vae" and self.vae_torch_fallback:
246
+ continue
247
+
248
+ profile_id = model_obj.get_profile_id(
249
+ opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
250
+ )
251
+
252
+ engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
253
+ onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
254
+ if not self.has_engine_file(engine_path):
255
+ logger.info(
256
+ "Building TensorRT engine for %s from %s to %s. It can take a while to complete...",
257
+ model_name,
258
+ onnx_opt_path,
259
+ engine_path,
260
+ )
261
+ else:
262
+ logger.info("Reuse cached TensorRT engine in directory %s", engine_path)
263
+
264
+ input_profile = model_obj.get_input_profile(
265
+ opt_batch_size,
266
+ opt_image_height,
267
+ opt_image_width,
268
+ static_batch=static_batch,
269
+ static_image_shape=static_image_shape,
270
+ )
271
+
272
+ engine = OrtTensorrtEngine(
273
+ engine_path,
274
+ device_id,
275
+ onnx_opt_path,
276
+ fp16=True,
277
+ input_profile=input_profile,
278
+ workspace_size=self.get_work_space_size(model_name, max_workspace_size),
279
+ enable_cuda_graph=self.use_cuda_graph,
280
+ timing_cache_path=timing_cache,
281
+ )
282
+
283
+ built_engines[model_name] = engine
284
+
285
+ self.engines = built_engines
286
+
287
+ def run_engine(self, model_name, feed_dict):
288
+ return self.engines[model_name].infer(feed_dict)
@@ -0,0 +1,395 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ # Modified from TensorRT demo diffusion, which has the following license:
6
+ #
7
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
8
+ # SPDX-License-Identifier: Apache-2.0
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ # --------------------------------------------------------------------------
22
+
23
+ import gc
24
+ import os
25
+ import pathlib
26
+ from collections import OrderedDict
27
+
28
+ import numpy as np
29
+ import tensorrt as trt
30
+ import torch
31
+ from cuda import cudart
32
+ from diffusion_models import PipelineInfo
33
+ from engine_builder import EngineBuilder, EngineType
34
+ from polygraphy.backend.common import bytes_from_path
35
+ from polygraphy.backend.trt import (
36
+ CreateConfig,
37
+ ModifyNetworkOutputs,
38
+ Profile,
39
+ engine_from_bytes,
40
+ engine_from_network,
41
+ network_from_onnx_path,
42
+ save_engine,
43
+ )
44
+
45
+ # Map of numpy dtype -> torch dtype
46
+ numpy_to_torch_dtype_dict = {
47
+ np.int32: torch.int32,
48
+ np.int64: torch.int64,
49
+ np.float16: torch.float16,
50
+ np.float32: torch.float32,
51
+ }
52
+
53
+
54
+ def _cuda_assert(cuda_ret):
55
+ err = cuda_ret[0]
56
+ if err != cudart.cudaError_t.cudaSuccess:
57
+ raise RuntimeError(
58
+ f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
59
+ )
60
+ if len(cuda_ret) > 1:
61
+ return cuda_ret[1]
62
+ return None
63
+
64
+
65
+ class TensorrtEngine:
66
+ def __init__(
67
+ self,
68
+ engine_path,
69
+ ):
70
+ self.engine_path = engine_path
71
+ self.engine = None
72
+ self.context = None
73
+ self.buffers = OrderedDict()
74
+ self.tensors = OrderedDict()
75
+ self.cuda_graph_instance = None
76
+
77
+ def __del__(self):
78
+ del self.engine
79
+ del self.context
80
+ del self.buffers
81
+ del self.tensors
82
+
83
+ def build(
84
+ self,
85
+ onnx_path,
86
+ fp16,
87
+ input_profile=None,
88
+ enable_all_tactics=False,
89
+ timing_cache=None,
90
+ update_output_names=None,
91
+ ):
92
+ print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
93
+ p = Profile()
94
+ if input_profile:
95
+ for name, dims in input_profile.items():
96
+ assert len(dims) == 3
97
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
98
+
99
+ config_kwargs = {}
100
+ if not enable_all_tactics:
101
+ config_kwargs["tactic_sources"] = []
102
+
103
+ network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
104
+ if update_output_names:
105
+ print(f"Updating network outputs to {update_output_names}")
106
+ network = ModifyNetworkOutputs(network, update_output_names)
107
+ engine = engine_from_network(
108
+ network,
109
+ config=CreateConfig(
110
+ fp16=fp16, refittable=False, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
111
+ ),
112
+ save_timing_cache=timing_cache,
113
+ )
114
+ save_engine(engine, path=self.engine_path)
115
+
116
+ def load(self):
117
+ print(f"Loading TensorRT engine: {self.engine_path}")
118
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
119
+
120
+ def activate(self, reuse_device_memory=None):
121
+ if reuse_device_memory:
122
+ self.context = self.engine.create_execution_context_without_device_memory()
123
+ self.context.device_memory = reuse_device_memory
124
+ else:
125
+ self.context = self.engine.create_execution_context()
126
+
127
+ def allocate_buffers(self, shape_dict=None, device="cuda"):
128
+ for idx in range(self.engine.num_io_tensors):
129
+ binding = self.engine[idx]
130
+ if shape_dict and binding in shape_dict:
131
+ shape = shape_dict[binding]
132
+ else:
133
+ shape = self.engine.get_binding_shape(binding)
134
+ dtype = trt.nptype(self.engine.get_binding_dtype(binding))
135
+ if self.engine.binding_is_input(binding):
136
+ self.context.set_binding_shape(idx, shape)
137
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
138
+ self.tensors[binding] = tensor
139
+
140
+ def infer(self, feed_dict, stream, use_cuda_graph=False):
141
+ for name, buf in feed_dict.items():
142
+ self.tensors[name].copy_(buf)
143
+
144
+ for name, tensor in self.tensors.items():
145
+ self.context.set_tensor_address(name, tensor.data_ptr())
146
+
147
+ if use_cuda_graph:
148
+ if self.cuda_graph_instance is not None:
149
+ _cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
150
+ _cuda_assert(cudart.cudaStreamSynchronize(stream))
151
+ else:
152
+ # do inference before CUDA graph capture
153
+ noerror = self.context.execute_async_v3(stream)
154
+ if not noerror:
155
+ raise ValueError("ERROR: inference failed.")
156
+ # capture cuda graph
157
+ _cuda_assert(
158
+ cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
159
+ )
160
+ self.context.execute_async_v3(stream)
161
+ self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream))
162
+
163
+ from cuda import nvrtc # noqa: PLC0415
164
+
165
+ result, major, minor = nvrtc.nvrtcVersion()
166
+ assert result == nvrtc.nvrtcResult(0)
167
+ if major < 12:
168
+ self.cuda_graph_instance = _cuda_assert(
169
+ cudart.cudaGraphInstantiate(self.graph, b"", 0)
170
+ ) # cuda < 12
171
+ else:
172
+ self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12
173
+ else:
174
+ noerror = self.context.execute_async_v3(stream)
175
+ if not noerror:
176
+ raise ValueError("ERROR: inference failed.")
177
+
178
+ return self.tensors
179
+
180
+
181
+ class TensorrtEngineBuilder(EngineBuilder):
182
+ """
183
+ Helper class to hide the detail of TensorRT Engine from pipeline.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ pipeline_info: PipelineInfo,
189
+ max_batch_size=16,
190
+ device="cuda",
191
+ use_cuda_graph=False,
192
+ ):
193
+ """
194
+ Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
195
+
196
+ Args:
197
+ pipeline_info (PipelineInfo):
198
+ Version and Type of pipeline.
199
+ max_batch_size (int):
200
+ Maximum batch size for dynamic batch engine.
201
+ device (str):
202
+ device to run.
203
+ use_cuda_graph (bool):
204
+ Use CUDA graph to capture engine execution and then launch inference
205
+ """
206
+ super().__init__(
207
+ EngineType.TRT,
208
+ pipeline_info,
209
+ max_batch_size=max_batch_size,
210
+ device=device,
211
+ use_cuda_graph=use_cuda_graph,
212
+ )
213
+
214
+ self.stream = None
215
+ self.shared_device_memory = None
216
+
217
+ def load_resources(self, image_height, image_width, batch_size):
218
+ super().load_resources(image_height, image_width, batch_size)
219
+
220
+ self.stream = _cuda_assert(cudart.cudaStreamCreate())
221
+
222
+ def teardown(self):
223
+ super().teardown()
224
+
225
+ if self.shared_device_memory:
226
+ cudart.cudaFree(self.shared_device_memory)
227
+
228
+ cudart.cudaStreamDestroy(self.stream)
229
+ del self.stream
230
+
231
+ def load_engines(
232
+ self,
233
+ engine_dir,
234
+ framework_model_dir,
235
+ onnx_dir,
236
+ onnx_opset,
237
+ opt_batch_size,
238
+ opt_image_height,
239
+ opt_image_width,
240
+ static_batch=False,
241
+ static_shape=True,
242
+ enable_all_tactics=False,
243
+ timing_cache=None,
244
+ ):
245
+ """
246
+ Build and load engines for TensorRT accelerated inference.
247
+ Export ONNX models first, if applicable.
248
+
249
+ Args:
250
+ engine_dir (str):
251
+ Directory to write the TensorRT engines.
252
+ framework_model_dir (str):
253
+ Directory to write the framework model ckpt.
254
+ onnx_dir (str):
255
+ Directory to write the ONNX models.
256
+ onnx_opset (int):
257
+ ONNX opset version to export the models.
258
+ opt_batch_size (int):
259
+ Batch size to optimize for during engine building.
260
+ opt_image_height (int):
261
+ Image height to optimize for during engine building. Must be a multiple of 8.
262
+ opt_image_width (int):
263
+ Image width to optimize for during engine building. Must be a multiple of 8.
264
+ static_batch (bool):
265
+ Build engine only for specified opt_batch_size.
266
+ static_shape (bool):
267
+ Build engine only for specified opt_image_height & opt_image_width. Default = True.
268
+ enable_all_tactics (bool):
269
+ Enable all tactic sources during TensorRT engine builds.
270
+ timing_cache (str):
271
+ Path to the timing cache to accelerate build or None
272
+ """
273
+ # Create directory
274
+ for directory in [engine_dir, onnx_dir]:
275
+ if not os.path.exists(directory):
276
+ print(f"[I] Create directory: {directory}")
277
+ pathlib.Path(directory).mkdir(parents=True)
278
+
279
+ self.load_models(framework_model_dir)
280
+
281
+ # Load lora only when we need export text encoder or UNet to ONNX.
282
+ load_lora = False
283
+ if self.pipeline_info.lora_weights:
284
+ for model_name, model_obj in self.models.items():
285
+ if model_name not in ["clip", "clip2", "unet", "unetxl"]:
286
+ continue
287
+ profile_id = model_obj.get_profile_id(
288
+ opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
289
+ )
290
+ engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
291
+ if not os.path.exists(engine_path):
292
+ onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
293
+ onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
294
+ if not os.path.exists(onnx_opt_path):
295
+ if not os.path.exists(onnx_path):
296
+ load_lora = True
297
+ break
298
+
299
+ # Export models to ONNX
300
+ self.disable_torch_spda()
301
+ pipe = self.load_pipeline_with_lora() if load_lora else None
302
+
303
+ for model_name, model_obj in self.models.items():
304
+ if model_name == "vae" and self.vae_torch_fallback:
305
+ continue
306
+ profile_id = model_obj.get_profile_id(
307
+ opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
308
+ )
309
+ engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
310
+ if not os.path.exists(engine_path):
311
+ onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
312
+ onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
313
+ if not os.path.exists(onnx_opt_path):
314
+ if not os.path.exists(onnx_path):
315
+ print(f"Exporting model: {onnx_path}")
316
+ model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
317
+
318
+ with torch.inference_mode(), torch.autocast("cuda"):
319
+ inputs = model_obj.get_sample_input(1, opt_image_height, opt_image_width)
320
+ torch.onnx.export(
321
+ model,
322
+ inputs,
323
+ onnx_path,
324
+ export_params=True,
325
+ opset_version=onnx_opset,
326
+ do_constant_folding=True,
327
+ input_names=model_obj.get_input_names(),
328
+ output_names=model_obj.get_output_names(),
329
+ dynamic_axes=model_obj.get_dynamic_axes(),
330
+ )
331
+ del model
332
+ torch.cuda.empty_cache()
333
+ gc.collect()
334
+ else:
335
+ print(f"Found cached model: {onnx_path}")
336
+
337
+ # Optimize onnx
338
+ if not os.path.exists(onnx_opt_path):
339
+ print(f"Generating optimizing model: {onnx_opt_path}")
340
+ model_obj.optimize_trt(onnx_path, onnx_opt_path)
341
+ else:
342
+ print(f"Found cached optimized model: {onnx_opt_path} ")
343
+ self.enable_torch_spda()
344
+
345
+ # Build TensorRT engines
346
+ for model_name, model_obj in self.models.items():
347
+ if model_name == "vae" and self.vae_torch_fallback:
348
+ continue
349
+ profile_id = model_obj.get_profile_id(
350
+ opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
351
+ )
352
+ engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
353
+ engine = TensorrtEngine(engine_path)
354
+ onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
355
+
356
+ if not os.path.exists(engine.engine_path):
357
+ engine.build(
358
+ onnx_opt_path,
359
+ fp16=True,
360
+ input_profile=model_obj.get_input_profile(
361
+ opt_batch_size,
362
+ opt_image_height,
363
+ opt_image_width,
364
+ static_batch,
365
+ static_shape,
366
+ ),
367
+ enable_all_tactics=enable_all_tactics,
368
+ timing_cache=timing_cache,
369
+ update_output_names=None,
370
+ )
371
+ self.engines[model_name] = engine
372
+
373
+ # Load TensorRT engines
374
+ for model_name in self.models:
375
+ if model_name == "vae" and self.vae_torch_fallback:
376
+ continue
377
+ self.engines[model_name].load()
378
+
379
+ def max_device_memory(self):
380
+ max_device_memory = 0
381
+ for engine in self.engines.values():
382
+ max_device_memory = max(max_device_memory, engine.engine.device_memory_size)
383
+ return max_device_memory
384
+
385
+ def activate_engines(self, shared_device_memory=None):
386
+ if shared_device_memory is None:
387
+ max_device_memory = self.max_device_memory()
388
+ _, shared_device_memory = cudart.cudaMalloc(max_device_memory)
389
+ self.shared_device_memory = shared_device_memory
390
+ # Load and activate TensorRT engines
391
+ for engine in self.engines.values():
392
+ engine.activate(reuse_device_memory=self.shared_device_memory)
393
+
394
+ def run_engine(self, model_name, feed_dict):
395
+ return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph)