onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,106 @@
1
+ from .operators.activation import QDQRemovableActivation, QLinearActivation
2
+ from .operators.argmax import QArgMax
3
+ from .operators.attention import AttentionQuant
4
+ from .operators.base_operator import QuantOperatorBase
5
+ from .operators.binary_op import QLinearBinaryOp
6
+ from .operators.concat import QLinearConcat
7
+ from .operators.conv import ConvInteger, QDQConv, QLinearConv
8
+ from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp
9
+ from .operators.embed_layernorm import EmbedLayerNormalizationQuant
10
+ from .operators.gather import GatherQuant, QDQGather
11
+ from .operators.gavgpool import QGlobalAveragePool
12
+ from .operators.gemm import QDQGemm, QLinearGemm
13
+ from .operators.lstm import LSTMQuant
14
+ from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul
15
+ from .operators.maxpool import QDQMaxPool, QMaxPool
16
+ from .operators.norm import QDQNormalization
17
+ from .operators.pad import QPad
18
+ from .operators.pooling import QLinearPool
19
+ from .operators.qdq_base_operator import QDQOperatorBase
20
+ from .operators.resize import QDQResize, QResize
21
+ from .operators.softmax import QLinearSoftmax
22
+ from .operators.split import QDQSplit, QSplit
23
+ from .operators.where import QDQWhere, QLinearWhere
24
+ from .quant_utils import QuantizationMode
25
+
26
+ CommonOpsRegistry = {
27
+ "Gather": GatherQuant,
28
+ "Transpose": Direct8BitOp,
29
+ "EmbedLayerNormalization": EmbedLayerNormalizationQuant,
30
+ }
31
+
32
+ IntegerOpsRegistry = {
33
+ "Conv": ConvInteger,
34
+ "MatMul": MatMulInteger,
35
+ "Attention": AttentionQuant,
36
+ "LSTM": LSTMQuant,
37
+ }
38
+ IntegerOpsRegistry.update(CommonOpsRegistry)
39
+
40
+ QLinearOpsRegistry = {
41
+ "ArgMax": QArgMax,
42
+ "Conv": QLinearConv,
43
+ "Gemm": QLinearGemm,
44
+ "MatMul": QLinearMatMul,
45
+ "Add": QLinearBinaryOp,
46
+ "Mul": QLinearBinaryOp,
47
+ "Relu": QLinearActivation,
48
+ "Clip": QLinearActivation,
49
+ "LeakyRelu": QLinearActivation,
50
+ "Sigmoid": QLinearActivation,
51
+ "MaxPool": QMaxPool,
52
+ "GlobalAveragePool": QGlobalAveragePool,
53
+ "Split": QSplit,
54
+ "Pad": QPad,
55
+ "Reshape": Direct8BitOp,
56
+ "Squeeze": Direct8BitOp,
57
+ "Unsqueeze": Direct8BitOp,
58
+ "Resize": QResize,
59
+ "AveragePool": QLinearPool,
60
+ "Concat": QLinearConcat,
61
+ "Softmax": QLinearSoftmax,
62
+ "Where": QLinearWhere,
63
+ }
64
+ QLinearOpsRegistry.update(CommonOpsRegistry)
65
+
66
+ QDQRegistry = {
67
+ "Conv": QDQConv,
68
+ "ConvTranspose": QDQConv,
69
+ "Gemm": QDQGemm,
70
+ "Clip": QDQRemovableActivation,
71
+ "Relu": QDQRemovableActivation,
72
+ "Reshape": QDQDirect8BitOp,
73
+ "Transpose": QDQDirect8BitOp,
74
+ "Squeeze": QDQDirect8BitOp,
75
+ "Unsqueeze": QDQDirect8BitOp,
76
+ "Resize": QDQResize,
77
+ "MaxPool": QDQMaxPool,
78
+ "AveragePool": QDQDirect8BitOp,
79
+ "MatMul": QDQMatMul,
80
+ "Split": QDQSplit,
81
+ "Gather": QDQGather,
82
+ "GatherElements": QDQGather,
83
+ "Where": QDQWhere,
84
+ "InstanceNormalization": QDQNormalization,
85
+ "LayerNormalization": QDQNormalization,
86
+ "BatchNormalization": QDQNormalization,
87
+ }
88
+
89
+
90
+ def CreateDefaultOpQuantizer(onnx_quantizer, node): # noqa: N802
91
+ return QuantOperatorBase(onnx_quantizer, node)
92
+
93
+
94
+ def CreateOpQuantizer(onnx_quantizer, node): # noqa: N802
95
+ registry = IntegerOpsRegistry if onnx_quantizer.mode == QuantizationMode.IntegerOps else QLinearOpsRegistry
96
+ if node.op_type in registry:
97
+ op_quantizer = registry[node.op_type](onnx_quantizer, node)
98
+ if op_quantizer.should_quantize():
99
+ return op_quantizer
100
+ return QuantOperatorBase(onnx_quantizer, node)
101
+
102
+
103
+ def CreateQDQQuantizer(onnx_quantizer, node): # noqa: N802
104
+ if node.op_type in QDQRegistry:
105
+ return QDQRegistry[node.op_type](onnx_quantizer, node)
106
+ return QDQOperatorBase(onnx_quantizer, node)
@@ -0,0 +1,187 @@
1
+ # --------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft, Intel Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+
8
+ import logging
9
+ import tempfile
10
+ import traceback
11
+ from pathlib import Path
12
+ from typing import Optional, Union
13
+
14
+ import onnx
15
+
16
+ import onnxruntime
17
+ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
18
+ from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data
19
+
20
+ from .quant_utils import add_pre_process_metadata
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def quant_pre_process(
26
+ input_model: Optional[Union[str, Path, onnx.ModelProto]] = None,
27
+ output_model_path: Optional[Union[str, Path]] = None,
28
+ skip_optimization: bool = False,
29
+ skip_onnx_shape: bool = False,
30
+ skip_symbolic_shape: bool = False,
31
+ auto_merge: bool = False,
32
+ int_max: int = 2**31 - 1,
33
+ guess_output_rank: bool = False,
34
+ verbose: int = 0,
35
+ save_as_external_data: bool = False,
36
+ all_tensors_to_one_file: bool = False,
37
+ external_data_location: Optional[str] = None,
38
+ external_data_size_threshold: int = 1024,
39
+ **deprecated_kwargs,
40
+ ) -> None:
41
+ """Shape inference and model optimization, in preparation for quantization.
42
+
43
+ Args:
44
+ input_model: Path to the input model file or ModelProto
45
+ output_model_path: Path to the output model file
46
+ skip_optimization: Skip model optimization step if true. This may result in ONNX shape
47
+ inference failure for some models.
48
+ skip_onnx_shape: Skip ONNX shape inference. Symbolic shape inference is most effective
49
+ with transformer based models. Skipping all shape inferences may
50
+ reduce the effectiveness of quantization, as a tensor with unknown
51
+ shape can not be quantized.
52
+ skip_symbolic_shape: Skip symbolic shape inference. Symbolic shape inference is most
53
+ effective with transformer based models. Skipping all shape
54
+ inferences may reduce the effectiveness of quantization, as a tensor
55
+ with unknown shape can not be quantized.
56
+ auto_merge: For symbolic shape inference, automatically merge symbolic dims when
57
+ conflict happens.
58
+ int_max: For symbolic shape inference, specify the maximum value for integer to be
59
+ treated as boundless for ops like slice
60
+ guess_output_rank: Guess output rank to be the same as input 0 for unknown ops
61
+ verbose: Logs detailed info of inference, 0: turn off, 1: warnings, 3: detailed
62
+ save_as_external_data: Saving an ONNX model to external data
63
+ all_tensors_to_one_file: Saving all the external data to one file
64
+ external_data_location: The file location to save the external file
65
+ external_data_size_threshold: The size threshold for external data
66
+ """
67
+
68
+ if input_model is None:
69
+ input_model = deprecated_kwargs.pop("input_model_path", None)
70
+ assert input_model is not None
71
+
72
+ assert output_model_path is not None, "output_model_path is required."
73
+
74
+ with tempfile.TemporaryDirectory(prefix="pre.quant.") as quant_tmp_dir:
75
+ temp_path = Path(quant_tmp_dir)
76
+ model = None
77
+
78
+ if not skip_symbolic_shape:
79
+ logger.info("Performing symbolic shape inference...")
80
+ loaded_model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
81
+ model = SymbolicShapeInference.infer_shapes(
82
+ loaded_model,
83
+ int_max,
84
+ auto_merge,
85
+ guess_output_rank,
86
+ verbose,
87
+ )
88
+
89
+ if not skip_optimization:
90
+ # Use ORT optimizers (native code) to optimize model
91
+ if not skip_symbolic_shape:
92
+ # Need to save the inferenced model to file so as to run the optimizer
93
+ input_model = str(temp_path / "symbolic_shape_inferred.onnx")
94
+ if save_as_external_data:
95
+ onnx.save_model(
96
+ model,
97
+ input_model,
98
+ save_as_external_data=True,
99
+ all_tensors_to_one_file=all_tensors_to_one_file,
100
+ size_threshold=external_data_size_threshold,
101
+ convert_attribute=False,
102
+ )
103
+ else:
104
+ onnx.save(model, input_model)
105
+ model = None
106
+
107
+ opt_model_path = str(temp_path / "optimized.onnx")
108
+ try:
109
+ sess_option = onnxruntime.SessionOptions()
110
+ sess_option.optimized_model_filepath = opt_model_path
111
+ sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
112
+ # For large model, extract external data from model and add to session options
113
+ if isinstance(input_model, onnx.ModelProto):
114
+ if has_external_data(input_model):
115
+ raise ValueError(
116
+ "ModelProto has external data not loaded into memory, ORT cannot create session. "
117
+ "Please load external data before calling this function. "
118
+ "See https://onnx.ai/onnx/repo-docs/ExternalData.html for more information."
119
+ )
120
+ external_names, external_values = extract_raw_data_from_model(input_model)
121
+ sess_option.add_external_initializers(list(external_names), list(external_values))
122
+ input_model = input_model.SerializeToString()
123
+
124
+ sess = onnxruntime.InferenceSession(input_model, sess_option, providers=["CPUExecutionProvider"])
125
+ # Close the session to avoid the cleanup error on Windows for temp folders
126
+ # https://github.com/microsoft/onnxruntime/issues/17627
127
+ del sess
128
+ except Exception:
129
+ logger.error(
130
+ "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'."
131
+ )
132
+ logger.error(traceback.format_exc())
133
+
134
+ input_model = opt_model_path
135
+
136
+ if not skip_onnx_shape:
137
+ # ONNX shape inference.
138
+ # According to docs, infer_shapes_path should be used for 2G+ models.
139
+ # If the skip optimization is specified, we could be dealing with a
140
+ # large model. So be on the safe side, save the model
141
+ if model is not None:
142
+ input_model = str(temp_path / "symbolic_shape_inferred.onnx")
143
+ if save_as_external_data:
144
+ onnx.save_model(
145
+ model,
146
+ input_model,
147
+ save_as_external_data=True,
148
+ all_tensors_to_one_file=all_tensors_to_one_file,
149
+ size_threshold=external_data_size_threshold,
150
+ convert_attribute=False,
151
+ )
152
+ else:
153
+ onnx.save(model, input_model)
154
+ model = None
155
+
156
+ if isinstance(input_model, onnx.ModelProto):
157
+ input_model = str(Path(quant_tmp_dir) / "model_input.onnx")
158
+ onnx.save_model(
159
+ model,
160
+ input_model,
161
+ save_as_external_data=True,
162
+ all_tensors_to_one_file=all_tensors_to_one_file,
163
+ size_threshold=external_data_size_threshold,
164
+ convert_attribute=False,
165
+ )
166
+
167
+ inferred_model_path = str(temp_path / "onnx_shape_inferred.onnx")
168
+ onnx.shape_inference.infer_shapes_path(input_model, inferred_model_path)
169
+ model = onnx.load(inferred_model_path)
170
+
171
+ if model is None:
172
+ model = input_model if isinstance(input_model, onnx.ModelProto) else onnx.load(input_model)
173
+
174
+ add_pre_process_metadata(model)
175
+
176
+ if save_as_external_data:
177
+ onnx.save_model(
178
+ model,
179
+ output_model_path,
180
+ save_as_external_data=True,
181
+ all_tensors_to_one_file=all_tensors_to_one_file,
182
+ location=external_data_location,
183
+ size_threshold=external_data_size_threshold,
184
+ convert_attribute=False,
185
+ )
186
+ else:
187
+ onnx.save(model, output_model_path)