onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,164 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import datetime
8
+ import json
9
+ from typing import Optional
10
+
11
+ import pandas as pd
12
+
13
+
14
+ class BaseObject:
15
+ def __init__(self):
16
+ self.customized = {}
17
+
18
+ def to_dict(self):
19
+ default_values = self.__dict__.copy()
20
+ default_values.pop("customized", None)
21
+ default_values.update(self.customized)
22
+
23
+ for k, v in default_values.items():
24
+ if isinstance(v, BaseObject):
25
+ default_values[k] = v.to_dict()
26
+
27
+ return {k: v for k, v in default_values.items() if v}
28
+
29
+
30
+ class ModelInfo(BaseObject):
31
+ def __init__(
32
+ self,
33
+ full_name: Optional[str] = None,
34
+ is_huggingface: Optional[bool] = False,
35
+ is_text_generation: Optional[bool] = False,
36
+ short_name: Optional[str] = None,
37
+ ):
38
+ super().__init__()
39
+ self.full_name = full_name
40
+ self.is_huggingface = is_huggingface
41
+ self.is_text_generation = is_text_generation
42
+ self.short_name = short_name
43
+ self.input_shape = []
44
+
45
+
46
+ class BackendOptions(BaseObject):
47
+ def __init__(
48
+ self,
49
+ enable_profiling: Optional[bool] = False,
50
+ execution_provider: Optional[str] = None,
51
+ use_io_binding: Optional[bool] = False,
52
+ ):
53
+ super().__init__()
54
+ self.enable_profiling = enable_profiling
55
+ self.execution_provider = execution_provider
56
+ self.use_io_binding = use_io_binding
57
+
58
+
59
+ class Config(BaseObject):
60
+ def __init__(
61
+ self,
62
+ backend: Optional[str] = "onnxruntime",
63
+ batch_size: Optional[int] = 1,
64
+ seq_length: Optional[int] = 0,
65
+ precision: Optional[str] = "fp32",
66
+ warmup_runs: Optional[int] = 1,
67
+ measured_runs: Optional[int] = 10,
68
+ ):
69
+ super().__init__()
70
+ self.backend = backend
71
+ self.batch_size = batch_size
72
+ self.seq_length = seq_length
73
+ self.precision = precision
74
+ self.warmup_runs = warmup_runs
75
+ self.measured_runs = measured_runs
76
+ self.model_info = ModelInfo()
77
+ self.backend_options = BackendOptions()
78
+
79
+
80
+ class Metadata(BaseObject):
81
+ def __init__(
82
+ self,
83
+ device: Optional[str] = None,
84
+ package_name: Optional[str] = None,
85
+ package_version: Optional[str] = None,
86
+ platform: Optional[str] = None,
87
+ python_version: Optional[str] = None,
88
+ ):
89
+ super().__init__()
90
+ self.device = device
91
+ self.package_name = package_name
92
+ self.package_version = package_version
93
+ self.platform = platform
94
+ self.python_version = python_version
95
+
96
+
97
+ class Metrics(BaseObject):
98
+ def __init__(
99
+ self,
100
+ latency_ms_mean: Optional[float] = 0.0,
101
+ throughput_qps: Optional[float] = 0.0,
102
+ max_memory_usage_GB: Optional[float] = 0.0,
103
+ ):
104
+ super().__init__()
105
+ self.latency_ms_mean = latency_ms_mean
106
+ self.throughput_qps = throughput_qps
107
+ self.max_memory_usage_GB = max_memory_usage_GB
108
+
109
+
110
+ class BenchmarkRecord:
111
+ def __init__(
112
+ self,
113
+ model_name: str,
114
+ precision: str,
115
+ backend: str,
116
+ device: str,
117
+ package_name: str,
118
+ package_version: str,
119
+ batch_size: Optional[int] = 1,
120
+ warmup_runs: Optional[int] = 1,
121
+ measured_runs: Optional[int] = 10,
122
+ trigger_date: Optional[str] = None,
123
+ ):
124
+ self.config = Config()
125
+ self.metrics = Metrics()
126
+ self.metadata = Metadata()
127
+ self.trigger_date = trigger_date or datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
128
+
129
+ self.config.model_info.full_name = model_name
130
+ self.config.precision = precision
131
+ self.config.backend = backend
132
+ self.config.batch_size = batch_size
133
+ self.config.warmup_runs = warmup_runs
134
+ self.config.measured_runs = measured_runs
135
+ self.metadata.device = device
136
+ self.metadata.package_name = package_name
137
+ self.metadata.package_version = package_version
138
+
139
+ def to_dict(self) -> dict:
140
+ return {
141
+ "config": self.config.to_dict(),
142
+ "metadata": self.metadata.to_dict(),
143
+ "metrics": self.metrics.to_dict(),
144
+ "trigger_date": self.trigger_date,
145
+ }
146
+
147
+ def to_json(self) -> str:
148
+ return json.dumps(self.to_dict(), default=str)
149
+
150
+ @classmethod
151
+ def save_as_csv(cls, file_name: str, records: list) -> None:
152
+ if records is None or len(records) == 0:
153
+ return
154
+ rds = [record.to_dict() for record in records]
155
+ df = pd.json_normalize(rds)
156
+ df.to_csv(file_name, index=False)
157
+
158
+ @classmethod
159
+ def save_as_json(cls, file_name: str, records: list) -> None:
160
+ if records is None or len(records) == 0:
161
+ return
162
+ rds = [record.to_dict() for record in records]
163
+ with open(file_name, "w") as f:
164
+ json.dump(rds, f, indent=4, default=str)
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os.path
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)
@@ -0,0 +1,98 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ import sys
11
+
12
+ from utils import (
13
+ chain_enc_dec_with_beamsearch,
14
+ export_summarization_edinit,
15
+ export_summarization_enc_dec_past,
16
+ onnx_inference,
17
+ )
18
+
19
+ # GLOBAL ENVS
20
+ logging.basicConfig(
21
+ format="%(asctime)s | %(levelname)s | %(name)s | [%(filename)s:%(lineno)d] %(message)s",
22
+ datefmt="%Y-%m-%d %H:%M:%S",
23
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
24
+ stream=sys.stdout,
25
+ )
26
+ logger = logging.getLogger("generate")
27
+
28
+
29
+ def print_args(args):
30
+ for arg in vars(args):
31
+ logger.info(f"{arg}: {getattr(args, arg)}")
32
+
33
+
34
+ def user_command():
35
+ parent_parser = argparse.ArgumentParser(add_help=False)
36
+ parent_parser.add_argument("--max_length", type=int, default=20, help="default to 20")
37
+ parent_parser.add_argument("--min_length", type=int, default=0, help="default to 0")
38
+ parent_parser.add_argument("-o", "--output", type=str, default="onnx_models", help="default name is onnx_models.")
39
+ parent_parser.add_argument("-i", "--input_text", type=str, default=None, help="input text")
40
+ parent_parser.add_argument("-s", "--spm_path", type=str, default=None, help="tokenizer model from sentencepice")
41
+ parent_parser.add_argument("-v", "--vocab_path", type=str, help="vocab dictionary")
42
+ parent_parser.add_argument("-b", "--num_beams", type=int, default=5, help="default to 5")
43
+ parent_parser.add_argument("--repetition_penalty", type=float, default=1.0, help="default to 1.0")
44
+ parent_parser.add_argument("--no_repeat_ngram_size", type=int, default=3, help="default to 3")
45
+ parent_parser.add_argument("--early_stopping", type=bool, default=False, help="default to False")
46
+ parent_parser.add_argument("--opset_version", type=int, default=14, help="minimum is 14")
47
+
48
+ parent_parser.add_argument("--no_encoder", action="store_true")
49
+ parent_parser.add_argument("--no_decoder", action="store_true")
50
+ parent_parser.add_argument("--no_chain", action="store_true")
51
+ parent_parser.add_argument("--no_inference", action="store_true")
52
+
53
+ required_args = parent_parser.add_argument_group("required input arguments")
54
+ required_args.add_argument(
55
+ "-m",
56
+ "--model_dir",
57
+ type=str,
58
+ required=True,
59
+ help="The directory contains input huggingface model. \
60
+ An official model like facebook/bart-base is also acceptable.",
61
+ )
62
+
63
+ print_args(parent_parser.parse_args())
64
+ return parent_parser.parse_args()
65
+
66
+
67
+ if __name__ == "__main__":
68
+ args = user_command()
69
+ if args.opset_version < 14:
70
+ raise ValueError(f"The minimum supported opset version is 14! The given one was {args.opset_version}.")
71
+
72
+ isExist = os.path.exists(args.output) # noqa: N816
73
+ if not isExist:
74
+ os.makedirs(args.output)
75
+
76
+ # beam search op only supports CPU for now
77
+ args.device = "cpu"
78
+ logger.info("ENV: CPU ...")
79
+
80
+ if not args.input_text:
81
+ args.input_text = (
82
+ "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
83
+ "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
84
+ "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
85
+ )
86
+
87
+ if not args.no_encoder:
88
+ logger.info("========== EXPORTING ENCODER ==========")
89
+ export_summarization_edinit.export_encoder(args)
90
+ if not args.no_decoder:
91
+ logger.info("========== EXPORTING DECODER ==========")
92
+ export_summarization_enc_dec_past.export_decoder(args)
93
+ if not args.no_chain:
94
+ logger.info("========== CONVERTING MODELS ==========")
95
+ chain_enc_dec_with_beamsearch.convert_model(args)
96
+ if not args.no_inference:
97
+ logger.info("========== INFERENCING WITH ONNX MODEL ==========")
98
+ onnx_inference.run_inference(args)
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os.path
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)
@@ -0,0 +1,329 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ #
6
+ # This script evaluates accuracy of ONNX models for question-answering task on SQuAD data set.
7
+ # Example to evaluate raw and optimized model for CUDA in Linux:
8
+ # pip3 install datasets evaluate optimum transformers onnxruntime-gpu
9
+ #
10
+ # python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding
11
+ #
12
+ # python3 -m onnxruntime.transformers.optimizer \
13
+ # --input ./bert-large-uncased-whole-word-masking-finetuned-squad/model.onnx \
14
+ # --output ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
15
+ #
16
+ # python3 eval_squad.py -m bert-large-uncased-whole-word-masking-finetuned-squad -s 384 -b 1 --use_io_binding \
17
+ # --onnx ./bert-large-uncased-whole-word-masking-finetuned-squad/optimized_model.onnx
18
+ #
19
+ # Snippet of example output in A100:
20
+ # {'exact': 86.65089877010406, 'f1': 92.99433524952254, 'total': 10570, 'HasAns_exact': 86.65089877010406
21
+ # 'total_time_in_seconds': 81.69239814393222, 'samples_per_second': 129.387804008115,
22
+ # 'latency_in_seconds': 0.007728703703304846, 'provider': 'CUDAExecutionProvider',
23
+ # 'pretrained_model_name': 'bert-large-uncased-whole-word-masking-finetuned-squad',
24
+ # 'batch_size': 1, 'sequence_length': 384, 'use_io_binding': True}
25
+ import argparse
26
+ import csv
27
+ import os
28
+ import time
29
+
30
+ try:
31
+ from importlib.metadata import PackageNotFoundError, version
32
+ except ImportError:
33
+ from importlib_metadata import PackageNotFoundError, version
34
+
35
+ from pathlib import Path
36
+ from typing import Any, Dict, List, Optional
37
+
38
+ from datasets import load_dataset
39
+ from evaluate import evaluator
40
+ from optimum.onnxruntime import ORTModelForQuestionAnswering
41
+ from optimum.version import __version__ as optimum_version
42
+ from packaging import version as version_check
43
+ from transformers import AutoTokenizer, pipeline
44
+
45
+ if version_check.parse(optimum_version) < version_check.parse("1.13.1"):
46
+ raise ImportError(f"Please install optimum>=1.13.1. Current version: {optimum_version}.")
47
+
48
+ PRETRAINED_SQUAD_MODELS = [
49
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
50
+ "deepset/roberta-base-squad2",
51
+ "distilbert-base-cased-distilled-squad",
52
+ ]
53
+
54
+
55
+ def get_package_version(package_name: str):
56
+ try:
57
+ return version(package_name)
58
+ except PackageNotFoundError:
59
+ return None
60
+
61
+
62
+ def load_onnx_model(
63
+ model_id: str, onnx_path: Optional[str] = None, provider="CUDAExecutionProvider", use_io_binding: bool = False
64
+ ):
65
+ """Load onnx model given pretrained model name and optional ONNX model path. If onnx_path is None,
66
+ the default onnx model from optimum will be used.
67
+
68
+ Args:
69
+ model_id (str): pretrained model name or checkpoint path
70
+ onnx_path (Optional[str], optional): path of onnx model to evaluate. Defaults to None.
71
+
72
+ Returns:
73
+ model: ORTModel for the onnx model
74
+ onnx_path: the path of onnx model
75
+ """
76
+
77
+ if onnx_path is None:
78
+ # Export onnx to a sub-directory named by the model id
79
+ model = ORTModelForQuestionAnswering.from_pretrained(
80
+ model_id, export=True, provider=provider, use_io_binding=use_io_binding
81
+ )
82
+ save_onnx_dir = os.path.join(".", model_id)
83
+ model.save_pretrained(save_onnx_dir)
84
+ onnx_path = os.path.join(save_onnx_dir, "model.onnx")
85
+ print("Model is exported to onnx file:", onnx_path)
86
+ else:
87
+ model = ORTModelForQuestionAnswering.from_pretrained(
88
+ os.path.dirname(onnx_path),
89
+ file_name=Path(onnx_path).name,
90
+ provider=provider,
91
+ use_io_binding=use_io_binding,
92
+ # provider_options={"enable_skip_layer_norm_strict_mode": True},
93
+ )
94
+
95
+ return model, onnx_path
96
+
97
+
98
+ def output_details(results: List[Dict[str, Any]], csv_filename: str):
99
+ """Output a CSV file with detail of each test results.
100
+
101
+ Args:
102
+ results (List[Dict[str, Any]]): list of JSON results.
103
+ csv_filename (str): path of output CSV file
104
+ """
105
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
106
+ column_names = [
107
+ "pretrained_model_name",
108
+ "onnx_path",
109
+ "provider",
110
+ "disable_fused_attention",
111
+ "batch_size",
112
+ "sequence_length",
113
+ "use_io_binding",
114
+ "exact",
115
+ "f1",
116
+ "total",
117
+ "HasAns_exact",
118
+ "HasAns_f1",
119
+ "HasAns_total",
120
+ "best_exact",
121
+ "best_exact_thresh",
122
+ "best_f1",
123
+ "best_f1_thresh",
124
+ "total_time_in_seconds",
125
+ "samples_per_second",
126
+ "latency_in_seconds",
127
+ ]
128
+
129
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
130
+ csv_writer.writeheader()
131
+ for result in results:
132
+ csv_writer.writerow(result)
133
+
134
+ csv_file.flush()
135
+
136
+ print(f"Detail results are saved to csv file: {csv_filename}")
137
+
138
+
139
+ def output_summary(results: List[Dict[str, Any]], csv_filename: str, metric_name: str):
140
+ """Output a CSV file with summary of a metric on combinations of batch_size and sequence_length.
141
+
142
+ Args:
143
+ results (List[Dict[str, Any]]): list of JSON results.
144
+ csv_filename (str): path of output CSV file
145
+ metric_name (str): the metric to summarize
146
+ """
147
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
148
+ header_names = [
149
+ "pretrained_model_name",
150
+ "onnx_path",
151
+ "provider",
152
+ "disable_fused_attention",
153
+ "use_io_binding",
154
+ ]
155
+
156
+ model_list = list({result["onnx_path"] for result in results})
157
+ model_list.sort()
158
+
159
+ batch_sizes = list({result["batch_size"] for result in results})
160
+ batch_sizes.sort()
161
+
162
+ sequence_lengths = list({result["sequence_length"] for result in results})
163
+ sequence_lengths.sort()
164
+
165
+ key_names = []
166
+ for sequence_length in sequence_lengths:
167
+ for batch_size in batch_sizes:
168
+ key_names.append(f"b{batch_size}_s{sequence_length}")
169
+
170
+ csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + key_names)
171
+ csv_writer.writeheader()
172
+
173
+ for model in model_list:
174
+ row = {}
175
+
176
+ # Metric value for given pair of batch_size and sequence_length.
177
+ # Assume that (onnx_path, batch_size and sequence_length) are unique so keep first occurrence only.
178
+ values = {}
179
+ values.update({k: "" for k in key_names})
180
+
181
+ for result in results:
182
+ if result["onnx_path"] == model and result[metric_name]:
183
+ headers = {k: v for k, v in result.items() if k in header_names}
184
+ if not row:
185
+ row.update(headers)
186
+
187
+ batch_size = result["batch_size"]
188
+ sequence_length = result["sequence_length"]
189
+ key = f"b{batch_size}_s{sequence_length}"
190
+
191
+ if key in key_names:
192
+ values[key] = result[metric_name]
193
+
194
+ if row:
195
+ for key in key_names:
196
+ row[key] = values.get(key, "")
197
+ csv_writer.writerow(row)
198
+
199
+ csv_file.flush()
200
+
201
+ print(f"Summary results for {metric_name} are saved to csv file: {csv_filename}")
202
+
203
+
204
+ def main():
205
+ args = parse_arguments()
206
+ print(args)
207
+
208
+ for name in ["onnxruntime-gpu", "onnxruntime", "onnx", "torch", "transformers", "optimum", "datasets", "evaluate"]:
209
+ package_version = get_package_version(name)
210
+ if package_version:
211
+ print(f"{name} version", package_version)
212
+
213
+ pretrained_model_name = args.model_name
214
+ if args.onnx and not os.path.exists(args.onnx):
215
+ raise RuntimeError(f"Onnx model path does not exist: {args.onnx}")
216
+
217
+ disable_fused_attention = os.environ.get("ORT_DISABLE_FUSED_ATTENTION", "0") == "1"
218
+
219
+ all_results = []
220
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
221
+ for sequence_length in args.sequence_lengths:
222
+ tokenizer.model_max_length = sequence_length
223
+ tokenizer.doc_stride = min(sequence_length // 2, 128)
224
+ if args.onnx is None:
225
+ print("Exporting onnx model. It might take a few minutes...")
226
+ start_time = time.time()
227
+ ort_model, onnx_path = load_onnx_model(pretrained_model_name, args.onnx, args.provider, args.use_io_binding)
228
+ latency = time.time() - start_time
229
+ print(f"Onnx model exported or loaded in {latency:.1f} seconds")
230
+
231
+ print(ort_model.config)
232
+ if sequence_length > ort_model.config.max_position_embeddings:
233
+ raise RuntimeError("sequence length should not be larger than {ort_model.config.max_position_embeddings}")
234
+
235
+ qa_pipeline = pipeline(
236
+ "question-answering", model=ort_model, tokenizer=tokenizer, question_first=True, batch_size=args.batch_size
237
+ )
238
+
239
+ task_evaluator = evaluator("question-answering")
240
+ print("Loading dataset...")
241
+ start_time = time.time()
242
+ squad_dataset = load_dataset("squad", split=f"validation[:{args.total}]" if args.total > 0 else "validation")
243
+ latency = time.time() - start_time
244
+ print(f"Dataset loaded in {latency:.1f} seconds")
245
+
246
+ print("Evaluating squad_v2 with ORT. It might take a few minutes...")
247
+ start_time = time.time()
248
+ result = task_evaluator.compute(
249
+ model_or_pipeline=qa_pipeline,
250
+ data=squad_dataset,
251
+ metric="squad_v2",
252
+ squad_v2_format=True,
253
+ )
254
+ latency = time.time() - start_time
255
+ print(f"Evaluation done in {latency:.1f} seconds")
256
+
257
+ result["provider"] = args.provider
258
+ result["disable_fused_attention"] = disable_fused_attention
259
+ result["pretrained_model_name"] = pretrained_model_name
260
+ result["onnx_path"] = onnx_path
261
+ result["batch_size"] = args.batch_size
262
+ result["sequence_length"] = sequence_length
263
+ result["use_io_binding"] = args.use_io_binding
264
+ print(result)
265
+
266
+ all_results.append(result)
267
+
268
+ output_details(all_results, "detail.csv")
269
+
270
+ for metric_name in ["f1", "exact", "samples_per_second"]:
271
+ output_summary(all_results, f"{metric_name}.csv", metric_name)
272
+
273
+
274
+ def parse_arguments(argv=None):
275
+ parser = argparse.ArgumentParser()
276
+
277
+ parser.add_argument(
278
+ "-m",
279
+ "--model_name",
280
+ required=False,
281
+ type=str,
282
+ default=PRETRAINED_SQUAD_MODELS[0],
283
+ help=f"Checkpoint directory or pre-trained model names in the list: {PRETRAINED_SQUAD_MODELS}",
284
+ )
285
+
286
+ parser.add_argument(
287
+ "-s",
288
+ "--sequence_lengths",
289
+ nargs="+",
290
+ type=int,
291
+ default=[384],
292
+ help="Sequence lengths for onnx model inputs. It could have multiple values.",
293
+ )
294
+
295
+ parser.add_argument(
296
+ "-b",
297
+ "--batch_size",
298
+ type=int,
299
+ default=1,
300
+ help="batch size for inference.",
301
+ )
302
+
303
+ parser.add_argument("-t", "--total", type=int, default=0, help="Total samples to test. 0 means all samples.")
304
+
305
+ parser.add_argument(
306
+ "--onnx",
307
+ required=False,
308
+ type=str,
309
+ default=None,
310
+ help="Optional onnx model path. If not specified, optimum will be used to export onnx model for testing.",
311
+ )
312
+
313
+ parser.add_argument(
314
+ "--provider",
315
+ required=False,
316
+ default="CUDAExecutionProvider",
317
+ help="Select which Execution Provider to use for runs. Default is CUDAExecutionProvider.",
318
+ )
319
+
320
+ parser.add_argument("--use_io_binding", required=False, action="store_true", help="Use IO Binding for GPU.")
321
+ parser.set_defaults(use_io_binding=False)
322
+
323
+ args = parser.parse_args(argv)
324
+
325
+ return args
326
+
327
+
328
+ if __name__ == "__main__":
329
+ main()
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os.path
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)