onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,561 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ """
7
+ This converts GPT2 model to onnx. Examples:
8
+ (1) Convert pretrained model 'gpt2' to ONNX
9
+ python convert_to_onnx.py -m gpt2 --output gpt2.onnx
10
+ (2) Convert pretrained model 'distilgpt2' to ONNX, and use optimizer to get float16 model.
11
+ python convert_to_onnx.py -m distilgpt2 --output distilgpt2_fp16.onnx -o -p fp16
12
+ (3) Convert a model check point to ONNX, and run optimization and int8 quantization
13
+ python convert_to_onnx.py -m ./my_model_checkpoint/ --output my_model_int8.onnx -o -p int8
14
+
15
+ """
16
+
17
+ import argparse
18
+ import csv
19
+ import json
20
+ import logging
21
+ import os
22
+ import shutil
23
+ import sys
24
+ from pathlib import Path
25
+
26
+ import numpy
27
+ import torch
28
+ from benchmark_helper import (
29
+ Precision,
30
+ create_onnxruntime_session,
31
+ get_ort_environment_variables,
32
+ prepare_environment,
33
+ setup_logger,
34
+ )
35
+ from gpt2_helper import DEFAULT_TOLERANCE, MODEL_CLASSES, PRETRAINED_GPT2_MODELS, Gpt2Helper
36
+ from gpt2_tester import Gpt2Tester
37
+ from packaging import version
38
+ from quantize_helper import QuantizeHelper
39
+ from transformers import AutoConfig
40
+ from transformers import __version__ as transformers_version
41
+
42
+ from onnxruntime import __version__ as ort_version
43
+
44
+ logger = logging.getLogger("")
45
+
46
+
47
+ def parse_arguments(argv=None):
48
+ parser = argparse.ArgumentParser()
49
+
50
+ parser.add_argument(
51
+ "-m",
52
+ "--model_name_or_path",
53
+ required=True,
54
+ type=str,
55
+ help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS),
56
+ )
57
+
58
+ parser.add_argument(
59
+ "--model_class",
60
+ required=False,
61
+ type=str,
62
+ default="GPT2LMHeadModel",
63
+ choices=list(MODEL_CLASSES.keys()),
64
+ help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--cache_dir",
69
+ required=False,
70
+ type=str,
71
+ default=os.path.join(".", "cache_models"),
72
+ help="Directory to cache pre-trained models",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--output",
77
+ required=False,
78
+ type=str,
79
+ default=os.path.join(".", "onnx_models"),
80
+ help="Output directory, or model path ends with .onnx",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "-o",
85
+ "--optimize_onnx",
86
+ required=False,
87
+ action="store_true",
88
+ help="Use optimizer.py to optimize onnx model",
89
+ )
90
+ parser.set_defaults(optimize_onnx=False)
91
+
92
+ parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
93
+ parser.set_defaults(use_gpu=False)
94
+
95
+ parser.add_argument(
96
+ "--provider",
97
+ required=False,
98
+ default=None,
99
+ choices=["dml", "rocm", "migraphx", "cuda", "tensorrt"],
100
+ help="use dml, rocm, cuda, tensorrt or migraphx for respective backend",
101
+ )
102
+
103
+ parser.add_argument(
104
+ "--tolerance",
105
+ required=False,
106
+ type=float,
107
+ default=0,
108
+ help="the absolute and relative tolerance for parity verification",
109
+ )
110
+
111
+ parser.add_argument(
112
+ "--input_test_file",
113
+ "-i",
114
+ required=False,
115
+ type=str,
116
+ default="",
117
+ help="Path to the file with inputs to test with",
118
+ )
119
+
120
+ parser.add_argument(
121
+ "-p",
122
+ "--precision",
123
+ required=False,
124
+ type=Precision,
125
+ default=Precision.FLOAT32,
126
+ choices=list(Precision),
127
+ help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision, and int8 for quantization",
128
+ )
129
+
130
+ parser.add_argument(
131
+ "-t",
132
+ "--test_cases",
133
+ required=False,
134
+ type=int,
135
+ default=1000,
136
+ help="Number of test cases per run for parity",
137
+ )
138
+ parser.add_argument(
139
+ "-r",
140
+ "--test_runs",
141
+ required=False,
142
+ type=int,
143
+ default=10,
144
+ help="Number of runs for parity. It is used for significance test.",
145
+ )
146
+
147
+ parser.add_argument("--verbose", required=False, action="store_true")
148
+ parser.set_defaults(verbose=False)
149
+
150
+ parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
151
+ parser.set_defaults(use_external_data_format=False)
152
+
153
+ parser.add_argument("--overwrite", required=False, action="store_true")
154
+ parser.set_defaults(overwrite=False)
155
+
156
+ parser.add_argument(
157
+ "--use_int64_inputs",
158
+ required=False,
159
+ action="store_true",
160
+ help="Use int32 instead of int64 for input_ids, position_ids and attention_mask.",
161
+ )
162
+ parser.set_defaults(use_int64_inputs=False)
163
+
164
+ parser.add_argument(
165
+ "-s",
166
+ "--stage",
167
+ type=int,
168
+ default=0,
169
+ required=False,
170
+ choices=[0, 1, 2],
171
+ help="Stage in generation: 1 (initial decoder), 2 (decoder), 0 (both). "
172
+ "1 - decode the first token when past_sequence_length is zero; "
173
+ "2 - decode the remaining tokens when past_sequence_length is not zero; "
174
+ "0 - one onnx model for both stages 1 and 2. "
175
+ "Note that we will optimize 1 and 2 differently for best performance.",
176
+ )
177
+
178
+ fp16_option_group = parser.add_argument_group(
179
+ 'float to float16 conversion parameters that works when "--precision fp16" is specified'
180
+ )
181
+
182
+ fp16_option_group.add_argument(
183
+ "-a",
184
+ "--auto_mixed_precision",
185
+ required=False,
186
+ action="store_true",
187
+ help="Convert to mixed precision automatically. Other float16 conversion parameters will be ignored.",
188
+ )
189
+ fp16_option_group.set_defaults(auto_mixed_precision=False)
190
+
191
+ fp16_option_group.add_argument(
192
+ "--keep_io_types",
193
+ required=False,
194
+ action="store_true",
195
+ help="Use float32 for past inputs, present and logits outputs.",
196
+ )
197
+ fp16_option_group.set_defaults(keep_io_types=False)
198
+
199
+ fp16_option_group.add_argument(
200
+ "--io_block_list",
201
+ nargs="+",
202
+ default=[],
203
+ help="List of inputs or outputs in float32 instead of float16",
204
+ )
205
+
206
+ fp16_option_group.add_argument(
207
+ "--op_block_list",
208
+ nargs="+",
209
+ default=[],
210
+ help="List of operators (like Add LayerNormalization SkipLayerNormalization EmbedLayerNormalization FastGelu) "
211
+ "to compute in float32 instead of float16.",
212
+ )
213
+
214
+ fp16_option_group.add_argument(
215
+ "--node_block_list",
216
+ nargs="+",
217
+ default=[],
218
+ help="List of node names to compute in float32 instead of float16.",
219
+ )
220
+
221
+ fp16_option_group.add_argument(
222
+ "--force_fp16_initializers",
223
+ required=False,
224
+ action="store_true",
225
+ help="Convert all float initializers to float16.",
226
+ )
227
+ fp16_option_group.set_defaults(force_fp16_initializers=False)
228
+
229
+ args = parser.parse_args(argv)
230
+
231
+ return args
232
+
233
+
234
+ def get_onnx_model_size(onnx_path: str, use_external_data_format: bool):
235
+ if not use_external_data_format:
236
+ return os.path.getsize(onnx_path)
237
+ else:
238
+ return sum([f.stat().st_size for f in Path(onnx_path).parent.rglob("*")])
239
+
240
+
241
+ def get_latency_name(batch_size, sequence_length, past_sequence_length):
242
+ return f"average_latency(batch_size={batch_size},sequence_length={sequence_length},past_sequence_length={past_sequence_length})"
243
+
244
+
245
+ def main(argv=None, experiment_name: str = "", run_id: str = "0", csv_filename: str = "gpt2_parity_results.csv"):
246
+ result = {}
247
+ if version.parse(transformers_version) < version.parse(
248
+ "3.1.0"
249
+ ): # past_key_values name does not exist in 3.0.2 or older
250
+ raise RuntimeError("This tool requires transformers 3.1.0 or later.")
251
+
252
+ args = parse_arguments(argv)
253
+ setup_logger(args.verbose)
254
+
255
+ if not experiment_name:
256
+ experiment_name = " ".join(argv if argv else sys.argv[1:])
257
+
258
+ if args.tolerance == 0:
259
+ args.tolerance = DEFAULT_TOLERANCE[args.precision]
260
+
261
+ logger.info(f"Arguments:{args}")
262
+
263
+ cache_dir = args.cache_dir
264
+ output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
265
+ prepare_environment(cache_dir, output_dir, args.use_gpu)
266
+
267
+ if args.precision != Precision.FLOAT32:
268
+ assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
269
+
270
+ if args.precision == Precision.FLOAT16:
271
+ assert args.use_gpu, "fp16 requires --use_gpu"
272
+
273
+ if args.precision == Precision.INT8:
274
+ assert not args.use_gpu, "quantization only supports CPU"
275
+
276
+ model_class = MODEL_CLASSES[args.model_class][0]
277
+ use_padding = MODEL_CLASSES[args.model_class][2]
278
+
279
+ gpt2helper = Gpt2Helper
280
+ config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=cache_dir)
281
+ model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir)
282
+
283
+ device = torch.device("cuda:0" if args.use_gpu else "cpu")
284
+ model.eval().to(device)
285
+
286
+ if (not args.use_external_data_format) and (config.n_layer > 24):
287
+ logger.info("Try --use_external_data_format when model size > 2GB")
288
+
289
+ onnx_model_paths = gpt2helper.get_onnx_paths(
290
+ output_dir,
291
+ args.model_name_or_path,
292
+ args.model_class,
293
+ new_folder=(args.precision == Precision.INT8),
294
+ remove_existing=["fp32", "fp16", "int8"],
295
+ ) # Do not remove raw model to save time in parity test
296
+
297
+ raw_onnx_model = onnx_model_paths["raw"]
298
+
299
+ int_data_type = torch.int64 if args.use_int64_inputs else torch.int32
300
+
301
+ if os.path.exists(raw_onnx_model) and not args.overwrite:
302
+ logger.warning(f"Skip exporting ONNX model since it existed: {raw_onnx_model}")
303
+ else:
304
+ logger.info(f"Exporting ONNX model to {raw_onnx_model}")
305
+ gpt2helper.export_onnx(
306
+ model,
307
+ device,
308
+ raw_onnx_model,
309
+ args.verbose,
310
+ args.use_external_data_format,
311
+ has_position_ids=use_padding,
312
+ has_attention_mask=use_padding,
313
+ input_ids_dtype=int_data_type,
314
+ position_ids_dtype=int_data_type,
315
+ attention_mask_dtype=int_data_type,
316
+ )
317
+
318
+ fp16_params = {"keep_io_types": args.keep_io_types}
319
+ if args.io_block_list:
320
+ fp16_params["keep_io_types"] = args.io_block_list
321
+ if args.node_block_list:
322
+ fp16_params["node_block_list"] = args.node_block_list
323
+ if args.op_block_list:
324
+ fp16_params["op_block_list"] = args.op_block_list
325
+ if args.force_fp16_initializers:
326
+ fp16_params["force_fp16_initializers"] = args.force_fp16_initializers
327
+
328
+ is_io_float16 = args.precision == Precision.FLOAT16 and not args.keep_io_types
329
+
330
+ optimized_ops = ""
331
+ all_ops = ""
332
+ if args.optimize_onnx or args.precision != Precision.FLOAT32:
333
+ output_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"]
334
+
335
+ logger.info(f"Optimizing model to {output_path}")
336
+ m = gpt2helper.optimize_onnx(
337
+ raw_onnx_model,
338
+ output_path,
339
+ args.precision == Precision.FLOAT16,
340
+ model.config.num_attention_heads,
341
+ model.config.hidden_size,
342
+ args.use_external_data_format,
343
+ auto_mixed_precision=args.auto_mixed_precision,
344
+ stage=args.stage,
345
+ **fp16_params,
346
+ )
347
+
348
+ nodes = m.nodes()
349
+ op_list = {node.op_type for node in nodes}
350
+ all_ops = ",".join(op_list)
351
+
352
+ # print optimized operators
353
+ optimized_op_counter = m.get_fused_operator_statistics()
354
+ if optimized_op_counter:
355
+ optimized_ops = ",".join([key for key in optimized_op_counter if optimized_op_counter[key] > 0])
356
+ else:
357
+ output_path = raw_onnx_model
358
+
359
+ if args.precision == Precision.INT8:
360
+ logger.info("quantizing model...")
361
+ QuantizeHelper.quantize_onnx_model(output_path, onnx_model_paths["int8"], args.use_external_data_format)
362
+ model = QuantizeHelper.quantize_torch_model(model)
363
+ logger.info("finished quantizing model")
364
+ output_path = onnx_model_paths["int8"]
365
+
366
+ if args.output.endswith(".onnx") and output_path != args.output and not args.use_external_data_format:
367
+ shutil.move(output_path, args.output)
368
+ output_path = args.output
369
+
370
+ logger.info(f"Output path: {output_path}")
371
+ model_size_in_MB = int(get_onnx_model_size(output_path, args.use_external_data_format) / 1024 / 1024) # noqa: N806
372
+
373
+ provider = args.provider
374
+ if args.provider == "migraphx":
375
+ provider = "MIGraphXExecutionProvider"
376
+
377
+ session = create_onnxruntime_session(
378
+ output_path, args.use_gpu, provider, enable_all_optimization=True, verbose=args.verbose
379
+ )
380
+ if args.model_class == "GPT2LMHeadModel" and session is not None:
381
+ parity_result = gpt2helper.test_parity(
382
+ session,
383
+ model,
384
+ device,
385
+ is_io_float16,
386
+ rtol=args.tolerance,
387
+ atol=args.tolerance,
388
+ model_class=args.model_class,
389
+ has_position_ids=use_padding,
390
+ has_attention_mask=use_padding,
391
+ input_ids_dtype=int_data_type,
392
+ position_ids_dtype=int_data_type,
393
+ attention_mask_dtype=int_data_type,
394
+ test_cases_per_run=args.test_cases,
395
+ total_runs=args.test_runs,
396
+ stage=args.stage,
397
+ verbose=args.verbose,
398
+ )
399
+
400
+ # An example configuration for testing performance
401
+ batch_size = 8
402
+ sequence_length = 32 if args.stage == 1 else 1
403
+ past_sequence_length = 0 if args.stage == 1 else 32
404
+
405
+ latency = gpt2helper.test_performance(
406
+ session,
407
+ model,
408
+ device,
409
+ is_io_float16,
410
+ total_runs=100,
411
+ use_io_binding=True,
412
+ model_class=args.model_class,
413
+ has_position_ids=use_padding,
414
+ has_attention_mask=use_padding,
415
+ input_ids_dtype=int_data_type,
416
+ position_ids_dtype=int_data_type,
417
+ attention_mask_dtype=int_data_type,
418
+ batch_size=batch_size,
419
+ sequence_length=sequence_length,
420
+ past_sequence_length=past_sequence_length,
421
+ )
422
+
423
+ if args.precision == Precision.FLOAT16:
424
+ logger.info(f"fp16 conversion parameters:{fp16_params}")
425
+
426
+ # Write results to file
427
+ latency_name = get_latency_name(batch_size, sequence_length, past_sequence_length)
428
+ csv_file_existed = os.path.exists(csv_filename)
429
+ with open(csv_filename, mode="a", newline="") as csv_file:
430
+ column_names = [
431
+ "experiment",
432
+ "run_id",
433
+ "model_name",
434
+ "model_class",
435
+ "stage",
436
+ "gpu",
437
+ "precision",
438
+ "optimizer",
439
+ "test_cases",
440
+ "runs",
441
+ "keep_io_types",
442
+ "io_block_list",
443
+ "op_block_list",
444
+ "node_block_list",
445
+ "force_fp16_initializers",
446
+ "auto_mixed_precision",
447
+ "optimized_operators",
448
+ "operators",
449
+ "environment_variables",
450
+ "onnxruntime",
451
+ latency_name,
452
+ "top1_match_rate",
453
+ "onnx_size_in_MB",
454
+ "diff_50_percentile",
455
+ "diff_90_percentile",
456
+ "diff_95_percentile",
457
+ "diff_99_percentile",
458
+ "diff_pass_rate",
459
+ "nan_rate",
460
+ "top1_match_rate_per_run",
461
+ ]
462
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
463
+ if not csv_file_existed:
464
+ csv_writer.writeheader()
465
+ row = {
466
+ "experiment": experiment_name,
467
+ "run_id": run_id,
468
+ "model_name": args.model_name_or_path,
469
+ "model_class": args.model_class,
470
+ "stage": args.stage,
471
+ "gpu": args.use_gpu,
472
+ "precision": args.precision,
473
+ "optimizer": args.optimize_onnx,
474
+ "test_cases": args.test_cases,
475
+ "runs": args.test_runs,
476
+ "keep_io_types": args.keep_io_types,
477
+ "io_block_list": args.io_block_list,
478
+ "op_block_list": args.op_block_list,
479
+ "node_block_list": args.node_block_list,
480
+ "force_fp16_initializers": args.force_fp16_initializers,
481
+ "auto_mixed_precision": args.auto_mixed_precision,
482
+ "optimized_operators": optimized_ops,
483
+ "operators": all_ops,
484
+ "environment_variables": get_ort_environment_variables(),
485
+ "onnxruntime": ort_version,
486
+ latency_name: f"{latency:.2f}",
487
+ "diff_50_percentile": parity_result["max_diff_percentile_50"],
488
+ "diff_90_percentile": parity_result["max_diff_percentile_90"],
489
+ "diff_95_percentile": parity_result["max_diff_percentile_95"],
490
+ "diff_99_percentile": parity_result["max_diff_percentile_99"],
491
+ "diff_pass_rate": parity_result["diff_pass_rate"],
492
+ "nan_rate": parity_result["nan_rate"],
493
+ "top1_match_rate": parity_result["top1_match_rate"],
494
+ "top1_match_rate_per_run": parity_result["top1_match_rate_per_run"],
495
+ "onnx_size_in_MB": f"{model_size_in_MB}",
496
+ }
497
+ logger.info(f"result: {row}")
498
+ result.update(row)
499
+ csv_writer.writerow(row)
500
+
501
+ if args.input_test_file:
502
+ test_inputs = []
503
+ # Each line of test file is a JSON string like:
504
+ # {"input_ids": [[14698, 257, 1310, 13688, 319, 326]]}
505
+ with open(args.input_test_file) as read_f:
506
+ for _, line in enumerate(read_f):
507
+ line = line.rstrip() # noqa: PLW2901
508
+ data = json.loads(line)
509
+ input_ids = torch.from_numpy(numpy.asarray(data["input_ids"], dtype=numpy.int64)).to(device)
510
+
511
+ if use_padding:
512
+ if "attention_mask" in data:
513
+ numpy_float = numpy.float16 if is_io_float16 else numpy.float32
514
+ attention_mask = torch.from_numpy(numpy.asarray(data["attention_mask"], dtype=numpy_float)).to(
515
+ device
516
+ )
517
+ else:
518
+ padding = -1
519
+ attention_mask = (input_ids != padding).type(torch.float16 if is_io_float16 else torch.float32)
520
+ input_ids.masked_fill_(input_ids == padding, 0)
521
+
522
+ if "position_ids" in data:
523
+ position_ids = torch.from_numpy(numpy.asarray(data["position_ids"], dtype=numpy.int64)).to(
524
+ device
525
+ )
526
+ else:
527
+ position_ids = attention_mask.long().cumsum(-1) - 1
528
+ position_ids.masked_fill_(position_ids < 0, 0)
529
+
530
+ inputs = {
531
+ "input_ids": input_ids.to(int_data_type),
532
+ "position_ids": position_ids.to(int_data_type),
533
+ "attention_mask": attention_mask.to(int_data_type),
534
+ }
535
+ else:
536
+ inputs = {"input_ids": input_ids.to(int_data_type)}
537
+
538
+ test_inputs.append(inputs)
539
+
540
+ Gpt2Tester.test_generation(
541
+ session,
542
+ model,
543
+ device,
544
+ test_inputs,
545
+ precision=args.precision,
546
+ model_class=args.model_class,
547
+ top_k=20,
548
+ top_k_no_order=True,
549
+ max_steps=24,
550
+ max_inputs=0,
551
+ verbose=args.verbose,
552
+ save_test_data=3,
553
+ save_test_data_dir=Path(output_path).parent,
554
+ )
555
+
556
+ logger.info(f"Done. Output model: {output_path}")
557
+ return result
558
+
559
+
560
+ if __name__ == "__main__":
561
+ main()