onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,501 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ # This script helps evaluation of GPT-2 model.
7
+ import logging
8
+ import math
9
+ import os
10
+ import statistics
11
+ import timeit
12
+
13
+ import numpy
14
+ import torch
15
+ from benchmark_helper import Precision
16
+ from gpt2_helper import Gpt2Helper, Gpt2Inputs
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class Gpt2Metric:
22
+ def __init__(self, treatment_name, baseline_name="Torch", top_k=20):
23
+ assert top_k > 1 and top_k <= 100
24
+ self.baseline = baseline_name
25
+ self.treatment = treatment_name
26
+ self.name: str = f"{treatment_name} vs {baseline_name}"
27
+ self.top_k = top_k
28
+ self.top_1_error: int = 0
29
+ self.top_k_error: int = 0
30
+ self.total_samples: int = 0
31
+ self.max_logits_diff: float = 0 # for non-empty past state
32
+ self.max_logits_diff_no_past: float = 0 # for empty past state
33
+ self.batch_top1_error: torch.FloatTensor = None # top 1 error for current batch
34
+ self.batch_topk_error: torch.FloatTensor = None # top k error for current batch
35
+ self.seq_len_latency = {}
36
+
37
+ def print(self):
38
+ if self.baseline != self.treatment:
39
+ print("---")
40
+ print(f"Metrics for {self.treatment} (baseline={self.baseline}):")
41
+ if self.total_samples > 0:
42
+ top_1_error_rate = 100.0 * self.top_1_error / self.total_samples
43
+ top_k_error_rate = 100.0 * self.top_k_error / self.total_samples
44
+ print(
45
+ f"Total={self.total_samples} Top1Error={self.top_1_error} ({top_1_error_rate:.2f}%) Top{self.top_k}Error={self.top_k_error} ({top_k_error_rate:.2f}%)"
46
+ )
47
+ print("Max logits diffs:")
48
+ print(f"\twith past = {self.max_logits_diff:.6f}")
49
+ print(f"\tempty past = {self.max_logits_diff_no_past:.6f}")
50
+ else:
51
+ print(f"Metrics for {self.treatment} (baseline):")
52
+
53
+ if self.seq_len_latency:
54
+ print("Past sequence length range and average latency:")
55
+ total = 0
56
+ count = 0
57
+ for key in sorted(self.seq_len_latency.keys()):
58
+ average = statistics.mean(self.seq_len_latency[key]) * 1000.0
59
+ if key == 0:
60
+ print(f"\t{key}: \t{average:.2f} ms")
61
+ else:
62
+ print(f"\t[{2**key}, {2 ** (key + 1) - 1}]:\t{average:.2f} ms")
63
+ total += average * len(self.seq_len_latency[key])
64
+ count += len(self.seq_len_latency[key])
65
+ print(f"Average Latency: {total / count:.2f} ms")
66
+
67
+ def diff_logits(self, baseline_logits, treatment_logits, is_empty_past: bool):
68
+ diff = (baseline_logits - treatment_logits).abs().max()
69
+ if is_empty_past:
70
+ self.max_logits_diff_no_past = max(self.max_logits_diff_no_past, diff)
71
+ else:
72
+ self.max_logits_diff = max(self.max_logits_diff, diff)
73
+
74
+ return diff
75
+
76
+ def start_batch(self, batch_size: int):
77
+ self.total_samples += batch_size
78
+ self.batch_top1_error = torch.zeros((batch_size, 1), dtype=torch.bool)
79
+ self.batch_topk_error = torch.zeros((batch_size, 1), dtype=torch.bool)
80
+
81
+ def eval_batch(self, baseline, treatment, past_seq_len, verbose=True):
82
+ self._eval_topk(baseline.top_1_tokens, treatment.top_1_tokens, 1, verbose)
83
+ self._eval_topk(baseline.top_k_tokens, treatment.top_k_tokens, self.top_k, verbose)
84
+
85
+ max_diff = self.diff_logits(baseline.logits, treatment.logits, past_seq_len == 0)
86
+ if verbose:
87
+ print(f"Max logits diffs of {self.name}: {max_diff}")
88
+
89
+ def _eval_topk(self, baseline_topk, treatment_topk, top_k, verbose=True):
90
+ if not torch.all(torch.eq(baseline_topk, treatment_topk)):
91
+ if top_k == 1:
92
+ if verbose:
93
+ print(f"Generated tokens not matched for {self.name}")
94
+ self.batch_top1_error |= torch.eq(baseline_topk, treatment_topk).logical_not()
95
+ else:
96
+ if verbose:
97
+ print(
98
+ f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results"
99
+ )
100
+ self.batch_topk_error |= (
101
+ torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) > 0
102
+ )
103
+
104
+ def end_batch(self):
105
+ self.top_1_error += self.batch_top1_error.sum()
106
+ self.top_k_error += self.batch_topk_error.sum()
107
+
108
+ def add_latency(self, past_seq_len, latency):
109
+ key = int(math.log2(past_seq_len)) + 1 if past_seq_len > 0 else 0
110
+ if key not in self.seq_len_latency:
111
+ self.seq_len_latency[key] = []
112
+ self.seq_len_latency[key].append(latency)
113
+
114
+
115
+ class Gpt2Tester:
116
+ def __init__(
117
+ self,
118
+ input_ids,
119
+ position_ids,
120
+ attention_mask,
121
+ num_attention_heads,
122
+ hidden_size,
123
+ num_layer,
124
+ device,
125
+ is_fp16=False,
126
+ top_k=20,
127
+ top_k_required_order=False,
128
+ ):
129
+ self.batch_size = input_ids.shape[0]
130
+ self.input_length = input_ids.shape[1]
131
+ self.n_layer = num_layer
132
+
133
+ self.input_ids = input_ids
134
+ self.position_ids = position_ids
135
+ self.attention_mask = attention_mask
136
+
137
+ self.has_position_ids = position_ids is not None
138
+ self.has_attention_mask = attention_mask is not None
139
+
140
+ # Empty past state for first inference
141
+ self.past = []
142
+ past_shape = [
143
+ 2,
144
+ self.batch_size,
145
+ num_attention_heads,
146
+ 0,
147
+ hidden_size // num_attention_heads,
148
+ ]
149
+ for _i in range(num_layer):
150
+ empty_past = torch.empty(past_shape).type(torch.float16 if is_fp16 else torch.float32)
151
+ self.past.append(empty_past.to(device))
152
+
153
+ self.logits = None
154
+ self.top_1_tokens = None
155
+ self.top_k_tokens = None
156
+ self.top_k = top_k
157
+ self.top_k_required_order = top_k_required_order
158
+
159
+ def get_inputs(self) -> Gpt2Inputs:
160
+ return Gpt2Inputs(self.input_ids, self.position_ids, self.attention_mask, self.past)
161
+
162
+ def save_test_data(self, session, output, save_test_data_dir, test_case_id):
163
+ from onnx import numpy_helper
164
+
165
+ path = os.path.join(save_test_data_dir, "test_data_set_" + str(test_case_id))
166
+ if os.path.exists(path):
167
+ print(f"Directory {path} existed. Skip saving test data")
168
+ return
169
+
170
+ os.makedirs(path, exist_ok=True)
171
+
172
+ def add_tensor(input_tensors, torch_tensor, name):
173
+ input_tensors.append(numpy_helper.from_array(torch_tensor.clone().cpu().numpy(), name))
174
+
175
+ input_tensors = []
176
+ add_tensor(input_tensors, self.input_ids, "input_ids")
177
+
178
+ if self.has_position_ids:
179
+ add_tensor(input_tensors, self.position_ids, "position_ids")
180
+
181
+ if self.has_attention_mask:
182
+ add_tensor(input_tensors, self.attention_mask, "attention_mask")
183
+
184
+ for i in range(self.n_layer):
185
+ add_tensor(input_tensors, self.past[i], "past_" + str(i))
186
+
187
+ for i, tensor in enumerate(input_tensors):
188
+ with open(os.path.join(path, f"input_{i}.pb"), "wb") as f:
189
+ f.write(tensor.SerializeToString())
190
+
191
+ output_names = [output.name for output in session.get_outputs()]
192
+ for i, _name in enumerate(output_names):
193
+ tensor = numpy_helper.from_array(
194
+ output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy()
195
+ )
196
+ with open(os.path.join(path, f"output_{i}.pb"), "wb") as f:
197
+ f.write(tensor.SerializeToString())
198
+
199
+ print(f"Test data saved to directory {path}")
200
+
201
+ def update(self, output, step, device):
202
+ """
203
+ Update the inputs for next inference.
204
+ """
205
+ self.logits = (
206
+ torch.from_numpy(output[0]) if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu()
207
+ )
208
+
209
+ self.top_1_tokens = Gpt2Tester.predict_next_token(self.logits)
210
+ self.top_k_tokens = Gpt2Tester.predict_next_token(self.logits, self.top_k, self.top_k_required_order)
211
+
212
+ self.input_ids = self.top_1_tokens.clone().detach().reshape([self.batch_size, 1]).to(device)
213
+
214
+ if self.has_position_ids:
215
+ self.position_ids = (
216
+ torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, 1).to(device)
217
+ )
218
+
219
+ if self.has_attention_mask:
220
+ self.attention_mask = torch.cat(
221
+ [
222
+ self.attention_mask,
223
+ torch.ones([self.batch_size, 1]).type_as(self.attention_mask),
224
+ ],
225
+ 1,
226
+ ).to(device)
227
+
228
+ self.past = []
229
+
230
+ if isinstance(output[1], tuple): # past in torch output is tuple
231
+ self.past = list(output[1])
232
+ else:
233
+ for i in range(self.n_layer):
234
+ past_i = (
235
+ torch.from_numpy(output[i + 1])
236
+ if isinstance(output[i + 1], numpy.ndarray)
237
+ else output[i + 1].clone().detach()
238
+ )
239
+ self.past.append(past_i.to(device))
240
+
241
+ def diff(self, baseline):
242
+ """
243
+ Compare inputs and logits output.
244
+ """
245
+
246
+ print("start diff...")
247
+ if self.logits is not None:
248
+ max_io_diff = (self.logits - baseline.logits).abs().max()
249
+ if max_io_diff > 1e-4:
250
+ print(f"Max logits difference is too large: {max_io_diff}")
251
+
252
+ if not torch.all(self.input_ids == baseline.input_ids):
253
+ print("Input_ids is different", self.input_ids, baseline.input_ids)
254
+
255
+ if self.has_position_ids:
256
+ if not torch.all(self.position_ids == baseline.position_ids):
257
+ print(
258
+ "position_ids is different",
259
+ self.position_ids,
260
+ baseline.position_ids,
261
+ )
262
+
263
+ if self.has_attention_mask:
264
+ if not torch.all(self.attention_mask == baseline.attention_mask):
265
+ print(
266
+ "attention_mask is different",
267
+ self.attention_mask,
268
+ baseline.attention_mask,
269
+ )
270
+
271
+ assert len(self.past) == len(baseline.past)
272
+
273
+ for i, past_i in enumerate(self.past):
274
+ assert past_i.shape == baseline.past[i].shape
275
+ if past_i.nelement() > 0:
276
+ max_past_diff = (past_i - baseline.past[i]).abs().max()
277
+ if max_past_diff > 1e-4:
278
+ print(f"max_past_diff[{i}]={max_past_diff}")
279
+
280
+ @staticmethod
281
+ def predict_next_token(logits, top_k=1, required_order=False):
282
+ """
283
+ Get top k topkens based on logits.
284
+ """
285
+
286
+ # logits has shape (batch_size, seq_len, vocab_size)
287
+ # last token logits has shape (batch_size, vocab_size)
288
+ lastTokenLogits = logits[:, -1] # noqa: N806
289
+ if top_k == 1:
290
+ generatedTokens = torch.argmax(lastTokenLogits, 1, True) # noqa: N806
291
+ return generatedTokens
292
+ else:
293
+ topk = torch.argsort(lastTokenLogits, -1, descending=True)[:, :top_k]
294
+ if not required_order:
295
+ sorted_topk, _ = topk.sort()
296
+ return sorted_topk
297
+ return topk
298
+
299
+ @staticmethod
300
+ def diff_present(onnx_output, onnx_io_output, n_layer):
301
+ """
302
+ Compare the present outputs of two outputs from ONNX Runtime.
303
+ """
304
+ present_diff_max = []
305
+ for i in range(n_layer):
306
+ onnx_present_i = (
307
+ torch.from_numpy(onnx_output[i + 1])
308
+ if isinstance(onnx_output[i + 1], numpy.ndarray)
309
+ else onnx_output[i + 1]
310
+ )
311
+ onnx_io_present_i = (
312
+ torch.from_numpy(onnx_io_output[i + 1])
313
+ if isinstance(onnx_io_output[i + 1], numpy.ndarray)
314
+ else onnx_io_output[i + 1]
315
+ )
316
+ max_diff = (onnx_present_i - onnx_io_present_i).abs().max()
317
+ present_diff_max.append(max_diff)
318
+ print(f"present_diff_max={present_diff_max}")
319
+
320
+ @staticmethod
321
+ def is_quantized_onnx_model(onnx_model_path):
322
+ """
323
+ Returns True if the ONNX model is quantized.
324
+ """
325
+ from onnx import load
326
+
327
+ model = load(onnx_model_path)
328
+ from onnxruntime.quantization.quantize import __producer__ as quantize_producer
329
+
330
+ return model.producer_name == quantize_producer
331
+
332
+ @staticmethod
333
+ def test_generation(
334
+ session,
335
+ model,
336
+ device,
337
+ test_inputs,
338
+ precision=Precision.FLOAT32,
339
+ model_class="Gpt2LMHeadModel",
340
+ top_k=20,
341
+ top_k_no_order=True,
342
+ max_steps=24,
343
+ max_inputs=0,
344
+ verbose=False,
345
+ save_test_data=0,
346
+ save_test_data_dir=".",
347
+ ):
348
+ """
349
+ Test Generation using greedy beam search (without sampling) to compare PyTorch and ONNX model.
350
+ It will print top 1 and top k errors on the given test inputs.
351
+ """
352
+ print(
353
+ f"start test generation: (top_k={top_k} top_k_no_order={top_k_no_order} max_steps={max_steps} test_inputs={len(test_inputs)} max_inputs={max_inputs})"
354
+ )
355
+ n_layer = model.config.n_layer
356
+ n_head = model.config.n_head
357
+ n_embd = model.config.n_embd
358
+ eos_token_id = model.config.eos_token_id
359
+ test_data_saved = 0
360
+
361
+ is_float16 = precision == Precision.FLOAT16
362
+ if is_float16:
363
+ assert "float16" in session.get_outputs()[0].type
364
+
365
+ # We will still use fp32 torch model as baseline when onnx model if fp16
366
+ model.eval().to(device)
367
+
368
+ # Allocate initial buffers for IO Binding of ONNX Runtimne. The buffer size will automatically increase later.
369
+ init_output_shapes = Gpt2Helper.get_output_shapes(
370
+ batch_size=4,
371
+ past_sequence_length=128,
372
+ sequence_length=32,
373
+ config=model.config,
374
+ model_class=model_class,
375
+ )
376
+ output_buffers = Gpt2Helper.get_output_buffers(init_output_shapes, device, is_float16=is_float16)
377
+
378
+ baseline_name = "Torch"
379
+ treatment_name = "Quantized Onnx" if precision == Precision.INT8 else "Onnx"
380
+ torch_metric = Gpt2Metric(baseline_name, baseline_name, top_k)
381
+ onnx_metric = Gpt2Metric(treatment_name, baseline_name, top_k)
382
+ onnx_io_metric = Gpt2Metric(treatment_name + " with IO Binding", baseline_name, top_k)
383
+
384
+ for i, inputs in enumerate(test_inputs):
385
+ if max_inputs > 0 and i == max_inputs:
386
+ break
387
+ if i % 10 == 0:
388
+ print(f"{i}")
389
+ input_ids = inputs["input_ids"]
390
+ position_ids = inputs.get("position_ids", None)
391
+ attention_mask = inputs.get("attention_mask", None)
392
+
393
+ onnx_runner = Gpt2Tester(
394
+ input_ids,
395
+ position_ids,
396
+ attention_mask,
397
+ n_head,
398
+ n_embd,
399
+ n_layer,
400
+ device,
401
+ is_float16,
402
+ top_k,
403
+ not top_k_no_order,
404
+ )
405
+ onnx_io_runner = Gpt2Tester(
406
+ input_ids,
407
+ position_ids,
408
+ attention_mask,
409
+ n_head,
410
+ n_embd,
411
+ n_layer,
412
+ device,
413
+ is_float16,
414
+ top_k,
415
+ not top_k_no_order,
416
+ )
417
+ torch_runner = Gpt2Tester(
418
+ input_ids,
419
+ position_ids,
420
+ attention_mask,
421
+ n_head,
422
+ n_embd,
423
+ n_layer,
424
+ device,
425
+ False,
426
+ top_k,
427
+ not top_k_no_order,
428
+ ) # Torch model baseline is fp32
429
+
430
+ batch_size = torch_runner.batch_size
431
+ onnx_metric.start_batch(batch_size)
432
+ onnx_io_metric.start_batch(batch_size)
433
+
434
+ with torch.no_grad():
435
+ done = torch.zeros(batch_size, dtype=torch.bool)
436
+ for step in range(max_steps):
437
+ seq_len = list(onnx_runner.input_ids.size())[1]
438
+ past_seq_len = list(onnx_runner.past[0].size())[3]
439
+
440
+ start_time = timeit.default_timer()
441
+ pytorch_output = Gpt2Helper.pytorch_inference(model, torch_runner.get_inputs())
442
+ torch_metric.add_latency(past_seq_len, timeit.default_timer() - start_time)
443
+ torch_runner.update(pytorch_output, step, device)
444
+
445
+ onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference(
446
+ session, onnx_runner.get_inputs(), total_runs=1
447
+ )
448
+ onnx_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
449
+ onnx_runner.update(onnx_output, step, device)
450
+
451
+ output_shapes = Gpt2Helper.get_output_shapes(
452
+ batch_size,
453
+ past_seq_len,
454
+ seq_len,
455
+ model.config,
456
+ model_class=model_class,
457
+ )
458
+ Gpt2Helper.auto_increase_buffer_size(output_buffers, output_shapes)
459
+
460
+ (
461
+ onnx_io_output,
462
+ avg_latency_ms,
463
+ ) = Gpt2Helper.onnxruntime_inference_with_binded_io(
464
+ session,
465
+ onnx_io_runner.get_inputs(),
466
+ output_buffers,
467
+ output_shapes,
468
+ total_runs=1,
469
+ return_numpy=False,
470
+ include_copy_output_latency=True,
471
+ )
472
+ onnx_io_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0)
473
+
474
+ if test_data_saved < save_test_data:
475
+ onnx_io_runner.save_test_data(session, onnx_io_output, save_test_data_dir, test_data_saved)
476
+ test_data_saved += 1
477
+
478
+ onnx_io_runner.update(onnx_io_output, step, device)
479
+
480
+ if verbose:
481
+ onnx_runner.diff(onnx_io_runner)
482
+ Gpt2Tester.diff_present(onnx_output, onnx_io_output, n_layer)
483
+
484
+ print("Top 1 tokens:")
485
+ print("\tTorch", torch_runner.top_1_tokens)
486
+ print("\tONNX", onnx_runner.top_1_tokens)
487
+ print("\tONNX with IO binding", onnx_io_runner.top_1_tokens)
488
+
489
+ onnx_metric.eval_batch(torch_runner, onnx_runner, past_seq_len, verbose=verbose)
490
+ onnx_io_metric.eval_batch(torch_runner, onnx_io_runner, past_seq_len, verbose=verbose)
491
+
492
+ done = done | (torch_runner.top_1_tokens == eos_token_id).any()
493
+ if torch.all(done):
494
+ break
495
+
496
+ onnx_metric.end_batch()
497
+ onnx_io_metric.end_batch()
498
+
499
+ torch_metric.print()
500
+ onnx_metric.print()
501
+ onnx_io_metric.print()
@@ -0,0 +1,146 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ # This script helps debugging parity issue for two same onnx models with fp16 and fp32 format
7
+ # Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON
8
+
9
+ import math
10
+ import multiprocessing
11
+ import os
12
+ from pathlib import Path
13
+
14
+ import numpy
15
+ import torch
16
+ from benchmark_helper import create_onnxruntime_session
17
+ from gpt2_helper import Gpt2Helper
18
+ from onnx import TensorProto, numpy_helper
19
+
20
+ NON_ZERO_VALUE = str(1)
21
+ ZERO_VALUE = str(0)
22
+
23
+
24
+ def environ_setting_nodes(node_name_filter=None, node_type_filter=None):
25
+ # Set I/O data as default
26
+ os.environ["ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA"] = ZERO_VALUE
27
+ os.environ["ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA"] = NON_ZERO_VALUE
28
+ os.environ["ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA"] = NON_ZERO_VALUE
29
+ if node_name_filter is not None:
30
+ os.environ["ORT_DEBUG_NODE_IO_NAME_FILTER"] = node_name_filter
31
+ elif node_type_filter is not None:
32
+ os.environ["ORT_DEBUG_NODE_IO_OP_TYPE_FILTER"] = node_type_filter
33
+ else:
34
+ os.environ["ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK"] = NON_ZERO_VALUE
35
+
36
+
37
+ def environ_setting_paths(output_path):
38
+ # Set dumping values to files as default
39
+ os.environ["ORT_DEBUG_NODE_IO_DUMP_DATA_DESTINATION"] = "files"
40
+ os.environ["ORT_DEBUG_NODE_IO_OUTPUT_DIR"] = output_path
41
+
42
+
43
+ def environ_reset():
44
+ for flag in [
45
+ "ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA",
46
+ "ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA",
47
+ "ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA",
48
+ "ORT_DEBUG_NODE_IO_NAME_FILTER",
49
+ "ORT_DEBUG_NODE_IO_OP_TYPE_FILTER",
50
+ "ORT_DEBUG_NODE_IO_DUMP_DATA_TO_FILES",
51
+ "ORT_DEBUG_NODE_IO_OUTPUT_DIR",
52
+ "ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK",
53
+ ]:
54
+ if flag in os.environ:
55
+ del os.environ[flag]
56
+
57
+
58
+ def inference(model_path, dummy_inputs, outputs_path, use_gpu):
59
+ environ_reset()
60
+ environ_setting_nodes()
61
+ environ_setting_paths(outputs_path)
62
+ session = create_onnxruntime_session(model_path, use_gpu, enable_all_optimization=False)
63
+ Gpt2Helper.onnxruntime_inference(session, dummy_inputs)
64
+
65
+
66
+ def generate_outputs_files(model_path, dummy_inputs, outputs_path, use_gpu):
67
+ dir_path = Path(outputs_path)
68
+ if dir_path.exists() and dir_path.is_dir():
69
+ import shutil
70
+
71
+ shutil.rmtree(outputs_path)
72
+ dir_path.mkdir(parents=True, exist_ok=True)
73
+
74
+ process = multiprocessing.Process(target=inference, args=(model_path, dummy_inputs, outputs_path, use_gpu))
75
+ process.start()
76
+ process.join()
77
+
78
+
79
+ def post_processing(outputs_path, outputs_path_other):
80
+ # Compare outputs with e.g. fp16 and fp32
81
+ record = {}
82
+ if_close = {}
83
+
84
+ import glob
85
+
86
+ for filename in glob.glob(os.path.join(outputs_path, "*.tensorproto")):
87
+ filename_other = os.path.join(outputs_path_other, Path(filename).name)
88
+ if not os.path.exists(filename_other):
89
+ continue
90
+ with open(filename, "rb") as f:
91
+ tensor = TensorProto()
92
+ tensor.ParseFromString(f.read())
93
+ array = numpy_helper.to_array(tensor)
94
+ with open(filename_other, "rb") as f: # noqa: PLW2901
95
+ tensor_other = TensorProto()
96
+ tensor_other.ParseFromString(f.read())
97
+ array_other = numpy_helper.to_array(tensor_other)
98
+ if array_other.size == 0:
99
+ continue
100
+ diff = numpy.average(numpy.abs(array_other - array) / (numpy.abs(array_other) + 1e-6))
101
+ if math.isnan(diff):
102
+ continue
103
+ record[Path(filename).name.split(".")[0]] = diff
104
+ if_close[Path(filename).name.split(".")[0]] = numpy.allclose(array, array_other, rtol=1e-04, atol=1e-04)
105
+
106
+ results = ["Node\tDiff\tClose"]
107
+ for k, v in sorted(record.items(), key=lambda x: x[1], reverse=True):
108
+ results.append(f"{k}\t{v}\t{if_close[k]}")
109
+ for line in results:
110
+ print(line)
111
+
112
+
113
+ if __name__ == "__main__":
114
+ # Below example shows how to use this helper to investigate parity issue of gpt-2 fp32 and fp16 onnx model
115
+ # Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON !!
116
+ multiprocessing.set_start_method("spawn")
117
+
118
+ # Generate Inputs
119
+ sequence_length = 8
120
+ past_sequence_length = 8
121
+ batch_size = 5
122
+ dummy_inputs_fp16 = Gpt2Helper.get_dummy_inputs(
123
+ batch_size,
124
+ past_sequence_length,
125
+ sequence_length,
126
+ 12,
127
+ 768,
128
+ 12,
129
+ 50257,
130
+ device=torch.device("cpu"),
131
+ float16=True,
132
+ )
133
+ dummy_inputs_fp32 = dummy_inputs_fp16.to_fp32()
134
+
135
+ # Get GPT-2 model from huggingface using convert_to_onnx.py
136
+ os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp32.onnx -o -p fp32 --use_gpu")
137
+ os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp16.onnx -o -p fp16 --use_gpu")
138
+
139
+ # Specify the directory to dump the node's I/O
140
+ outputs_path_fp32_gpu = "./fp32_gpu"
141
+ outputs_path_fp16_gpu = "./fp16_gpu"
142
+ generate_outputs_files("./gpt2_fp32.onnx", dummy_inputs_fp32, outputs_path_fp32_gpu, use_gpu=True)
143
+ generate_outputs_files("./gpt2_fp16.onnx", dummy_inputs_fp16, outputs_path_fp16_gpu, use_gpu=True)
144
+
145
+ # Compare each node's I/O value and sort based on average rtol
146
+ post_processing(outputs_path_fp16_gpu, outputs_path_fp32_gpu)
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)