onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,641 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ # It is a tool to generate test data for a bert model.
7
+ # The test data can be used by onnxruntime_perf_test tool to evaluate the inference latency.
8
+
9
+ import argparse
10
+ import os
11
+ import random
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ from onnx import ModelProto, TensorProto, numpy_helper
16
+ from onnx_model import OnnxModel
17
+
18
+
19
+ def fake_input_ids_data(
20
+ input_ids: TensorProto, batch_size: int, sequence_length: int, dictionary_size: int
21
+ ) -> np.ndarray:
22
+ """Create input tensor based on the graph input of input_ids
23
+
24
+ Args:
25
+ input_ids (TensorProto): graph input of the input_ids input tensor
26
+ batch_size (int): batch size
27
+ sequence_length (int): sequence length
28
+ dictionary_size (int): vocabulary size of dictionary
29
+
30
+ Returns:
31
+ np.ndarray: the input tensor created
32
+ """
33
+ assert input_ids.type.tensor_type.elem_type in [
34
+ TensorProto.FLOAT,
35
+ TensorProto.INT32,
36
+ TensorProto.INT64,
37
+ ]
38
+
39
+ data = np.random.randint(dictionary_size, size=(batch_size, sequence_length), dtype=np.int32)
40
+
41
+ if input_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
42
+ data = np.float32(data)
43
+ elif input_ids.type.tensor_type.elem_type == TensorProto.INT64:
44
+ data = np.int64(data)
45
+
46
+ return data
47
+
48
+
49
+ def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_length: int) -> np.ndarray:
50
+ """Create input tensor based on the graph input of segment_ids
51
+
52
+ Args:
53
+ segment_ids (TensorProto): graph input of the token_type_ids input tensor
54
+ batch_size (int): batch size
55
+ sequence_length (int): sequence length
56
+
57
+ Returns:
58
+ np.ndarray: the input tensor created
59
+ """
60
+ assert segment_ids.type.tensor_type.elem_type in [
61
+ TensorProto.FLOAT,
62
+ TensorProto.INT32,
63
+ TensorProto.INT64,
64
+ ]
65
+
66
+ data = np.zeros((batch_size, sequence_length), dtype=np.int32)
67
+
68
+ if segment_ids.type.tensor_type.elem_type == TensorProto.FLOAT:
69
+ data = np.float32(data)
70
+ elif segment_ids.type.tensor_type.elem_type == TensorProto.INT64:
71
+ data = np.int64(data)
72
+
73
+ return data
74
+
75
+
76
+ def get_random_length(max_sequence_length: int, average_sequence_length: int):
77
+ assert average_sequence_length >= 1 and average_sequence_length <= max_sequence_length
78
+
79
+ # For uniform distribution, we find proper lower and upper bounds so that the average is in the middle.
80
+ if 2 * average_sequence_length > max_sequence_length:
81
+ return random.randint(2 * average_sequence_length - max_sequence_length, max_sequence_length)
82
+ else:
83
+ return random.randint(1, 2 * average_sequence_length - 1)
84
+
85
+
86
+ def fake_input_mask_data(
87
+ input_mask: TensorProto,
88
+ batch_size: int,
89
+ sequence_length: int,
90
+ average_sequence_length: int,
91
+ random_sequence_length: bool,
92
+ mask_type: int = 2,
93
+ ) -> np.ndarray:
94
+ """Create input tensor based on the graph input of segment_ids.
95
+
96
+ Args:
97
+ input_mask (TensorProto): graph input of the attention mask input tensor
98
+ batch_size (int): batch size
99
+ sequence_length (int): sequence length
100
+ average_sequence_length (int): average sequence length excluding paddings
101
+ random_sequence_length (bool): whether use uniform random number for sequence length
102
+ mask_type (int): mask type - 1: mask index (sequence length excluding paddings). Shape is (batch_size).
103
+ 2: 2D attention mask. Shape is (batch_size, sequence_length).
104
+ 3: key len, cumulated lengths of query and key. Shape is (3 * batch_size + 2).
105
+
106
+ Returns:
107
+ np.ndarray: the input tensor created
108
+ """
109
+
110
+ assert input_mask.type.tensor_type.elem_type in [
111
+ TensorProto.FLOAT,
112
+ TensorProto.INT32,
113
+ TensorProto.INT64,
114
+ ]
115
+
116
+ if mask_type == 1: # sequence length excluding paddings
117
+ data = np.ones((batch_size), dtype=np.int32)
118
+ if random_sequence_length:
119
+ for i in range(batch_size):
120
+ data[i] = get_random_length(sequence_length, average_sequence_length)
121
+ else:
122
+ for i in range(batch_size):
123
+ data[i] = average_sequence_length
124
+ elif mask_type == 2: # 2D attention mask
125
+ data = np.zeros((batch_size, sequence_length), dtype=np.int32)
126
+ if random_sequence_length:
127
+ for i in range(batch_size):
128
+ actual_seq_len = get_random_length(sequence_length, average_sequence_length)
129
+ for j in range(actual_seq_len):
130
+ data[i, j] = 1
131
+ else:
132
+ temp = np.ones((batch_size, average_sequence_length), dtype=np.int32)
133
+ data[: temp.shape[0], : temp.shape[1]] = temp
134
+ else:
135
+ assert mask_type == 3
136
+ data = np.zeros((batch_size * 3 + 2), dtype=np.int32)
137
+ if random_sequence_length:
138
+ for i in range(batch_size):
139
+ data[i] = get_random_length(sequence_length, average_sequence_length)
140
+
141
+ for i in range(batch_size + 1):
142
+ data[batch_size + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
143
+ data[2 * batch_size + 1 + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
144
+ else:
145
+ for i in range(batch_size):
146
+ data[i] = average_sequence_length
147
+ for i in range(batch_size + 1):
148
+ data[batch_size + i] = i * average_sequence_length
149
+ data[2 * batch_size + 1 + i] = i * average_sequence_length
150
+
151
+ if input_mask.type.tensor_type.elem_type == TensorProto.FLOAT:
152
+ data = np.float32(data)
153
+ elif input_mask.type.tensor_type.elem_type == TensorProto.INT64:
154
+ data = np.int64(data)
155
+
156
+ return data
157
+
158
+
159
+ def output_test_data(directory: str, inputs: dict[str, np.ndarray]):
160
+ """Output input tensors of test data to a directory
161
+
162
+ Args:
163
+ directory (str): path of a directory
164
+ inputs (Dict[str, np.ndarray]): map from input name to value
165
+ """
166
+ if not os.path.exists(directory):
167
+ try:
168
+ os.mkdir(directory)
169
+ except OSError:
170
+ print(f"Creation of the directory {directory} failed")
171
+ else:
172
+ print(f"Successfully created the directory {directory} ")
173
+ else:
174
+ print(f"Warning: directory {directory} existed. Files will be overwritten.")
175
+
176
+ for index, (name, data) in enumerate(inputs.items()):
177
+ tensor = numpy_helper.from_array(data, name)
178
+ with open(os.path.join(directory, f"input_{index}.pb"), "wb") as file:
179
+ file.write(tensor.SerializeToString())
180
+
181
+
182
+ def fake_test_data(
183
+ batch_size: int,
184
+ sequence_length: int,
185
+ test_cases: int,
186
+ dictionary_size: int,
187
+ verbose: bool,
188
+ random_seed: int,
189
+ input_ids: TensorProto,
190
+ segment_ids: TensorProto,
191
+ input_mask: TensorProto,
192
+ average_sequence_length: int,
193
+ random_sequence_length: bool,
194
+ mask_type: int,
195
+ ):
196
+ """Create given number of input data for testing
197
+
198
+ Args:
199
+ batch_size (int): batch size
200
+ sequence_length (int): sequence length
201
+ test_cases (int): number of test cases
202
+ dictionary_size (int): vocabulary size of dictionary for input_ids
203
+ verbose (bool): print more information or not
204
+ random_seed (int): random seed
205
+ input_ids (TensorProto): graph input of input IDs
206
+ segment_ids (TensorProto): graph input of token type IDs
207
+ input_mask (TensorProto): graph input of attention mask
208
+ average_sequence_length (int): average sequence length excluding paddings
209
+ random_sequence_length (bool): whether use uniform random number for sequence length
210
+ mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
211
+
212
+ Returns:
213
+ List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
214
+ with input name as key and a tensor as value
215
+ """
216
+ assert input_ids is not None
217
+
218
+ np.random.seed(random_seed)
219
+ random.seed(random_seed)
220
+
221
+ all_inputs = []
222
+ for _test_case in range(test_cases):
223
+ input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
224
+ inputs = {input_ids.name: input_1}
225
+
226
+ if segment_ids:
227
+ inputs[segment_ids.name] = fake_segment_ids_data(segment_ids, batch_size, sequence_length)
228
+
229
+ if input_mask:
230
+ inputs[input_mask.name] = fake_input_mask_data(
231
+ input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length, mask_type
232
+ )
233
+
234
+ if verbose and len(all_inputs) == 0:
235
+ print("Example inputs", inputs)
236
+ all_inputs.append(inputs)
237
+ return all_inputs
238
+
239
+
240
+ def generate_test_data(
241
+ batch_size: int,
242
+ sequence_length: int,
243
+ test_cases: int,
244
+ seed: int,
245
+ verbose: bool,
246
+ input_ids: TensorProto,
247
+ segment_ids: TensorProto,
248
+ input_mask: TensorProto,
249
+ average_sequence_length: int,
250
+ random_sequence_length: bool,
251
+ mask_type: int,
252
+ dictionary_size: int = 10000,
253
+ ):
254
+ """Create given number of input data for testing
255
+
256
+ Args:
257
+ batch_size (int): batch size
258
+ sequence_length (int): sequence length
259
+ test_cases (int): number of test cases
260
+ seed (int): random seed
261
+ verbose (bool): print more information or not
262
+ input_ids (TensorProto): graph input of input IDs
263
+ segment_ids (TensorProto): graph input of token type IDs
264
+ input_mask (TensorProto): graph input of attention mask
265
+ average_sequence_length (int): average sequence length excluding paddings
266
+ random_sequence_length (bool): whether use uniform random number for sequence length
267
+ mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
268
+
269
+ Returns:
270
+ List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
271
+ with input name as key and a tensor as value
272
+ """
273
+ all_inputs = fake_test_data(
274
+ batch_size,
275
+ sequence_length,
276
+ test_cases,
277
+ dictionary_size,
278
+ verbose,
279
+ seed,
280
+ input_ids,
281
+ segment_ids,
282
+ input_mask,
283
+ average_sequence_length,
284
+ random_sequence_length,
285
+ mask_type,
286
+ )
287
+ if len(all_inputs) != test_cases:
288
+ print("Failed to create test data for test.")
289
+ return all_inputs
290
+
291
+
292
+ def get_graph_input_from_embed_node(onnx_model, embed_node, input_index):
293
+ if input_index >= len(embed_node.input):
294
+ return None
295
+
296
+ input = embed_node.input[input_index]
297
+ graph_input = onnx_model.find_graph_input(input)
298
+ if graph_input is None:
299
+ parent_node = onnx_model.get_parent(embed_node, input_index)
300
+ if parent_node is not None and parent_node.op_type == "Cast":
301
+ graph_input = onnx_model.find_graph_input(parent_node.input[0])
302
+ return graph_input
303
+
304
+
305
+ def find_bert_inputs(
306
+ onnx_model: OnnxModel,
307
+ input_ids_name: str | None = None,
308
+ segment_ids_name: str | None = None,
309
+ input_mask_name: str | None = None,
310
+ ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
311
+ """Find graph inputs for BERT model.
312
+ First, we will deduce inputs from EmbedLayerNormalization node.
313
+ If not found, we will guess the meaning of graph inputs based on naming.
314
+
315
+ Args:
316
+ onnx_model (OnnxModel): onnx model object
317
+ input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
318
+ segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
319
+ input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
320
+
321
+ Raises:
322
+ ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
323
+ ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
324
+ and input_mask_name
325
+
326
+ Returns:
327
+ Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
328
+ segment_ids and input_mask
329
+ """
330
+
331
+ graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
332
+
333
+ if input_ids_name is not None:
334
+ input_ids = onnx_model.find_graph_input(input_ids_name)
335
+ if input_ids is None:
336
+ raise ValueError(f"Graph does not have input named {input_ids_name}")
337
+
338
+ segment_ids = None
339
+ if segment_ids_name:
340
+ segment_ids = onnx_model.find_graph_input(segment_ids_name)
341
+ if segment_ids is None:
342
+ raise ValueError(f"Graph does not have input named {segment_ids_name}")
343
+
344
+ input_mask = None
345
+ if input_mask_name:
346
+ input_mask = onnx_model.find_graph_input(input_mask_name)
347
+ if input_mask is None:
348
+ raise ValueError(f"Graph does not have input named {input_mask_name}")
349
+
350
+ expected_inputs = 1 + (1 if segment_ids else 0) + (1 if input_mask else 0)
351
+ if len(graph_inputs) != expected_inputs:
352
+ raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
353
+
354
+ return input_ids, segment_ids, input_mask
355
+
356
+ if len(graph_inputs) != 3:
357
+ raise ValueError(f"Expect the graph to have 3 inputs. Got {len(graph_inputs)}")
358
+
359
+ embed_nodes = onnx_model.get_nodes_by_op_type("EmbedLayerNormalization")
360
+ if len(embed_nodes) == 1:
361
+ embed_node = embed_nodes[0]
362
+ input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0)
363
+ segment_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 1)
364
+ input_mask = get_graph_input_from_embed_node(onnx_model, embed_node, 7)
365
+
366
+ if input_mask is None:
367
+ for input in graph_inputs:
368
+ input_name_lower = input.name.lower()
369
+ if "mask" in input_name_lower:
370
+ input_mask = input
371
+ if input_mask is None:
372
+ raise ValueError("Failed to find attention mask input")
373
+
374
+ return input_ids, segment_ids, input_mask
375
+
376
+ # Try guess the inputs based on naming.
377
+ input_ids = None
378
+ segment_ids = None
379
+ input_mask = None
380
+ for input in graph_inputs:
381
+ input_name_lower = input.name.lower()
382
+ if "mask" in input_name_lower: # matches input with name like "attention_mask" or "input_mask"
383
+ input_mask = input
384
+ elif (
385
+ "token" in input_name_lower or "segment" in input_name_lower
386
+ ): # matches input with name like "segment_ids" or "token_type_ids"
387
+ segment_ids = input
388
+ else:
389
+ input_ids = input
390
+
391
+ if input_ids and segment_ids and input_mask:
392
+ return input_ids, segment_ids, input_mask
393
+
394
+ raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
395
+
396
+
397
+ def get_bert_inputs(
398
+ onnx_file: str,
399
+ input_ids_name: str | None = None,
400
+ segment_ids_name: str | None = None,
401
+ input_mask_name: str | None = None,
402
+ ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
403
+ """Find graph inputs for BERT model.
404
+ First, we will deduce inputs from EmbedLayerNormalization node.
405
+ If not found, we will guess the meaning of graph inputs based on naming.
406
+
407
+ Args:
408
+ onnx_file (str): onnx model path
409
+ input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
410
+ segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
411
+ input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
412
+
413
+ Returns:
414
+ Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
415
+ segment_ids and input_mask
416
+ """
417
+ model = ModelProto()
418
+ with open(onnx_file, "rb") as file:
419
+ model.ParseFromString(file.read())
420
+
421
+ onnx_model = OnnxModel(model)
422
+ return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
423
+
424
+
425
+ def parse_arguments():
426
+ parser = argparse.ArgumentParser()
427
+
428
+ parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
429
+
430
+ parser.add_argument(
431
+ "--output_dir",
432
+ required=False,
433
+ type=str,
434
+ default=None,
435
+ help="output test data path. Default is current directory.",
436
+ )
437
+
438
+ parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
439
+
440
+ parser.add_argument(
441
+ "--sequence_length",
442
+ required=False,
443
+ type=int,
444
+ default=128,
445
+ help="maximum sequence length of input",
446
+ )
447
+
448
+ parser.add_argument(
449
+ "--input_ids_name",
450
+ required=False,
451
+ type=str,
452
+ default=None,
453
+ help="input name for input ids",
454
+ )
455
+ parser.add_argument(
456
+ "--segment_ids_name",
457
+ required=False,
458
+ type=str,
459
+ default=None,
460
+ help="input name for segment ids",
461
+ )
462
+ parser.add_argument(
463
+ "--input_mask_name",
464
+ required=False,
465
+ type=str,
466
+ default=None,
467
+ help="input name for attention mask",
468
+ )
469
+
470
+ parser.add_argument(
471
+ "--samples",
472
+ required=False,
473
+ type=int,
474
+ default=1,
475
+ help="number of test cases to be generated",
476
+ )
477
+
478
+ parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
479
+
480
+ parser.add_argument(
481
+ "--verbose",
482
+ required=False,
483
+ action="store_true",
484
+ help="print verbose information",
485
+ )
486
+ parser.set_defaults(verbose=False)
487
+
488
+ parser.add_argument(
489
+ "--only_input_tensors",
490
+ required=False,
491
+ action="store_true",
492
+ help="only save input tensors and no output tensors",
493
+ )
494
+ parser.set_defaults(only_input_tensors=False)
495
+
496
+ parser.add_argument(
497
+ "-a",
498
+ "--average_sequence_length",
499
+ default=-1,
500
+ type=int,
501
+ help="average sequence length excluding padding",
502
+ )
503
+
504
+ parser.add_argument(
505
+ "-r",
506
+ "--random_sequence_length",
507
+ required=False,
508
+ action="store_true",
509
+ help="use uniform random instead of fixed sequence length",
510
+ )
511
+ parser.set_defaults(random_sequence_length=False)
512
+
513
+ parser.add_argument(
514
+ "--mask_type",
515
+ required=False,
516
+ type=int,
517
+ default=2,
518
+ help="mask type: (1: mask index, 2: raw 2D mask, 3: key lengths, cumulated lengths of query and key)",
519
+ )
520
+
521
+ args = parser.parse_args()
522
+ return args
523
+
524
+
525
+ def create_and_save_test_data(
526
+ model: str,
527
+ output_dir: str,
528
+ batch_size: int,
529
+ sequence_length: int,
530
+ test_cases: int,
531
+ seed: int,
532
+ verbose: bool,
533
+ input_ids_name: str | None,
534
+ segment_ids_name: str | None,
535
+ input_mask_name: str | None,
536
+ only_input_tensors: bool,
537
+ average_sequence_length: int,
538
+ random_sequence_length: bool,
539
+ mask_type: int,
540
+ ):
541
+ """Create test data for a model, and save test data to a directory.
542
+
543
+ Args:
544
+ model (str): path of ONNX bert model
545
+ output_dir (str): output directory
546
+ batch_size (int): batch size
547
+ sequence_length (int): sequence length
548
+ test_cases (int): number of test cases
549
+ seed (int): random seed
550
+ verbose (bool): whether print more information
551
+ input_ids_name (str): graph input name of input_ids
552
+ segment_ids_name (str): graph input name of segment_ids
553
+ input_mask_name (str): graph input name of input_mask
554
+ only_input_tensors (bool): only save input tensors,
555
+ average_sequence_length (int): average sequence length excluding paddings
556
+ random_sequence_length (bool): whether use uniform random number for sequence length
557
+ mask_type(int): mask type
558
+ """
559
+ input_ids, segment_ids, input_mask = get_bert_inputs(model, input_ids_name, segment_ids_name, input_mask_name)
560
+
561
+ all_inputs = generate_test_data(
562
+ batch_size,
563
+ sequence_length,
564
+ test_cases,
565
+ seed,
566
+ verbose,
567
+ input_ids,
568
+ segment_ids,
569
+ input_mask,
570
+ average_sequence_length,
571
+ random_sequence_length,
572
+ mask_type,
573
+ )
574
+
575
+ for i, inputs in enumerate(all_inputs):
576
+ directory = os.path.join(output_dir, "test_data_set_" + str(i))
577
+ output_test_data(directory, inputs)
578
+
579
+ if only_input_tensors:
580
+ return
581
+
582
+ import onnxruntime # noqa: PLC0415
583
+
584
+ providers = (
585
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]
586
+ if "CUDAExecutionProvider" in onnxruntime.get_available_providers()
587
+ else ["CPUExecutionProvider"]
588
+ )
589
+ session = onnxruntime.InferenceSession(model, providers=providers)
590
+ output_names = [output.name for output in session.get_outputs()]
591
+
592
+ for i, inputs in enumerate(all_inputs):
593
+ directory = os.path.join(output_dir, "test_data_set_" + str(i))
594
+ result = session.run(output_names, inputs)
595
+ for i, output_name in enumerate(output_names): # noqa: PLW2901
596
+ tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_name)
597
+ with open(os.path.join(directory, f"output_{i}.pb"), "wb") as file:
598
+ file.write(tensor_result.SerializeToString())
599
+
600
+
601
+ def main():
602
+ args = parse_arguments()
603
+
604
+ if args.average_sequence_length <= 0:
605
+ args.average_sequence_length = args.sequence_length
606
+
607
+ output_dir = args.output_dir
608
+ if output_dir is None:
609
+ # Default output directory is a sub-directory under the directory of model.
610
+ p = Path(args.model)
611
+ output_dir = os.path.join(p.parent, f"batch_{args.batch_size}_seq_{args.sequence_length}")
612
+
613
+ if output_dir is not None:
614
+ # create the output directory if not existed
615
+ path = Path(output_dir)
616
+ path.mkdir(parents=True, exist_ok=True)
617
+ else:
618
+ print("Directory existed. test data files will be overwritten.")
619
+
620
+ create_and_save_test_data(
621
+ args.model,
622
+ output_dir,
623
+ args.batch_size,
624
+ args.sequence_length,
625
+ args.samples,
626
+ args.seed,
627
+ args.verbose,
628
+ args.input_ids_name,
629
+ args.segment_ids_name,
630
+ args.input_mask_name,
631
+ args.only_input_tensors,
632
+ args.average_sequence_length,
633
+ args.random_sequence_length,
634
+ args.mask_type,
635
+ )
636
+
637
+ print("Test data is saved to directory:", output_dir)
638
+
639
+
640
+ if __name__ == "__main__":
641
+ main()