onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,76 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+
10
+ import onnx
11
+ import torch
12
+ from transformers.modeling_utils import Conv1D
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _conv1d_to_linear(module):
18
+ in_size, out_size = module.weight.shape
19
+ linear = torch.nn.Linear(in_size, out_size)
20
+ linear.weight.data = module.weight.data.T.contiguous()
21
+ linear.bias.data = module.bias.data
22
+ return linear
23
+
24
+
25
+ def conv1d_to_linear(model):
26
+ """in-place
27
+ This is for Dynamic Quantization, as Conv1D is not recognized by PyTorch, convert it to nn.Linear
28
+ """
29
+ logger.debug("replace Conv1D with Linear")
30
+ for name in list(model._modules):
31
+ module = model._modules[name]
32
+ if isinstance(module, Conv1D):
33
+ linear = _conv1d_to_linear(module)
34
+ model._modules[name] = linear
35
+ else:
36
+ conv1d_to_linear(module)
37
+
38
+
39
+ def _get_size_of_pytorch_model(model):
40
+ torch.save(model.state_dict(), "temp.p")
41
+ size = os.path.getsize("temp.p") / (1024 * 1024)
42
+ os.remove("temp.p")
43
+ return size
44
+
45
+
46
+ class QuantizeHelper:
47
+ @staticmethod
48
+ def quantize_torch_model(model, dtype=torch.qint8):
49
+ """
50
+ Usage: model = quantize_model(model)
51
+
52
+ TODO: mix of in-place and return, but results are different
53
+ """
54
+ conv1d_to_linear(model)
55
+ quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=dtype)
56
+ logger.info(f"Size of full precision Torch model(MB):{_get_size_of_pytorch_model(model)}")
57
+ logger.info(f"Size of quantized Torch model(MB):{_get_size_of_pytorch_model(quantized_model)}")
58
+ return quantized_model
59
+
60
+ @staticmethod
61
+ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data_format=False):
62
+ from pathlib import Path
63
+
64
+ from onnxruntime.quantization import quantize_dynamic
65
+
66
+ Path(quantized_model_path).parent.mkdir(parents=True, exist_ok=True)
67
+ logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path)/(1024*1024)}")
68
+ quantize_dynamic(
69
+ onnx_model_path,
70
+ quantized_model_path,
71
+ use_external_data_format=use_external_data_format,
72
+ extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
73
+ )
74
+ logger.info(f"quantized model saved to:{quantized_model_path}")
75
+ # TODO: inlcude external data in total model size.
76
+ logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path)/(1024*1024)}")
@@ -0,0 +1,122 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+ from typing import Dict
10
+
11
+ # In ORT Package the symbolic_shape_infer.py is in ../tools
12
+ file_path = os.path.dirname(__file__)
13
+ if os.path.exists(os.path.join(file_path, "../tools/symbolic_shape_infer.py")):
14
+ sys.path.append(os.path.join(file_path, "../tools"))
15
+ else:
16
+ sys.path.append(os.path.join(file_path, ".."))
17
+
18
+ from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy # noqa: E402
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class SymbolicShapeInferenceHelper(SymbolicShapeInference):
24
+ def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False):
25
+ super().__init__(int_max, auto_merge, guess_output_rank, verbose)
26
+ self.model_ = model
27
+ self.all_shapes_inferred_: bool = False
28
+ self.is_inferred_: bool = False
29
+ self.dynamic_axis_mapping_: Dict[str, int] = {}
30
+
31
+ def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 200):
32
+ """Run shape inference, and try replace dynamic axis from string to integer when mapping is provided.
33
+
34
+ Args:
35
+ dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4}
36
+ max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200.
37
+
38
+ Returns:
39
+ bool: whether all shapes has been inferred or not.
40
+ """
41
+ assert dynamic_axis_mapping is not None
42
+
43
+ if self.is_inferred_ and self.dynamic_axis_mapping_ == dynamic_axis_mapping:
44
+ return self.all_shapes_inferred_
45
+
46
+ self.dynamic_axis_mapping_ = dynamic_axis_mapping
47
+
48
+ self._preprocess(self.model_)
49
+
50
+ count = 0
51
+ while self.run_:
52
+ logger.debug(f"shape infer run {count}")
53
+ self.all_shapes_inferred_ = self._infer_impl()
54
+ count += 1
55
+ if max_runs > 0 and count >= max_runs:
56
+ break
57
+
58
+ self.is_inferred_ = True
59
+ return self.all_shapes_inferred_
60
+
61
+ def _get_sympy_shape(self, node, idx):
62
+ """Override it to ensure shape inference by giving the actual value of dynamic axis."""
63
+ sympy_shape = []
64
+
65
+ shape = self._get_shape(node, idx)
66
+ if shape:
67
+ for dim in shape:
68
+ if isinstance(dim, str):
69
+ if dim in self.dynamic_axis_mapping_:
70
+ sympy_shape.append(self.dynamic_axis_mapping_[dim])
71
+ elif dim in self.symbolic_dims_:
72
+ sympy_shape.append(self.symbolic_dims_[dim])
73
+ else:
74
+ sympy_shape.append(sympy.Symbol(dim, integer=True))
75
+ else:
76
+ assert dim is not None
77
+ sympy_shape.append(dim)
78
+ return sympy_shape
79
+
80
+ def get_edge_shape(self, edge):
81
+ """Get shape of an edge.
82
+
83
+ Args:
84
+ edge (str): name of edge
85
+
86
+ Returns:
87
+ Optional[List[int]]: the shape, or None if shape is unknown
88
+ """
89
+ assert self.all_shapes_inferred_
90
+ if edge not in self.known_vi_:
91
+ print("Cannot retrieve the shape of " + str(edge))
92
+ return None
93
+
94
+ type_proto = self.known_vi_[edge].type
95
+ shape = get_shape_from_type_proto(type_proto)
96
+
97
+ if shape is not None:
98
+ for i, dim in enumerate(shape):
99
+ if isinstance(dim, str) and dim in self.dynamic_axis_mapping_:
100
+ shape[i] = self.dynamic_axis_mapping_[dim]
101
+
102
+ return shape
103
+
104
+ def compare_shape(self, edge, edge_other):
105
+ """Compare shape of two edges.
106
+
107
+ Args:
108
+ edge (str): name of edge
109
+ edge_other (str): name of another edge
110
+
111
+ Raises:
112
+ Exception: At least one shape is missed for edges to compare
113
+
114
+ Returns:
115
+ bool: whether the shape is same or not
116
+ """
117
+ assert self.all_shapes_inferred_
118
+ shape = self.get_edge_shape(edge)
119
+ shape_other = self.get_edge_shape(edge_other)
120
+ if shape is None or shape_other is None:
121
+ raise Exception("At least one shape is missed for edges to compare")
122
+ return shape == shape_other
@@ -0,0 +1,401 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ # This tool is not used directly in bert optimization. It could assist developing the optimization script on the following scenarios:
7
+ # (1) It could simplify graph by removing many sub-graphs related to reshape.
8
+ # (2) It could reduce extra inputs and outputs to fit other tools. The script compare_bert_results.py or bert_perf_test.py requires 3 inputs.
9
+
10
+ import argparse
11
+ import logging
12
+ import os
13
+ import re # noqa: F401
14
+ import sys
15
+ import tempfile
16
+ from collections import deque # noqa: F401
17
+ from datetime import datetime
18
+ from pathlib import Path # noqa: F401
19
+ from typing import List, Optional
20
+
21
+ import numpy as np
22
+ import onnx
23
+ from onnx import ModelProto, TensorProto, numpy_helper
24
+ from onnx_model import OnnxModel
25
+
26
+ import onnxruntime
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ CONSTANT_SHAPE_NAME_PREFIX = "constant_shape_opt__"
31
+ RESHAPE_INPUT_SHAPE_PREFIX = "reshape_input_shape__"
32
+
33
+
34
+ class BertOnnxModelShapeOptimizer(OnnxModel):
35
+ """
36
+ This optimizer will replace Shape output or the shape input of Reshape node by initializer. Currently, it requires
37
+ model inputs to have static shape.
38
+ """
39
+
40
+ def __init__(self, onnx_model):
41
+ super().__init__(onnx_model.model)
42
+
43
+ def add_shape_initializer(self, shape):
44
+ """
45
+ Add an initializer for constant shape.
46
+ """
47
+ shape_value = np.asarray(shape, dtype=np.int64)
48
+ constant_shape_name = self.create_node_name("Constant", CONSTANT_SHAPE_NAME_PREFIX)
49
+ tensor = onnx.helper.make_tensor(
50
+ name=constant_shape_name,
51
+ data_type=TensorProto.INT64,
52
+ dims=shape_value.shape,
53
+ vals=shape_value,
54
+ )
55
+ self.add_initializer(tensor)
56
+ return tensor
57
+
58
+ def get_shape_outputs(self):
59
+ """
60
+ Returns a list of output names of all Shape nodes.
61
+ """
62
+ input_name_to_nodes = self.input_name_to_nodes()
63
+
64
+ outputs = []
65
+ for node in self.model.graph.node:
66
+ if node.op_type == "Shape":
67
+ if node.output[0] in input_name_to_nodes:
68
+ outputs.append(node.output[0])
69
+
70
+ return outputs
71
+
72
+ def get_reshape_shape_inputs(self):
73
+ """
74
+ Returns a list of shape input names of Reshape nodes.
75
+ """
76
+ self.output_name_to_node()
77
+
78
+ shape_inputs = []
79
+ for node in self.model.graph.node:
80
+ if node.op_type == "Reshape":
81
+ shape_inputs.append(node.input[1])
82
+
83
+ return shape_inputs
84
+
85
+ def add_shape_for_reshape_input(self):
86
+ """
87
+ For each Reshape node, create a Shape node for its first input.
88
+ Returns the output names of these Shape nodes.
89
+ """
90
+ output_names = []
91
+ nodes_to_add = []
92
+ for node in self.model.graph.node:
93
+ if node.op_type == "Reshape":
94
+ input = node.input[0]
95
+ output_name = self.create_node_name("Reshape_Input", RESHAPE_INPUT_SHAPE_PREFIX)
96
+ shape_node = onnx.helper.make_node("Shape", inputs=[input], outputs=[output_name])
97
+ nodes_to_add.append(shape_node)
98
+ output_names.append(output_name)
99
+
100
+ self.add_nodes(nodes_to_add)
101
+ return output_names
102
+
103
+ def add_extra_graph_output(self, extra_outputs):
104
+ """
105
+ Add a list of output names to graph output.
106
+ """
107
+ names_to_evaluate = []
108
+ output_names = [output.name for output in self.model.graph.output]
109
+ for name in extra_outputs:
110
+ if self.get_initializer(name) is not None: # already a constant
111
+ continue
112
+ names_to_evaluate.append(name)
113
+
114
+ if name not in output_names:
115
+ output_info = onnx.helper.ValueInfoProto()
116
+ output_info.name = name
117
+ self.model.graph.output.extend([output_info])
118
+ output_names.append(name)
119
+
120
+ return names_to_evaluate
121
+
122
+ # Update input and output shape to be static
123
+ def use_static_input(self, inputs, batch_size=1, max_seq_len=128):
124
+ """
125
+ Update the model to use static axes instead of dynamic axes for graph inputs.
126
+ """
127
+ for input in self.model.graph.input:
128
+ if input.name in inputs:
129
+ dim_proto = input.type.tensor_type.shape.dim[0]
130
+ dim_proto.dim_value = batch_size
131
+ dim_proto = input.type.tensor_type.shape.dim[1]
132
+ if dim_proto.HasField("dim_param"):
133
+ dim_proto.dim_value = max_seq_len
134
+ elif dim_proto.HasField("dim_value") and dim_proto.dim_value != max_seq_len:
135
+ raise ValueError(
136
+ f"Unable to set dimension value to {max_seq_len} for axis {1} of {input.name}. Contradicts existing dimension value {dim_proto.dim_value}."
137
+ )
138
+
139
+ def create_dummy_inputs(
140
+ self,
141
+ input_ids,
142
+ segment_ids,
143
+ input_mask,
144
+ batch_size,
145
+ sequence_length,
146
+ elem_type,
147
+ dictionary_size=8,
148
+ ):
149
+ """
150
+ Create dummy data for model inputs. If the model has more than 3 inputs, please update this function accordingly before running the tool.
151
+ """
152
+ assert elem_type in [1, 6, 7] # only int32, int64 and float32 are supported.
153
+
154
+ # Create dummy inputs
155
+ input_1 = np.random.randint(dictionary_size, size=(batch_size, sequence_length), dtype=np.int32)
156
+ input_2 = np.ones((batch_size, sequence_length), dtype=np.int32)
157
+ input_3 = np.zeros((batch_size, sequence_length), dtype=np.int32)
158
+
159
+ # Here we assume that 3 inputs have same data type
160
+ if elem_type == 1: # float32
161
+ input_1 = np.float32(input_1)
162
+ input_2 = np.float32(input_2)
163
+ input_3 = np.float32(input_3)
164
+ elif elem_type == 7: # int64
165
+ input_1 = np.int64(input_1)
166
+ input_2 = np.int64(input_2)
167
+ input_3 = np.int64(input_3)
168
+
169
+ inputs = {input_ids: input_1, input_mask: input_2, segment_ids: input_3}
170
+ return inputs
171
+
172
+ def shape_optimization(
173
+ self,
174
+ temp_model_path,
175
+ input_ids,
176
+ segment_ids,
177
+ input_mask,
178
+ output_names,
179
+ batch_size,
180
+ sequence_length,
181
+ enable_shape_opt,
182
+ enable_reshape_opt,
183
+ verbose,
184
+ ):
185
+ self.bert_inputs = [input_ids, segment_ids, input_mask]
186
+
187
+ extra_outputs = []
188
+ if enable_shape_opt:
189
+ extra_outputs.extend(self.get_shape_outputs())
190
+
191
+ if enable_reshape_opt:
192
+ reshape_shape_inputs = self.get_reshape_shape_inputs()
193
+ reshape_input_shapes = self.add_shape_for_reshape_input()
194
+ extra_outputs.extend(reshape_shape_inputs)
195
+ extra_outputs.extend(reshape_input_shapes)
196
+
197
+ if len(extra_outputs) == 0:
198
+ return
199
+
200
+ names_to_evaluate = self.add_extra_graph_output(extra_outputs)
201
+
202
+ # This tool does not support dynamic axes right now.
203
+ self.use_static_input(self.bert_inputs, batch_size, sequence_length)
204
+
205
+ with open(temp_model_path, "wb") as out:
206
+ out.write(self.model.SerializeToString())
207
+ sess_options = onnxruntime.SessionOptions()
208
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
209
+ session = onnxruntime.InferenceSession(
210
+ temp_model_path,
211
+ sess_options,
212
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
213
+ )
214
+
215
+ elem_type = 7
216
+ for input in self.model.graph.input:
217
+ if input.name == input_ids:
218
+ elem_type = input.type.tensor_type.elem_type
219
+ inputs = self.create_dummy_inputs(input_ids, segment_ids, input_mask, batch_size, sequence_length, elem_type)
220
+
221
+ outputs = session.run(names_to_evaluate, inputs)
222
+ shapes = {}
223
+ for i, name in enumerate(names_to_evaluate):
224
+ shapes[name] = outputs[i]
225
+
226
+ logger.debug(f"shapes={shapes}")
227
+
228
+ if enable_reshape_opt:
229
+ for i, shape_input in enumerate(reshape_shape_inputs):
230
+ input_shape = reshape_input_shapes[i]
231
+ self.update_target_shape(shapes, shape_input, input_shape, verbose)
232
+
233
+ for name, shape in shapes.items():
234
+ tensor = self.add_shape_initializer(shape)
235
+ self.replace_input_of_all_nodes(name, tensor.name)
236
+
237
+ # Remove extra outputs, and prune all nodes not linked to output.
238
+ self.prune_graph(output_names)
239
+
240
+ def update_target_shape(self, shapes, shape_input, input_shape, verbose):
241
+ """
242
+ Update the target shape to use 0 to represent that dimension value does not change.
243
+ For example, shape of source data is (2, 5, 8) and target shape is (2, 5, 4, 2), the target shape will be updated to (0, 0, 4, 2).
244
+ """
245
+ if shape_input in shapes:
246
+ target_shape = shapes[shape_input]
247
+ else:
248
+ initializer = self.get_initializer(shape_input)
249
+ assert initializer is not None
250
+ target_shape = numpy_helper.to_array(initializer)
251
+
252
+ if input_shape in shapes:
253
+ source_shape = shapes[input_shape]
254
+ else:
255
+ initializer = self.get_initializer(input_shape)
256
+ assert initializer is not None
257
+ source_shape = numpy_helper.to_array(initializer)
258
+
259
+ new_target_shape = []
260
+ for i, dim_value in enumerate(target_shape):
261
+ if i < len(source_shape) and source_shape[i] == dim_value:
262
+ new_target_shape.append(0)
263
+ else:
264
+ new_target_shape.append(dim_value)
265
+ shapes[shape_input] = new_target_shape
266
+
267
+ logger.debug(f"source_shape={source_shape}, target_shape={target_shape}, new_target_shape={new_target_shape}")
268
+
269
+ def validate_input(self, input: str):
270
+ if not self.find_graph_input(input):
271
+ valid_names = [input.name for input in self.model.graph.input]
272
+ raise Exception(f"Input {input} does not exist in the graph inputs: {valid_names}")
273
+
274
+ def validate_outputs(self, output_names: List[str]):
275
+ valid_names = [output.name for output in self.model.graph.output]
276
+ for name in output_names:
277
+ if name not in valid_names:
278
+ raise Exception(f"Output {name} does not exist in the graph outputs: {valid_names}")
279
+
280
+ def optimize(
281
+ self,
282
+ output_path: str,
283
+ input_ids: str,
284
+ segment_ids: str,
285
+ input_mask: str,
286
+ enable_shape_opt: bool,
287
+ enable_reshape_opt: bool,
288
+ output_names: Optional[List[str]] = None,
289
+ batch_size=1,
290
+ sequence_length=128,
291
+ verbose=False,
292
+ ):
293
+ # Skip if shape optimization has been done before.
294
+ for tensor in self.model.graph.initializer:
295
+ if tensor.name.startswith(CONSTANT_SHAPE_NAME_PREFIX):
296
+ logger.info("Skip shape optimization since it has been done before")
297
+ return
298
+
299
+ self.validate_input(input_ids)
300
+ self.validate_input(segment_ids)
301
+ self.validate_input(input_mask)
302
+
303
+ if output_names is not None:
304
+ self.validate_outputs(output_names)
305
+ self.prune_graph(output_names)
306
+
307
+ remaining_outputs = [output.name for output in self.model.graph.output]
308
+
309
+ if enable_shape_opt or enable_reshape_opt:
310
+ if len(self.get_graph_inputs_excluding_initializers()) != 3:
311
+ logger.info("Skip shape optimization since graph input number is not 3")
312
+ return
313
+
314
+ with tempfile.TemporaryDirectory() as temp_dir:
315
+ temp_file_name = "temp_{}.onnx".format(datetime.now().strftime("%m_%d-%H_%M_%S"))
316
+ dir = "." if verbose else temp_dir
317
+ temp_file = os.path.join(dir, temp_file_name)
318
+ self.shape_optimization(
319
+ temp_file,
320
+ input_ids,
321
+ segment_ids,
322
+ input_mask,
323
+ remaining_outputs,
324
+ batch_size,
325
+ sequence_length,
326
+ enable_shape_opt,
327
+ enable_reshape_opt,
328
+ verbose,
329
+ )
330
+ logger.debug(f"Temp model with additional outputs: {temp_file}")
331
+ logger.warning(
332
+ f"Shape optimization is done. The optimized model might only work for input with batch_size={batch_size} sequence_length={sequence_length}"
333
+ )
334
+
335
+ if output_path is not None:
336
+ with open(output_path, "wb") as out:
337
+ out.write(self.model.SerializeToString())
338
+
339
+
340
+ def parse_arguments():
341
+ parser = argparse.ArgumentParser()
342
+ parser.add_argument("--input", required=True, type=str)
343
+ parser.add_argument("--output", required=True, type=str)
344
+ parser.add_argument("--input_ids", required=True, type=str)
345
+ parser.add_argument("--segment_ids", required=True, type=str)
346
+ parser.add_argument("--input_mask", required=True, type=str)
347
+ parser.add_argument("--output_names", required=False, type=str, default=None)
348
+ parser.add_argument("--batch_size", required=False, type=int, default=1)
349
+ parser.add_argument("--sequence_length", required=False, type=int, default=128)
350
+ parser.add_argument("--enable_shape_opt", required=False, action="store_true")
351
+ parser.set_defaults(enable_shape_opt=False)
352
+ parser.add_argument("--enable_reshape_opt", required=False, action="store_true")
353
+ parser.set_defaults(enable_reshape_opt=False)
354
+ parser.add_argument("--verbose", required=False, action="store_true")
355
+ parser.set_defaults(verbose=False)
356
+ args = parser.parse_args()
357
+ return args
358
+
359
+
360
+ def setup_logging(verbose):
361
+ log_handler = logging.StreamHandler(sys.stdout)
362
+ if verbose:
363
+ log_handler.setFormatter(logging.Formatter("[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s"))
364
+ logging_level = logging.DEBUG
365
+ else:
366
+ log_handler.setFormatter(logging.Formatter("%(filename)20s: %(message)s"))
367
+ logging_level = logging.INFO
368
+ log_handler.setLevel(logging_level)
369
+ logger.addHandler(log_handler)
370
+ logger.setLevel(logging_level)
371
+
372
+
373
+ def main():
374
+ args = parse_arguments()
375
+ setup_logging(args.verbose)
376
+
377
+ output_names = None if args.output_names is None else args.output_names.split(";")
378
+
379
+ model = ModelProto()
380
+ with open(args.input, "rb") as input_file:
381
+ model.ParseFromString(input_file.read())
382
+ onnx_model = OnnxModel(model)
383
+
384
+ optimizer = BertOnnxModelShapeOptimizer(onnx_model)
385
+
386
+ optimizer.optimize(
387
+ args.output,
388
+ args.input_ids,
389
+ args.segment_ids,
390
+ args.input_mask,
391
+ args.enable_shape_opt,
392
+ args.enable_reshape_opt,
393
+ output_names,
394
+ args.batch_size,
395
+ args.sequence_length,
396
+ args.verbose,
397
+ )
398
+
399
+
400
+ if __name__ == "__main__":
401
+ main()
@@ -0,0 +1,74 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import torch
7
+ from torch._C._onnx import OperatorExportTypes
8
+
9
+ TrainingMode = torch.onnx.TrainingMode
10
+ from packaging.version import Version # noqa: E402
11
+
12
+
13
+ def torch_onnx_export(
14
+ model,
15
+ args,
16
+ f,
17
+ export_params=True,
18
+ verbose=False,
19
+ training=TrainingMode.EVAL,
20
+ input_names=None,
21
+ output_names=None,
22
+ operator_export_type=OperatorExportTypes.ONNX,
23
+ opset_version=None,
24
+ _retain_param_name=None,
25
+ do_constant_folding=True,
26
+ example_outputs=None,
27
+ strip_doc_string=None,
28
+ dynamic_axes=None,
29
+ keep_initializers_as_inputs=None,
30
+ custom_opsets=None,
31
+ enable_onnx_checker=None,
32
+ use_external_data_format=None,
33
+ export_modules_as_functions=False,
34
+ ):
35
+ if Version(torch.__version__) >= Version("1.11.0"):
36
+ torch.onnx.export(
37
+ model=model,
38
+ args=args,
39
+ f=f,
40
+ export_params=export_params,
41
+ verbose=verbose,
42
+ training=training,
43
+ input_names=input_names,
44
+ output_names=output_names,
45
+ operator_export_type=operator_export_type,
46
+ opset_version=opset_version,
47
+ do_constant_folding=do_constant_folding,
48
+ dynamic_axes=dynamic_axes,
49
+ keep_initializers_as_inputs=keep_initializers_as_inputs,
50
+ custom_opsets=custom_opsets,
51
+ export_modules_as_functions=export_modules_as_functions,
52
+ )
53
+ else:
54
+ torch.onnx.export(
55
+ model=model,
56
+ args=args,
57
+ f=f,
58
+ export_params=export_params,
59
+ verbose=verbose,
60
+ training=training,
61
+ input_names=input_names,
62
+ output_names=output_names,
63
+ operator_export_type=operator_export_type,
64
+ opset_version=opset_version,
65
+ _retain_param_name=_retain_param_name,
66
+ do_constant_folding=do_constant_folding,
67
+ example_outputs=example_outputs,
68
+ strip_doc_string=strip_doc_string,
69
+ dynamic_axes=dynamic_axes,
70
+ keep_initializers_as_inputs=keep_initializers_as_inputs,
71
+ custom_opsets=custom_opsets,
72
+ enable_onnx_checker=enable_onnx_checker,
73
+ use_external_data_format=use_external_data_format,
74
+ )