onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,717 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import numpy
12
+ import torch
13
+ from affinity_helper import AffinitySetting
14
+ from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_session
15
+ from huggingface_models import MODEL_CLASSES
16
+ from quantize_helper import QuantizeHelper
17
+ from torch_onnx_export_helper import torch_onnx_export
18
+ from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, LxmertConfig, TransfoXLConfig
19
+
20
+ from onnxruntime.transformers.models.gpt2.gpt2_helper import (
21
+ PRETRAINED_GPT2_MODELS,
22
+ GPT2ModelNoPastState,
23
+ TFGPT2ModelNoPastState,
24
+ )
25
+
26
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Workaround by replacing torch.triu using self-defined op
31
+ # Since torch.triu cannot be exported to ONNX. See https://github.com/pytorch/pytorch/issues/32968
32
+ torch_func = {"triu": torch.triu}
33
+
34
+
35
+ def triu_onnx(x, diagonal=0, out=None):
36
+ assert out is None
37
+ assert len(x.shape) == 2 and x.size(0) == x.size(1)
38
+
39
+ torch_triu = torch_func["triu"]
40
+ template = torch_triu(torch.ones((1024, 1024), dtype=torch.uint8), diagonal)
41
+ mask = template[: x.size(0), : x.size(1)]
42
+ return torch.where(mask.bool(), x, torch.zeros_like(x))
43
+
44
+
45
+ def replace_torch_functions():
46
+ torch.triu = triu_onnx
47
+
48
+
49
+ def restore_torch_functions():
50
+ torch.triu = torch_func["triu"]
51
+
52
+
53
+ def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64):
54
+ if config.model_type in ["vit", "swin"]:
55
+ input_ids = numpy.random.rand(batch_size, 3, config.image_size, config.image_size).astype(numpy.float32)
56
+ inputs = {"pixel_values": input_ids}
57
+ return inputs
58
+
59
+ input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type)
60
+ inputs = {"input_ids": input_ids}
61
+
62
+ if "attention_mask" in input_names:
63
+ attention_mask = numpy.ones([batch_size, sequence_length], dtype=data_type)
64
+ inputs["attention_mask"] = attention_mask
65
+
66
+ if "token_type_ids" in input_names:
67
+ segment_ids = numpy.zeros([batch_size, sequence_length], dtype=data_type)
68
+ inputs["token_type_ids"] = segment_ids
69
+
70
+ if config.is_encoder_decoder:
71
+ inputs["decoder_input_ids"] = input_ids
72
+
73
+ if isinstance(config, LxmertConfig):
74
+ inputs["visual_feats"] = numpy.random.randn(1, 1, config.visual_feat_dim).astype(numpy.float32)
75
+ inputs["visual_pos"] = numpy.random.randn(1, 1, config.visual_pos_dim).astype(numpy.float32)
76
+ if isinstance(config, TransfoXLConfig):
77
+ inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros(
78
+ [config.hidden_size], dtype=numpy.float32
79
+ )
80
+ return inputs
81
+
82
+
83
+ def filter_inputs(inputs, input_names):
84
+ remaining_model_inputs = {}
85
+ for input_name in input_names:
86
+ if input_name in inputs:
87
+ remaining_model_inputs[input_name] = inputs[input_name]
88
+ return remaining_model_inputs
89
+
90
+
91
+ def flatten(inputs):
92
+ return [[flatten(i) for i in inputs] if isinstance(inputs, (list, tuple)) else inputs]
93
+
94
+
95
+ def update_flatten_list(inputs, res_list):
96
+ for i in inputs:
97
+ res_list.append(i) if not isinstance(i, (list, tuple)) else update_flatten_list(i, res_list)
98
+ return res_list
99
+
100
+
101
+ def build_dynamic_axes(example_inputs, outputs_flatten):
102
+ sequence_length = example_inputs["input_ids"].shape[-1]
103
+
104
+ dynamic_axes = {key: {0: "batch_size", 1: "seq_len"} for key in example_inputs}
105
+
106
+ output_names = ["output_" + str(i + 1) for i in range(len(outputs_flatten))]
107
+ for i, output_name in enumerate(output_names):
108
+ dynamic_axes[output_name] = {0: "batch_size"}
109
+ dims = outputs_flatten[i].shape
110
+ for j, dim in enumerate(dims):
111
+ if dim == sequence_length:
112
+ dynamic_axes[output_name].update({j: "seq_len"})
113
+ return dynamic_axes, output_names
114
+
115
+
116
+ def validate_onnx_model(
117
+ onnx_model_path,
118
+ example_inputs,
119
+ example_outputs_flatten,
120
+ use_gpu,
121
+ fp16,
122
+ output_names=None,
123
+ ):
124
+ test_session = create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=False)
125
+ if test_session is None:
126
+ logger.error(f"{onnx_model_path} is an invalid ONNX model")
127
+ return False
128
+
129
+ logger.info(f"{onnx_model_path} is a valid ONNX model")
130
+
131
+ # Compare the inference result with PyTorch or Tensorflow
132
+ example_ort_inputs = {k: t.numpy() for k, t in example_inputs.items()}
133
+ example_ort_outputs = test_session.run(output_names, example_ort_inputs)
134
+ if len(example_outputs_flatten) != len(example_ort_outputs):
135
+ logger.error(
136
+ f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}"
137
+ )
138
+ return False
139
+
140
+ for i in range(len(example_outputs_flatten)):
141
+ abs_diff = numpy.amax(numpy.abs(example_ort_outputs[i] - example_outputs_flatten[i].cpu().numpy()))
142
+ if abs_diff > 1e-4:
143
+ logger.info(f"Max absolute diff={abs_diff} for output tensor {i}")
144
+
145
+ rtol = 5e-02 if fp16 else 1e-4
146
+ atol = 1e-01 if fp16 else 1e-4
147
+ if not numpy.allclose(
148
+ example_ort_outputs[i],
149
+ example_outputs_flatten[i].cpu().numpy(),
150
+ rtol=rtol,
151
+ atol=atol,
152
+ ):
153
+ logger.error(f"Output tensor {i} is not close: rtol={rtol}, atol={atol}")
154
+ return False
155
+
156
+ logger.info(f"inference result of onnxruntime is validated on {onnx_model_path}")
157
+ return True
158
+
159
+
160
+ def get_onnx_file_path(
161
+ onnx_dir: str,
162
+ model_name: str,
163
+ input_count: int,
164
+ optimized_by_script: bool,
165
+ use_gpu: bool,
166
+ precision: Precision,
167
+ optimized_by_onnxruntime: bool,
168
+ use_external_data: bool,
169
+ ):
170
+ from re import sub
171
+
172
+ normalized_model_name = sub(r"[^a-zA-Z0-9_]", "_", model_name)
173
+
174
+ if not optimized_by_script:
175
+ filename = f"{normalized_model_name}_{input_count}"
176
+ else:
177
+ device = "gpu" if use_gpu else "cpu"
178
+ filename = f"{normalized_model_name}_{input_count}_{precision}_{device}"
179
+
180
+ if optimized_by_onnxruntime:
181
+ filename += "_ort"
182
+
183
+ directory = onnx_dir
184
+ # ONNXRuntime will not write external data so the raw and optimized models shall be in same directory.
185
+ if use_external_data and not optimized_by_onnxruntime:
186
+ directory = os.path.join(onnx_dir, filename)
187
+ if not os.path.exists(directory):
188
+ os.makedirs(directory)
189
+
190
+ return os.path.join(directory, f"{filename}.onnx")
191
+
192
+
193
+ def add_filename_suffix(file_path: str, suffix: str) -> str:
194
+ """
195
+ Append a suffix at the filename (before the extension).
196
+ Args:
197
+ path: pathlib.Path The actual path object we would like to add a suffix
198
+ suffix: The suffix to add
199
+ Returns: path with suffix appended at the end of the filename and before extension
200
+ """
201
+ path = Path(file_path)
202
+ return str(path.parent.joinpath(path.stem + suffix).with_suffix(path.suffix))
203
+
204
+
205
+ def optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics):
206
+ if overwrite or not os.path.exists(ort_model_path):
207
+ Path(ort_model_path).parent.mkdir(parents=True, exist_ok=True)
208
+ from optimizer import get_fusion_statistics, optimize_by_onnxruntime
209
+
210
+ # Use onnxruntime to optimize model, which will be saved to *_ort.onnx
211
+ _ = optimize_by_onnxruntime(
212
+ onnx_model_path,
213
+ use_gpu=use_gpu,
214
+ optimized_model_path=ort_model_path,
215
+ opt_level=99,
216
+ )
217
+ model_fusion_statistics[ort_model_path] = get_fusion_statistics(ort_model_path)
218
+ else:
219
+ logger.info(f"Skip optimization since model existed: {ort_model_path}")
220
+
221
+
222
+ def optimize_onnx_model(
223
+ onnx_model_path,
224
+ optimized_model_path,
225
+ model_type,
226
+ num_attention_heads,
227
+ hidden_size,
228
+ use_gpu,
229
+ precision,
230
+ use_raw_attention_mask,
231
+ overwrite,
232
+ model_fusion_statistics,
233
+ use_external_data_format,
234
+ optimization_options=None,
235
+ ):
236
+ if overwrite or not os.path.exists(optimized_model_path):
237
+ Path(optimized_model_path).parent.mkdir(parents=True, exist_ok=True)
238
+
239
+ from fusion_options import FusionOptions
240
+ from optimizer import optimize_model
241
+
242
+ if optimization_options is None:
243
+ optimization_options = FusionOptions(model_type)
244
+ optimization_options.use_raw_attention_mask(use_raw_attention_mask)
245
+ if precision == Precision.FLOAT16:
246
+ optimization_options.enable_gelu_approximation = True
247
+ if precision == Precision.INT8:
248
+ optimization_options.enable_embed_layer_norm = False
249
+
250
+ # For swin models, the num_attention_heads is a list, which isn't supported yet, so set to 0 for now
251
+ if model_type == "swin":
252
+ num_attention_heads = 0
253
+ hidden_size = 0
254
+
255
+ # Use script to optimize model.
256
+ # Use opt_level <= 1 for models to be converted to fp16, because some fused op (like FusedGemm) has only fp32 and no fp16.
257
+ # It is better to be conservative so we use opt_level=0 here, in case MemcpyFromHost is added to the graph by OnnxRuntime.
258
+ opt_model = optimize_model(
259
+ onnx_model_path,
260
+ model_type,
261
+ num_heads=num_attention_heads,
262
+ hidden_size=hidden_size,
263
+ opt_level=0,
264
+ optimization_options=optimization_options,
265
+ use_gpu=use_gpu,
266
+ only_onnxruntime=False,
267
+ )
268
+ if model_type == "bert_keras" or model_type == "bert_tf":
269
+ opt_model.use_dynamic_axes()
270
+
271
+ model_fusion_statistics[optimized_model_path] = opt_model.get_fused_operator_statistics()
272
+
273
+ if precision == Precision.FLOAT16:
274
+ opt_model.convert_float_to_float16(keep_io_types=True)
275
+
276
+ opt_model.save_model_to_file(optimized_model_path, use_external_data_format)
277
+ else:
278
+ logger.info(f"Skip optimization since model existed: {optimized_model_path}")
279
+
280
+
281
+ def modelclass_dispatcher(model_name, custom_model_class):
282
+ if custom_model_class is not None:
283
+ if custom_model_class in MODEL_CLASSES:
284
+ return custom_model_class
285
+ else:
286
+ raise Exception("Valid model class: " + " ".join(MODEL_CLASSES))
287
+
288
+ if model_name in PRETRAINED_GPT2_MODELS:
289
+ return "GPT2ModelNoPastState"
290
+
291
+ import re
292
+
293
+ if re.search("-squad$", model_name) is not None:
294
+ return "AutoModelForQuestionAnswering"
295
+ elif re.search("-mprc$", model_name) is not None:
296
+ return "AutoModelForSequenceClassification"
297
+ elif re.search("gpt2", model_name) is not None:
298
+ return "AutoModelWithLMHead"
299
+
300
+ return "AutoModel"
301
+
302
+
303
+ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_tf_model=False):
304
+ model_class_name = modelclass_dispatcher(model_name, custom_model_class)
305
+
306
+ if model_class_name == "GPT2ModelNoPastState":
307
+ if is_tf_model:
308
+ return TFGPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
309
+ else:
310
+ return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
311
+
312
+ if is_tf_model:
313
+ model_class_name = "TF" + model_class_name
314
+
315
+ transformers_module = __import__("transformers", fromlist=[model_class_name])
316
+ logger.info(f"Model class name: {model_class_name}")
317
+ model_class = getattr(transformers_module, model_class_name)
318
+
319
+ return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)
320
+
321
+
322
+ def load_pt_model(model_name, model_class, cache_dir, config_modifier):
323
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
324
+ if hasattr(config, "return_dict"):
325
+ config.return_dict = False
326
+
327
+ config_modifier.modify(config)
328
+
329
+ model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class)
330
+
331
+ return config, model
332
+
333
+
334
+ def load_tf_model(model_name, model_class, cache_dir, config_modifier):
335
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
336
+
337
+ config_modifier.modify(config)
338
+ # Loading tf model from transformers limits the cpu affinity to {0} when KMP_AFFINITY is set
339
+ # Restore the affinity after model loading for expected ORT performance
340
+ affinity_setting = AffinitySetting()
341
+ affinity_setting.get_affinity()
342
+ model = load_pretrained_model(
343
+ model_name,
344
+ config=config,
345
+ cache_dir=cache_dir,
346
+ custom_model_class=model_class,
347
+ is_tf_model=True,
348
+ )
349
+ affinity_setting.set_affinity()
350
+
351
+ return config, model
352
+
353
+
354
+ # For test only
355
+ def load_pt_model_from_tf(model_name):
356
+ # Note that we could get pt model from tf, but model source and its structure in this case is different from directly using
357
+ # load_pt_model() and load_tf_model() even with the same name. Therefore it should not be used for comparing with them
358
+ from convert_tf_models_to_pytorch import tf2pt_pipeline
359
+
360
+ config, model = tf2pt_pipeline(model_name)
361
+
362
+ return config, model
363
+
364
+
365
+ def validate_and_optimize_onnx(
366
+ model_name,
367
+ use_external_data_format,
368
+ model_type,
369
+ onnx_dir,
370
+ input_names,
371
+ use_gpu,
372
+ precision,
373
+ optimize_info,
374
+ validate_onnx,
375
+ use_raw_attention_mask,
376
+ overwrite,
377
+ config,
378
+ model_fusion_statistics,
379
+ onnx_model_path,
380
+ example_inputs,
381
+ example_outputs_flatten,
382
+ output_names,
383
+ fusion_options,
384
+ ):
385
+ is_valid_onnx_model = True
386
+ if validate_onnx:
387
+ is_valid_onnx_model = validate_onnx_model(
388
+ onnx_model_path,
389
+ example_inputs,
390
+ example_outputs_flatten,
391
+ use_gpu,
392
+ False,
393
+ output_names,
394
+ )
395
+ if optimize_info == OptimizerInfo.NOOPT:
396
+ return onnx_model_path, is_valid_onnx_model, config.vocab_size
397
+
398
+ if (
399
+ optimize_info == OptimizerInfo.BYSCRIPT or precision == Precision.FLOAT16 or precision == Precision.INT8
400
+ ): # Use script (optimizer.py) to optimize
401
+ optimized_model_path = get_onnx_file_path(
402
+ onnx_dir,
403
+ model_name,
404
+ len(input_names),
405
+ True,
406
+ use_gpu,
407
+ precision,
408
+ False,
409
+ use_external_data_format,
410
+ )
411
+ optimize_onnx_model(
412
+ onnx_model_path,
413
+ optimized_model_path,
414
+ model_type,
415
+ config.num_attention_heads,
416
+ config.hidden_size,
417
+ use_gpu,
418
+ precision,
419
+ use_raw_attention_mask,
420
+ overwrite,
421
+ model_fusion_statistics,
422
+ use_external_data_format,
423
+ fusion_options,
424
+ )
425
+
426
+ onnx_model_path = optimized_model_path
427
+ if validate_onnx:
428
+ is_valid_onnx_model = validate_onnx_model(
429
+ onnx_model_path,
430
+ example_inputs,
431
+ example_outputs_flatten,
432
+ use_gpu,
433
+ precision == Precision.FLOAT16,
434
+ output_names,
435
+ )
436
+
437
+ if precision == Precision.INT8:
438
+ logger.info(f"Quantizing model: {onnx_model_path}")
439
+ QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_path, use_external_data_format)
440
+ logger.info(f"Finished quantizing model: {onnx_model_path}")
441
+
442
+ if optimize_info == OptimizerInfo.BYORT: # Use OnnxRuntime to optimize
443
+ if is_valid_onnx_model:
444
+ ort_model_path = add_filename_suffix(onnx_model_path, "_ort")
445
+ optimize_onnx_model_by_ort(
446
+ onnx_model_path,
447
+ ort_model_path,
448
+ use_gpu,
449
+ overwrite,
450
+ model_fusion_statistics,
451
+ )
452
+
453
+ return (
454
+ onnx_model_path,
455
+ is_valid_onnx_model,
456
+ config.num_labels if model_type in ["vit", "swin"] else config.vocab_size,
457
+ )
458
+
459
+
460
+ def export_onnx_model_from_pt(
461
+ model_name,
462
+ opset_version,
463
+ use_external_data_format,
464
+ model_type,
465
+ model_class,
466
+ config_modifier,
467
+ cache_dir,
468
+ onnx_dir,
469
+ input_names,
470
+ use_gpu,
471
+ precision,
472
+ optimizer_info,
473
+ validate_onnx,
474
+ use_raw_attention_mask,
475
+ overwrite,
476
+ model_fusion_statistics,
477
+ fusion_options,
478
+ ):
479
+ config, model = load_pt_model(model_name, model_class, cache_dir, config_modifier)
480
+ # config, model = load_pt_model_from_tf(model_name)
481
+ model.cpu()
482
+
483
+ example_inputs = None
484
+ max_input_size = None
485
+
486
+ if model_type in ["vit", "swin"]:
487
+ image_processor = AutoFeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir)
488
+ data = numpy.random.randint(
489
+ low=0, high=256, size=config.image_size * config.image_size * 3, dtype=numpy.uint8
490
+ ).reshape(config.image_size, config.image_size, 3)
491
+
492
+ example_inputs = image_processor(data, return_tensors="pt")
493
+ else:
494
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
495
+ max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024)
496
+ example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
497
+
498
+ example_inputs = filter_inputs(example_inputs, input_names)
499
+
500
+ example_outputs = model(**example_inputs)
501
+
502
+ assert isinstance(example_outputs, (list, tuple)), f"type of output is not list or tuple: {type(example_outputs)}"
503
+
504
+ # Flatten is needed for gpt2 and distilgpt2.
505
+ example_outputs_flatten = flatten(example_outputs)
506
+ example_outputs_flatten = update_flatten_list(example_outputs_flatten, [])
507
+
508
+ onnx_model_path = get_onnx_file_path(
509
+ onnx_dir,
510
+ model_name,
511
+ len(input_names),
512
+ False,
513
+ use_gpu,
514
+ precision,
515
+ False,
516
+ use_external_data_format,
517
+ )
518
+
519
+ if overwrite or not os.path.exists(onnx_model_path):
520
+ logger.info(f"Exporting ONNX model to {onnx_model_path}")
521
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
522
+
523
+ dynamic_axes = None
524
+ output_names = None
525
+
526
+ if model_type in ["vit", "swin"]:
527
+ dynamic_axes, output_names = {key: {0: "pixel_values"} for key in example_inputs}, ["logits"]
528
+ else:
529
+ dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
530
+
531
+ replace_torch_functions()
532
+ torch_onnx_export(
533
+ model=model,
534
+ args=tuple(example_inputs.values()),
535
+ f=onnx_model_path,
536
+ input_names=list(example_inputs.keys()),
537
+ output_names=output_names,
538
+ dynamic_axes=dynamic_axes,
539
+ do_constant_folding=True,
540
+ opset_version=opset_version,
541
+ use_external_data_format=use_external_data_format,
542
+ )
543
+ restore_torch_functions()
544
+ else:
545
+ logger.info(f"Skip export since model existed: {onnx_model_path}")
546
+
547
+ onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
548
+ model_name,
549
+ use_external_data_format,
550
+ model_type,
551
+ onnx_dir,
552
+ input_names,
553
+ use_gpu,
554
+ precision,
555
+ optimizer_info,
556
+ validate_onnx,
557
+ use_raw_attention_mask,
558
+ overwrite,
559
+ config,
560
+ model_fusion_statistics,
561
+ onnx_model_path,
562
+ example_inputs,
563
+ example_outputs_flatten,
564
+ None,
565
+ fusion_options,
566
+ )
567
+
568
+ return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size
569
+
570
+
571
+ def export_onnx_model_from_tf(
572
+ model_name,
573
+ opset_version,
574
+ use_external_data_format,
575
+ model_type,
576
+ model_class,
577
+ config_modifier,
578
+ cache_dir,
579
+ onnx_dir,
580
+ input_names,
581
+ use_gpu,
582
+ precision,
583
+ optimizer_info,
584
+ validate_onnx,
585
+ use_raw_attention_mask,
586
+ overwrite,
587
+ model_fusion_statistics,
588
+ fusion_options,
589
+ ):
590
+ # Use CPU to export
591
+ import tensorflow as tf
592
+
593
+ tf.config.set_visible_devices([], "GPU")
594
+
595
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
596
+ # Fix "Using pad_token, but it is not set yet" error.
597
+ if tokenizer.pad_token is None:
598
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
599
+ max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024)
600
+
601
+ config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier)
602
+ model.resize_token_embeddings(len(tokenizer))
603
+
604
+ example_inputs = tokenizer.encode_plus(
605
+ "This is a sample input",
606
+ return_tensors="tf",
607
+ max_length=max_input_size,
608
+ padding="max_length",
609
+ truncation=True,
610
+ )
611
+ example_inputs = filter_inputs(example_inputs, input_names)
612
+
613
+ if config.is_encoder_decoder:
614
+ example_inputs["decoder_input_ids"] = tokenizer.encode_plus(
615
+ "This is a sample input",
616
+ return_tensors="tf",
617
+ max_length=max_input_size,
618
+ padding="max_length",
619
+ truncation=True,
620
+ ).input_ids
621
+ if model_name == "unc-nlp/lxmert-base-uncased":
622
+ example_inputs["visual_feats"] = tf.random.normal([1, 1, config.visual_feat_dim])
623
+ example_inputs["visual_pos"] = tf.random.normal([1, 1, config.visual_pos_dim])
624
+
625
+ try:
626
+ # Use no past state for these models
627
+ if config.use_cache:
628
+ config.use_cache = False
629
+ except Exception:
630
+ pass
631
+
632
+ example_outputs = model(example_inputs, training=False)
633
+ output_names = None
634
+
635
+ # For xlnet models, only compare the last_hidden_state output.
636
+ if model_name == "xlnet-base-cased" or model_name == "xlnet-large-cased":
637
+ output_names = ["last_hidden_state"]
638
+ example_outputs = example_outputs["last_hidden_state"]
639
+
640
+ # Flatten is needed for gpt2 and distilgpt2. Output name sorting is needed for tf2onnx outputs to match onnx outputs.
641
+ from tensorflow.python.util import nest
642
+
643
+ example_outputs_flatten = nest.flatten(example_outputs)
644
+
645
+ onnx_model_path = get_onnx_file_path(
646
+ onnx_dir,
647
+ model_name,
648
+ len(input_names),
649
+ False,
650
+ use_gpu,
651
+ precision,
652
+ False,
653
+ use_external_data_format,
654
+ )
655
+ tf_internal_model_path = onnx_model_path[:-5] if use_external_data_format else onnx_model_path
656
+
657
+ if overwrite or not os.path.exists(tf_internal_model_path):
658
+ logger.info(f"Exporting ONNX model to {onnx_model_path}")
659
+ if not use_external_data_format:
660
+ Path(tf_internal_model_path).parent.mkdir(parents=True, exist_ok=True)
661
+
662
+ import zipfile
663
+
664
+ import tf2onnx
665
+
666
+ tf2onnx.logging.set_level(tf2onnx.logging.ERROR)
667
+ specs = []
668
+ for name, value in example_inputs.items():
669
+ dims = [None] * len(value.shape)
670
+ specs.append(tf.TensorSpec(tuple(dims), value.dtype, name=name))
671
+ _, _ = tf2onnx.convert.from_keras(
672
+ model,
673
+ input_signature=tuple(specs),
674
+ opset=opset_version,
675
+ large_model=use_external_data_format,
676
+ output_path=tf_internal_model_path,
677
+ )
678
+ if use_external_data_format:
679
+ # need to unpack the zip for run_onnxruntime()
680
+ with zipfile.ZipFile(tf_internal_model_path, "r") as z:
681
+ z.extractall(os.path.dirname(tf_internal_model_path))
682
+ tf_internal_model_path = os.path.join(os.path.dirname(tf_internal_model_path), "__MODEL_PROTO.onnx")
683
+ if os.path.exists(onnx_model_path):
684
+ os.remove(onnx_model_path)
685
+ os.rename(tf_internal_model_path, onnx_model_path)
686
+
687
+ else:
688
+ logger.info(f"Skip export since model existed: {onnx_model_path}")
689
+
690
+ model_type = model_type + "_tf"
691
+ optimized_onnx_path, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
692
+ model_name,
693
+ use_external_data_format,
694
+ model_type,
695
+ onnx_dir,
696
+ input_names,
697
+ use_gpu,
698
+ precision,
699
+ optimizer_info,
700
+ validate_onnx,
701
+ use_raw_attention_mask,
702
+ overwrite,
703
+ config,
704
+ model_fusion_statistics,
705
+ onnx_model_path,
706
+ example_inputs,
707
+ example_outputs_flatten,
708
+ output_names,
709
+ fusion_options,
710
+ )
711
+
712
+ return (
713
+ optimized_onnx_path,
714
+ is_valid_onnx_model,
715
+ vocab_size,
716
+ max_input_size,
717
+ )