onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,205 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import glob
7
+ import os
8
+
9
+ import requests
10
+
11
+ TFMODELS = {
12
+ "bert-base-uncased": (
13
+ "bert",
14
+ "BertConfig",
15
+ "",
16
+ "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip",
17
+ ),
18
+ "bert-base-cased": (
19
+ "bert",
20
+ "BertConfig",
21
+ "",
22
+ "https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip",
23
+ ),
24
+ "bert-large-uncased": (
25
+ "bert",
26
+ "BertConfig",
27
+ "",
28
+ "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip",
29
+ ),
30
+ "albert-base": (
31
+ "albert",
32
+ "AlbertConfig",
33
+ "",
34
+ "https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz",
35
+ ),
36
+ "albert-large": (
37
+ "albert",
38
+ "AlbertConfig",
39
+ "",
40
+ "https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz",
41
+ ),
42
+ "gpt-2-117M": (
43
+ "gpt2",
44
+ "GPT2Config",
45
+ "GPT2Model",
46
+ "https://storage.googleapis.com/gpt-2/models/117M",
47
+ ),
48
+ "gpt-2-124M": (
49
+ "gpt2",
50
+ "GPT2Config",
51
+ "GPT2Model",
52
+ "https://storage.googleapis.com/gpt-2/models/124M",
53
+ ),
54
+ }
55
+
56
+
57
+ def download_compressed_file(tf_ckpt_url, ckpt_dir):
58
+ r = requests.get(tf_ckpt_url)
59
+ compressed_file_name = tf_ckpt_url.split("/")[-1]
60
+ compressed_file_dir = os.path.join(ckpt_dir, compressed_file_name)
61
+ with open(compressed_file_dir, "wb") as f:
62
+ f.write(r.content)
63
+ return compressed_file_dir
64
+
65
+
66
+ def get_ckpt_prefix_path(ckpt_dir):
67
+ # get prefix
68
+ sub_folder_dir = None
69
+ for o in os.listdir(ckpt_dir):
70
+ sub_folder_dir = os.path.join(ckpt_dir, o)
71
+ break
72
+ if os.path.isfile(sub_folder_dir):
73
+ sub_folder_dir = ckpt_dir
74
+ unique_file_name = str(glob.glob(sub_folder_dir + "/*data-00000-of-00001"))
75
+ prefix = (unique_file_name.rpartition(".")[0]).split("/")[-1]
76
+
77
+ return os.path.join(sub_folder_dir, prefix)
78
+
79
+
80
+ def download_tf_checkpoint(model_name, tf_models_dir="tf_models"):
81
+ import pathlib
82
+
83
+ base_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), tf_models_dir)
84
+ ckpt_dir = os.path.join(base_dir, model_name)
85
+
86
+ if not os.path.exists(ckpt_dir):
87
+ os.makedirs(ckpt_dir)
88
+
89
+ tf_ckpt_url = TFMODELS[model_name][3]
90
+
91
+ import re
92
+
93
+ if re.search(".zip$", tf_ckpt_url) is not None:
94
+ zip_dir = download_compressed_file(tf_ckpt_url, ckpt_dir)
95
+
96
+ # unzip file
97
+ import zipfile
98
+
99
+ with zipfile.ZipFile(zip_dir, "r") as zip_ref:
100
+ zip_ref.extractall(ckpt_dir)
101
+ os.remove(zip_dir)
102
+
103
+ return get_ckpt_prefix_path(ckpt_dir)
104
+
105
+ elif re.search(".tar.gz$", tf_ckpt_url) is not None:
106
+ tar_dir = download_compressed_file(tf_ckpt_url, ckpt_dir)
107
+
108
+ # untar file
109
+ import tarfile
110
+
111
+ with tarfile.open(tar_dir, "r") as tar_ref:
112
+ tar_ref.extractall(ckpt_dir)
113
+ os.remove(tar_dir)
114
+
115
+ return get_ckpt_prefix_path(ckpt_dir)
116
+
117
+ else:
118
+ for filename in [
119
+ "checkpoint",
120
+ "model.ckpt.data-00000-of-00001",
121
+ "model.ckpt.index",
122
+ "model.ckpt.meta",
123
+ ]:
124
+ r = requests.get(tf_ckpt_url + "/" + filename)
125
+ with open(os.path.join(ckpt_dir, filename), "wb") as f:
126
+ f.write(r.content)
127
+
128
+ return get_ckpt_prefix_path(ckpt_dir)
129
+
130
+
131
+ def init_pytorch_model(model_name, tf_checkpoint_path):
132
+ config_name = TFMODELS[model_name][1]
133
+ config_module = __import__("transformers", fromlist=[config_name])
134
+ model_config = getattr(config_module, config_name)
135
+
136
+ parent_path = tf_checkpoint_path.rpartition("/")[0]
137
+ config_path = glob.glob(parent_path + "/*config.json")
138
+ config = model_config() if len(config_path) == 0 else model_config.from_json_file(str(config_path[0]))
139
+
140
+ if not TFMODELS[model_name][2]:
141
+ from transformers import AutoModelForPreTraining
142
+
143
+ init_model = AutoModelForPreTraining.from_config(config)
144
+ else:
145
+ model_categroy_name = TFMODELS[model_name][2]
146
+ module = __import__("transformers", fromlist=[model_categroy_name])
147
+ model_categroy = getattr(module, model_categroy_name)
148
+ init_model = model_categroy(config)
149
+ return config, init_model
150
+
151
+
152
+ def convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoint_path, is_tf2):
153
+ load_tf_weight_func_name = "load_tf_weights_in_" + TFMODELS[model_name][0]
154
+
155
+ module = __import__("transformers", fromlist=[load_tf_weight_func_name])
156
+
157
+ if is_tf2 is False:
158
+ load_tf_weight_func = getattr(module, load_tf_weight_func_name)
159
+ else:
160
+ if TFMODELS[model_name][0] != "bert":
161
+ raise NotImplementedError("Only support tf2 ckeckpoint for Bert model")
162
+ from transformers import convert_bert_original_tf2_checkpoint_to_pytorch
163
+
164
+ load_tf_weight_func = convert_bert_original_tf2_checkpoint_to_pytorch.load_tf2_weights_in_bert
165
+
166
+ # Expect transformers team will unify the order of signature in the future
167
+ model = (
168
+ load_tf_weight_func(init_model, config, tf_checkpoint_path)
169
+ if is_tf2 is False
170
+ else load_tf_weight_func(init_model, tf_checkpoint_path, config)
171
+ )
172
+ model.eval()
173
+ return model
174
+
175
+
176
+ def tf2pt_pipeline(model_name, is_tf2=False):
177
+ if model_name not in TFMODELS:
178
+ raise NotImplementedError(model_name + " not implemented")
179
+ tf_checkpoint_path = download_tf_checkpoint(model_name)
180
+ config, init_model = init_pytorch_model(model_name, tf_checkpoint_path)
181
+ model = convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoint_path, is_tf2)
182
+ # Could then use the model in Benchmark
183
+ return config, model
184
+
185
+
186
+ def tf2pt_pipeline_test():
187
+ # For test on linux only
188
+ import logging
189
+
190
+ import torch
191
+
192
+ logger = logging.getLogger("")
193
+ for model_name in TFMODELS:
194
+ config, model = tf2pt_pipeline(model_name)
195
+ assert config.model_type is TFMODELS[model_name][0]
196
+
197
+ input = torch.randint(low=0, high=config.vocab_size - 1, size=(4, 128), dtype=torch.long)
198
+ try:
199
+ model(input)
200
+ except RuntimeError as e:
201
+ logger.exception(e)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ tf2pt_pipeline_test()
@@ -0,0 +1,387 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import argparse
7
+ import logging
8
+ import os
9
+ from typing import List, Union
10
+
11
+ import coloredlogs
12
+ from constants import (
13
+ AttentionInputIDs,
14
+ AttentionOutputIDs,
15
+ MultiHeadAttentionInputIDs,
16
+ MultiHeadAttentionOutputIDs,
17
+ Operators,
18
+ )
19
+ from onnx import helper, load_model
20
+ from onnx_model import NodeProto, OnnxModel
21
+ from shape_infer_helper import SymbolicShapeInferenceHelper
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class PackingAttentionBase:
27
+ def __init__(self, model: OnnxModel, attention_op_type: str):
28
+ self.model: OnnxModel = model
29
+ self.nodes_to_remove: List = []
30
+ self.nodes_to_add: List = []
31
+ self.prune_graph: bool = False
32
+ self.node_name_to_graph_name: dict = {}
33
+ self.this_graph_name: str = self.model.model.graph.name
34
+ self.attention_op_type = attention_op_type
35
+ self.attention_nodes = self.model.get_nodes_by_op_type(attention_op_type)
36
+
37
+ def _try_getting_attention_mask(self) -> Union[str, None]:
38
+ mask_index = (
39
+ AttentionInputIDs.MASK_INDEX
40
+ if self.attention_op_type == Operators.ATTENTION
41
+ else MultiHeadAttentionInputIDs.KEY_PADDING_MASK
42
+ )
43
+ first_attention_node = self._try_getting_first_attention()
44
+ # check if attention has mask
45
+ if not first_attention_node or len(first_attention_node.input) <= mask_index:
46
+ return None
47
+
48
+ attention_mask = first_attention_node.input[mask_index]
49
+
50
+ # check if all attention nodes have same mask
51
+ for node in self.attention_nodes:
52
+ if len(node.input) <= mask_index or node.input[mask_index] != attention_mask:
53
+ return None
54
+
55
+ return attention_mask
56
+
57
+ def _try_getting_first_attention(self) -> Union[NodeProto, None]:
58
+ if len(self.attention_nodes) <= 0:
59
+ return None
60
+
61
+ return self.attention_nodes[0]
62
+
63
+ def _try_getting_last_layernorm(self) -> Union[NodeProto, None]:
64
+ last_layernorm_node = None
65
+ for node in self.model.nodes():
66
+ if node.op_type == Operators.LAYERNORM or node.op_type == Operators.SKIPLAYERNORM:
67
+ last_layernorm_node = node
68
+ return last_layernorm_node
69
+
70
+ def _are_attentions_supported(self) -> bool:
71
+ raise NotImplementedError()
72
+
73
+ def _insert_removepadding_node(self, inputs: List[str], outputs: List[str]) -> None:
74
+ new_node = helper.make_node(
75
+ Operators.REMOVEPADDING,
76
+ inputs=inputs,
77
+ outputs=outputs,
78
+ name=self.model.create_node_name(Operators.REMOVEPADDING),
79
+ )
80
+
81
+ new_node.domain = "com.microsoft"
82
+ self.nodes_to_add.append(new_node)
83
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
84
+
85
+ def _insert_restorepadding_node(self, inputs: List[str], outputs: List[str]) -> None:
86
+ new_node = helper.make_node(
87
+ Operators.RESTOREPADDING,
88
+ inputs=inputs,
89
+ outputs=outputs,
90
+ name=self.model.create_node_name(Operators.RESTOREPADDING),
91
+ )
92
+
93
+ new_node.domain = "com.microsoft"
94
+ self.nodes_to_add.append(new_node)
95
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
96
+
97
+ def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
98
+ raise NotImplementedError()
99
+
100
+ def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]:
101
+ if self.attention_op_type == Operators.ATTENTION:
102
+ return first_attention_node.input[AttentionInputIDs.INPUT]
103
+ return None
104
+
105
+ def convert(self, use_symbolic_shape_infer: bool = True) -> None:
106
+ logger.debug("start converting to packing model...")
107
+
108
+ if not self._are_attentions_supported():
109
+ return
110
+
111
+ attention_mask = self._try_getting_attention_mask()
112
+ if not attention_mask:
113
+ return
114
+
115
+ first_attention_node = self._try_getting_first_attention()
116
+ last_layernorm_node = self._try_getting_last_layernorm()
117
+ if not last_layernorm_node:
118
+ return
119
+
120
+ # insert RemovePadding
121
+ input_to_remove_padding = self._get_input_to_remove_padding(first_attention_node)
122
+ if not input_to_remove_padding:
123
+ return
124
+
125
+ output_without_padding = input_to_remove_padding + "_no_padding"
126
+ token_offset = input_to_remove_padding + "_token_offset"
127
+ cumulated_seq_len = input_to_remove_padding + "_cumulated_seq_len"
128
+ max_seq_len = input_to_remove_padding + "_max_seq_len"
129
+ self._insert_removepadding_node(
130
+ [input_to_remove_padding, attention_mask],
131
+ [output_without_padding, token_offset, cumulated_seq_len, max_seq_len],
132
+ )
133
+ self.model.replace_input_of_all_nodes(input_to_remove_padding, output_without_padding)
134
+ logger.debug("inserted RemovePadding before Attention")
135
+
136
+ # insert RestorePadding
137
+ restorepadding_input = last_layernorm_node.output[0] + "_restore_input"
138
+ self._insert_restorepadding_node([restorepadding_input, token_offset], [last_layernorm_node.output[0]])
139
+ self.model.replace_output_of_all_nodes(last_layernorm_node.output[0], restorepadding_input)
140
+ logger.debug(f"inserted RestorePadding after last {last_layernorm_node.op_type} layer")
141
+
142
+ # insert PackedAttention
143
+ self._replace_attention_with_packing_attention(token_offset, cumulated_seq_len)
144
+ logger.debug(f"replaced {self.attention_op_type} with Packed{self.attention_op_type}")
145
+
146
+ self.model.remove_nodes(self.nodes_to_remove)
147
+ self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
148
+
149
+ if self.prune_graph:
150
+ self.model.prune_graph()
151
+ elif self.nodes_to_remove or self.nodes_to_add:
152
+ self.model.update_graph()
153
+ self.model.clean_shape_infer()
154
+ if use_symbolic_shape_infer:
155
+ # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
156
+ # are not recognized by onnx shape inference.
157
+ shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model, verbose=0)
158
+ inferred_model = shape_infer_helper.infer_shapes(self.model.model, auto_merge=True, guess_output_rank=False)
159
+ if inferred_model:
160
+ self.model.model = inferred_model
161
+
162
+
163
+ class PackingAttention(PackingAttentionBase):
164
+ def __init__(self, model: OnnxModel):
165
+ super().__init__(model, Operators.ATTENTION)
166
+
167
+ def _are_attentions_supported(self) -> bool:
168
+ for node in self.attention_nodes:
169
+ if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None:
170
+ return False
171
+ if OnnxModel.get_node_attribute(node, "do_rotary") is not None:
172
+ return False
173
+ unidirection_attr = OnnxModel.get_node_attribute(node, "unidirectional")
174
+ if unidirection_attr is not None and unidirection_attr != 0:
175
+ return False
176
+ if len(node.input) > AttentionInputIDs.PAST and not node.input[AttentionInputIDs.PAST]:
177
+ return False
178
+ if (
179
+ len(node.input) > AttentionInputIDs.PAST_SEQUENCE_LENGTH
180
+ and not node.input[AttentionInputIDs.PAST_SEQUENCE_LENGTH]
181
+ ):
182
+ return False
183
+ return True
184
+
185
+ def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
186
+ for attention in self.attention_nodes:
187
+ attention_bias = (
188
+ attention.input[AttentionInputIDs.ATTENTION_BIAS]
189
+ if len(attention.input) > AttentionInputIDs.ATTENTION_BIAS
190
+ else ""
191
+ )
192
+ packed_attention = helper.make_node(
193
+ Operators.PACKEDATTENTION,
194
+ inputs=[
195
+ attention.input[AttentionInputIDs.INPUT],
196
+ attention.input[AttentionInputIDs.WEIGHTS],
197
+ attention.input[AttentionInputIDs.BIAS],
198
+ token_offset,
199
+ cumulative_sequence_length,
200
+ attention_bias,
201
+ ],
202
+ outputs=[attention.output[AttentionOutputIDs.OUTPUT]],
203
+ name=self.model.create_node_name(Operators.PACKEDATTENTION),
204
+ )
205
+
206
+ attributes = []
207
+ for attr in attention.attribute:
208
+ if attr.name in ["num_heads", "qkv_hidden_sizes", "scale"]:
209
+ attributes.append(attr)
210
+
211
+ packed_attention.attribute.extend(attributes)
212
+ packed_attention.domain = "com.microsoft"
213
+ self.nodes_to_add.append(packed_attention)
214
+ self.nodes_to_remove.append(attention)
215
+ self.node_name_to_graph_name[packed_attention.name] = self.this_graph_name
216
+
217
+ logger.info("Converted %d Attention nodes to PackedAttention.", len(self.attention_nodes))
218
+
219
+
220
+ class PackingMultiHeadAttention(PackingAttentionBase):
221
+ def __init__(self, model: OnnxModel):
222
+ super().__init__(model, Operators.MULTI_HEAD_ATTENTION)
223
+
224
+ def _check_empty_input(self, node, index: int, name: str):
225
+ """Check a node does not have given input."""
226
+ if len(node.input) > index:
227
+ if len(node.input[index]) > 0:
228
+ logger.error(f"node input {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
229
+ return False
230
+ return True
231
+
232
+ def _check_empty_output(self, node, index: int, name: str):
233
+ """Check a node does not have given input."""
234
+ if len(node.output) > index:
235
+ if len(node.output[index]) > 0:
236
+ logger.error(f"node output {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
237
+ return False
238
+ return True
239
+
240
+ def _are_attentions_supported(self) -> bool:
241
+ for node in self.attention_nodes:
242
+ for attr in node.attribute:
243
+ if attr.name not in ["num_heads", "mask_filter_value", "scale"]:
244
+ logger.error(f"node attribute {attr.name} is not supported in PackedMultiHeadAttention: {node}")
245
+ return False
246
+
247
+ if node.input[MultiHeadAttentionInputIDs.KEY] and not node.input[MultiHeadAttentionInputIDs.VALUE]:
248
+ logger.error("packed kv format is not supported in PackedMultiHeadAttention")
249
+ return False
250
+
251
+ if not (
252
+ self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_KEY, "past_key")
253
+ and self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_VALUE, "past_key")
254
+ and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_KEY, "present_key")
255
+ and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_VALUE, "present_key")
256
+ ):
257
+ return False
258
+
259
+ return True
260
+
261
+ def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
262
+ gated_relative_pos_bias_count = 0
263
+ for mha in self.attention_nodes:
264
+ attention_bias = (
265
+ mha.input[MultiHeadAttentionInputIDs.ATTENTION_BIAS]
266
+ if len(mha.input) > MultiHeadAttentionInputIDs.ATTENTION_BIAS
267
+ else ""
268
+ )
269
+ packed_mha = helper.make_node(
270
+ Operators.PACKED_MULTI_HEAD_ATTENTION,
271
+ inputs=[
272
+ mha.input[MultiHeadAttentionInputIDs.QUERY],
273
+ mha.input[MultiHeadAttentionInputIDs.KEY],
274
+ mha.input[MultiHeadAttentionInputIDs.VALUE],
275
+ mha.input[MultiHeadAttentionInputIDs.BIAS],
276
+ token_offset,
277
+ cumulative_sequence_length,
278
+ attention_bias,
279
+ ],
280
+ outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]],
281
+ name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION),
282
+ )
283
+
284
+ attributes = []
285
+ for attr in mha.attribute:
286
+ if attr.name in ["num_heads", "mask_filter_value", "scale"]:
287
+ attributes.append(attr)
288
+
289
+ packed_mha.attribute.extend(attributes)
290
+ packed_mha.domain = "com.microsoft"
291
+ self.nodes_to_add.append(packed_mha)
292
+ self.nodes_to_remove.append(mha)
293
+ self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name
294
+
295
+ # Append token_offset input to GatedRelativePositionBias
296
+ if attention_bias:
297
+ rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.ATTENTION_BIAS)
298
+ if (
299
+ rel_pos_bias_node
300
+ and rel_pos_bias_node.op_type == "GatedRelativePositionBias"
301
+ and len(rel_pos_bias_node.input) == 6
302
+ ):
303
+ rel_pos_bias_node.input.append(token_offset)
304
+ gated_relative_pos_bias_count += 1
305
+
306
+ logger.info("Converted %d MultiHeadAttention nodes to PackedMultiHeadAttention.", len(self.attention_nodes))
307
+ logger.info("Converted %d GatedRelativePositionBias nodes to packing mode.", gated_relative_pos_bias_count)
308
+
309
+ def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]:
310
+ # When there are query, key and value inputs, we need to find the first input of the parent MatMul node.
311
+ matmul = self.model.get_parent(first_attention_node, 0)
312
+ if matmul and matmul.op_type == "MatMul":
313
+ return matmul.input[0]
314
+ return None
315
+
316
+
317
+ class PackingMode:
318
+ def __init__(self, model: OnnxModel):
319
+ self.model = model
320
+
321
+ def convert(self, use_symbolic_shape_infer: bool = True) -> None:
322
+ if self.model.get_nodes_by_op_type(Operators.ATTENTION):
323
+ if self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
324
+ logger.error("Packing mode does not support both Attention and MultiHeadAttention in same graph.")
325
+ return None
326
+ packing = PackingAttention(self.model)
327
+ return packing.convert(use_symbolic_shape_infer)
328
+ elif self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
329
+ packing = PackingMultiHeadAttention(self.model)
330
+ return packing.convert(use_symbolic_shape_infer)
331
+ else:
332
+ logger.error("Packing mode requires either Attention or MultiHeadAttention node in onnx graph.")
333
+ return None
334
+
335
+
336
+ def _parse_arguments():
337
+ parser = argparse.ArgumentParser(
338
+ description="Convert to packing mode tool for ONNX Runtime. It converts BERT like model to use packing mode."
339
+ )
340
+ parser.add_argument("--input", required=True, type=str, help="input onnx model path")
341
+
342
+ parser.add_argument("--output", required=True, type=str, help="optimized onnx model path")
343
+
344
+ parser.add_argument("--verbose", required=False, action="store_true", help="show debug information.")
345
+ parser.set_defaults(verbose=False)
346
+
347
+ parser.add_argument(
348
+ "--use_external_data_format",
349
+ required=False,
350
+ action="store_true",
351
+ help="use external data format to store large model (>2GB)",
352
+ )
353
+ parser.set_defaults(use_external_data_format=False)
354
+
355
+ args = parser.parse_args()
356
+
357
+ return args
358
+
359
+
360
+ def _setup_logger(verbose):
361
+ if verbose:
362
+ coloredlogs.install(
363
+ level="DEBUG",
364
+ fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
365
+ )
366
+ else:
367
+ coloredlogs.install(fmt="%(funcName)20s: %(message)s")
368
+
369
+
370
+ def main():
371
+ args = _parse_arguments()
372
+
373
+ _setup_logger(args.verbose)
374
+
375
+ logger.debug(f"arguments:{args}")
376
+
377
+ if os.path.realpath(args.input) == os.path.realpath(args.output):
378
+ logger.warning("Specified the same input and output path. Note that this may overwrite the original model")
379
+
380
+ model = load_model(args.input)
381
+ packing_mode = PackingMode(OnnxModel(model))
382
+ packing_mode.convert()
383
+ packing_mode.model.save_model_to_file(args.output, use_external_data_format=args.use_external_data_format)
384
+
385
+
386
+ if __name__ == "__main__":
387
+ main()
@@ -0,0 +1,104 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ import onnx
8
+
9
+
10
+ class DynamoOnnxHelper:
11
+ """
12
+ Helper class for processing ONNX models exported by torch Dynamo.
13
+ """
14
+
15
+ def __init__(self, model: onnx.ModelProto):
16
+ self.model = model
17
+
18
+ def update_edges(self, edge_mapping: dict) -> None:
19
+ """
20
+ Updates the edges in the model according to the given mapping.
21
+ """
22
+ for node in self.model.graph.node:
23
+ for i in range(len(node.input)):
24
+ if node.input[i] in edge_mapping:
25
+ node.input[i] = edge_mapping[node.input[i]]
26
+ for i in range(len(node.output)):
27
+ if node.output[i] in edge_mapping:
28
+ node.output[i] = edge_mapping[node.output[i]]
29
+
30
+ for graph_input in self.model.graph.input:
31
+ if graph_input.name in edge_mapping:
32
+ graph_input.name = edge_mapping[graph_input.name]
33
+ for graph_output in self.model.graph.output:
34
+ if graph_output.name in edge_mapping:
35
+ graph_output.name = edge_mapping[graph_output.name]
36
+
37
+ def unroll_function(self, func_name: str) -> None:
38
+ """
39
+ Unrolls the function with the given name in the model.
40
+ """
41
+ logging.info(f"Unrolling function {func_name}...")
42
+ nodes_to_remove = []
43
+ nodes_to_add = []
44
+ edges_to_remove = []
45
+ edges_to_add = []
46
+ for node in self.model.graph.node:
47
+ if node.op_type == func_name:
48
+ nodes_to_remove.append(node)
49
+ edges_to_remove.extend(list(node.input) + list(node.output))
50
+
51
+ func_to_remove = None
52
+ for f in self.model.functions:
53
+ if f.name == func_name:
54
+ nodes_to_add.extend(list(f.node))
55
+ edges_to_add.extend(list(f.input) + list(f.output))
56
+ func_to_remove = f
57
+
58
+ assert len(edges_to_remove) == len(edges_to_add)
59
+
60
+ for node in nodes_to_remove:
61
+ self.model.graph.node.remove(node)
62
+ for node in nodes_to_add:
63
+ self.model.graph.node.append(node)
64
+ if func_to_remove is not None:
65
+ self.model.functions.remove(func_to_remove)
66
+
67
+ edge_mapping = {}
68
+ for i in range(len(edges_to_remove)):
69
+ k = edges_to_remove[i]
70
+ v = edges_to_add[i]
71
+ if k != v:
72
+ edge_mapping[k] = v
73
+
74
+ return self.update_edges(edge_mapping)
75
+
76
+ def remove_function(self, func_name: str, input_id: int, output_id: int) -> None:
77
+ """
78
+ Removes the function in the model.
79
+ """
80
+ edge_mapping = {}
81
+ nodes_to_remove = []
82
+ for node in self.model.graph.node:
83
+ if node.op_type.find(func_name) != -1:
84
+ edge_mapping[node.input[input_id]] = node.output[output_id]
85
+ nodes_to_remove.append(node)
86
+ for node in nodes_to_remove:
87
+ self.model.graph.node.remove(node)
88
+
89
+ self.update_edges(edge_mapping)
90
+
91
+ def remove_dropout_layer(self) -> None:
92
+ """
93
+ Removes the dropout layer in the model.
94
+ """
95
+ logging.info("Removing dropout layer...")
96
+ self.remove_function("Dropout", 0, 0)
97
+
98
+ def remove_lm_head_layer(self) -> None:
99
+ """
100
+ Removes the LM head layer in the model.
101
+ """
102
+ logging.info("Removing LM head layer...")
103
+ # bugbug: need to copy the right vi over
104
+ self.remove_function("Linear_lm_head", 2, 0)