onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,395 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ """
7
+ Export LLM to onnx
8
+ """
9
+
10
+ import argparse
11
+ import inspect
12
+ import math
13
+ import os
14
+ import tempfile
15
+ from pathlib import Path
16
+
17
+ import onnx
18
+ import torch
19
+ import transformers
20
+ from torch import nn
21
+
22
+
23
+ def disable_huggingface_init():
24
+ """do not init model twice as it slow initialization"""
25
+
26
+ torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x
27
+ torch.nn.init.uniform_ = lambda x, *args, **kwargs: x
28
+ torch.nn.init.normal_ = lambda x, *args, **kwargs: x
29
+ torch.nn.init.constant_ = lambda x, *args, **kwargs: x
30
+ torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x
31
+ torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x
32
+ torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x
33
+ torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x
34
+
35
+
36
+ def get_model_parameter_size(model: nn.Module):
37
+ """to calculate how much memory this model needs"""
38
+ param_size = 0
39
+ param_sum = 0
40
+ for param in model.parameters():
41
+ param_size += param.nelement() * param.element_size()
42
+ param_sum += param.nelement()
43
+ buffer_size = 0
44
+ buffer_sum = 0
45
+ for buffer in model.buffers():
46
+ buffer_size += buffer.nelement() * buffer.element_size()
47
+ buffer_sum += buffer.nelement()
48
+ all_size = (param_size + buffer_size) / 1024 / 1024
49
+ return all_size
50
+
51
+
52
+ def initialize_model_and_sample_inputs(hf_model: str, cache_dir: str | None, tokenizer=None):
53
+ """
54
+ get the pretrained torch model from hugginface,
55
+ and sample model-inputs
56
+ """
57
+
58
+ disable_huggingface_init()
59
+
60
+ model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore
61
+ hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True
62
+ )
63
+ if tokenizer is None:
64
+ tokenizer = hf_model
65
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore
66
+
67
+ sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values())
68
+ return model, sample_inputs
69
+
70
+
71
+ def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple):
72
+ """Make the model executable across multiple GPUs."""
73
+
74
+ def input_gpu_device_hook(mod, inputs, kwargs):
75
+ modifyed_inputs = []
76
+ first_dev = None
77
+ for layer_input in inputs:
78
+ if type(layer_input) is not torch.Tensor:
79
+ modifyed_inputs.append(layer_input)
80
+ elif hasattr(mod, "weight"):
81
+ modifyed_inputs.append(layer_input.to(mod.weight.device))
82
+ elif hasattr(mod, "parameters"):
83
+ device = next(mod.parameters(), layer_input).device
84
+ modifyed_inputs.append(layer_input.to(device))
85
+ elif hasattr(next(mod.children(), None), "weight"):
86
+ modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device))
87
+ elif first_dev is not None and layer_input.device != first_dev:
88
+ modifyed_inputs.append(layer_input.to(first_dev))
89
+ else:
90
+ modifyed_inputs.append(layer_input)
91
+ if first_dev is None:
92
+ first_dev = modifyed_inputs[0].device
93
+ for key, value in kwargs.items():
94
+ if type(value) is torch.Tensor:
95
+ kwargs[key] = value.to(first_dev)
96
+
97
+ return (tuple(modifyed_inputs), kwargs)
98
+
99
+ def move_layer_to_device_rurc(mod, dev):
100
+ mod.to(dev)
101
+ for layer in mod.named_children():
102
+ move_layer_to_device_rurc(layer[1], dev)
103
+
104
+ model = model.half()
105
+ all_hooks = []
106
+ all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
107
+ pre_fix = next(iter(model.named_children()))[0]
108
+ for top_name, top_module in model.named_children():
109
+ for name, module in top_module.named_children():
110
+ all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
111
+ if type(module) in [torch.nn.ModuleList]:
112
+ num_layers_on_each_gpu = math.floor(len(module) / len(gpulist))
113
+ for idx, attn_layer in enumerate(module):
114
+ all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True))
115
+
116
+ to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))]
117
+ attn_layer.to(to_dev)
118
+ move_layer_to_device_rurc(attn_layer, to_dev)
119
+ print(f"move {pre_fix}.{name}.{idx} to {to_dev}")
120
+ else:
121
+ module.to(gpulist[0])
122
+ print(f"move {pre_fix}.{name} to {gpulist[0]}")
123
+ if len(list(top_module.named_children())) == 0:
124
+ top_module.to(gpulist[0])
125
+ print(f"move {top_name} to {gpulist[0]}")
126
+
127
+ with torch.no_grad():
128
+ model(sample_inputs[0], attention_mask=sample_inputs[1])
129
+ return model
130
+
131
+
132
+ def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool):
133
+ """
134
+ auto retrieve onnx inputs from torch model as we can't enumlate all possibilities
135
+ for all models
136
+ """
137
+ user_inputs = []
138
+
139
+ def hook_for_inputs(_, inputs, kwargs):
140
+ user_inputs.append((inputs, kwargs))
141
+ return user_inputs[0]
142
+
143
+ hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True)
144
+
145
+ forward_params = inspect.signature(model.forward).parameters
146
+ input_keys = list(forward_params.keys())
147
+ default_values = [forward_params.get(key).default for key in input_keys]
148
+ out = model(sample_inputs[0], attention_mask=sample_inputs[1])
149
+ hook_handle.remove()
150
+ user_inputs = user_inputs[0]
151
+ onnx_inputs = default_values
152
+ for idx, _val in enumerate(user_inputs[0]):
153
+ onnx_inputs[idx] = user_inputs[0][idx]
154
+ for key, value in user_inputs[1].items():
155
+ idx = input_keys.index(key)
156
+ onnx_inputs[idx] = value
157
+ for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs, strict=False)):
158
+ if type(value) is torch.Tensor:
159
+ value.to(model.device)
160
+ if "use_cache" in key:
161
+ onnx_inputs[idx] = with_past
162
+ out = model(sample_inputs[0], attention_mask=sample_inputs[1], use_cache=with_past) if with_past else out
163
+
164
+ return input_keys, onnx_inputs, out.past_key_values
165
+
166
+
167
+ def move_to_appropriate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module:
168
+ """
169
+ According to the model size, we will upload it to
170
+ CPU if has no GPU or enough GPU memory,
171
+ Single GPU if has only one GPU in local or model size is enough to fit one GPU
172
+ Multiple GPU if there is more than one gpu in local and model is too large
173
+ """
174
+ total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024
175
+
176
+ print(f"Model_Size = {get_model_parameter_size(model) / 1024} GB")
177
+ print(f"total_mem_per_cpu = {total_mem_per_cpu / 1024} GB")
178
+ if get_model_parameter_size(model) > total_mem_per_cpu * 0.45:
179
+ device_collection = [torch.device(i) for i in range(torch.cuda.device_count())]
180
+ if len(device_collection) > 1:
181
+ print(
182
+ f"{len(device_collection)} GPUs are used to export onnx, \
183
+ Please set CUDA_VISIBLE_DEVICES to use specific GPU group"
184
+ )
185
+ model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp)
186
+ else:
187
+ print("!!!! convert model to float and export onnx using CPU")
188
+ model = model.cpu().float()
189
+ else:
190
+ print("Export model on a single GPU")
191
+ model = model.cuda().half()
192
+ return model
193
+
194
+
195
+ def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple:
196
+ """move inputs to device"""
197
+ sample_inputs_ = []
198
+ for sample_int in sample_inputs:
199
+ if isinstance(sample_int, torch.Tensor):
200
+ sample_inputs_.append(sample_int.to(device))
201
+ else:
202
+ sample_inputs_.append(sample_int)
203
+ return tuple(sample_inputs_)
204
+
205
+
206
+ def fetch_onnx_inputs_outputs_name(
207
+ model: nn.Module,
208
+ onnx_inputs: list,
209
+ torch_input_names: tuple,
210
+ past_key_values: tuple,
211
+ with_past: bool,
212
+ input_with_past: bool,
213
+ ):
214
+ """fetch onnx inputs and outputs name"""
215
+ num_of_past_key = 0
216
+ kv_cache_axis = {0: "batch_size"}
217
+ # try get num_of_past_key and shape of past_key_value
218
+ if past_key_values is not None:
219
+ num_of_past_key = len(past_key_values)
220
+ seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1)
221
+ assert seq_index.numel() == 1
222
+ kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"}
223
+
224
+ if not num_of_past_key:
225
+ num_of_past_key = model.config.num_hidden_layers
226
+
227
+ # filter out constant inputs
228
+ onnx_inp_names = tuple(
229
+ [torch_input_names[i] for i in range(len(torch_input_names)) if isinstance(onnx_inputs[i], torch.Tensor)]
230
+ )
231
+ assert "input_ids" in onnx_inp_names and "attention_mask" in onnx_inp_names, (
232
+ "input_ids and attention_mask must be existed in inputs"
233
+ )
234
+ onnx_out_names = ("logits",)
235
+ onnx_dynamic_axes = {
236
+ "input_ids": {0: "batch_size", 1: "seq_len"},
237
+ "attention_mask": {0: "batch_size", 1: "seq_len"},
238
+ }
239
+ # add dyanmic dimensions for the unkonw inputs
240
+ for idx, name in enumerate(onnx_inp_names):
241
+ if name not in onnx_dynamic_axes:
242
+ unknown_dims = {i: f"{idx}__unknown_dims__{i}" for i in range(onnx_inputs[idx].dim())}
243
+ onnx_dynamic_axes[name] = unknown_dims
244
+ if input_with_past:
245
+ for i in range(num_of_past_key):
246
+ onnx_inp_names += (f"past_key_values.{i}.key",)
247
+ onnx_inp_names += (f"past_key_values.{i}.value",)
248
+
249
+ onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis
250
+ onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis
251
+
252
+ if with_past or input_with_past:
253
+ for i in range(num_of_past_key):
254
+ onnx_out_names += (f"present.{i}.key",)
255
+ onnx_out_names += (f"present.{i}.value",)
256
+
257
+ for idx, name in enumerate(torch_input_names):
258
+ if input_with_past:
259
+ if name == "past_key_values":
260
+ onnx_inputs[idx] = past_key_values
261
+ elif name == "attention_mask":
262
+ attn_mask = onnx_inputs[idx]
263
+ onnx_inputs[idx] = torch.cat(
264
+ (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device, dtype=attn_mask.dtype)),
265
+ dim=1,
266
+ )
267
+ elif name == "input_ids":
268
+ input_ids = onnx_inputs[idx]
269
+ onnx_inputs[idx] = input_ids[:, -1:]
270
+
271
+ return onnx_inp_names, onnx_out_names, onnx_dynamic_axes
272
+
273
+
274
+ def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int):
275
+ """do export with torch.onnx.export"""
276
+ onnx_model_name = onnx_path.name
277
+ onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple
278
+ # two step to export onnx
279
+ # 1. export onnx with lots of pieces of weights
280
+ # 2. save all weights to external data
281
+ with tempfile.TemporaryDirectory() as tmpdirname:
282
+ tmp_onnx = os.path.join(tmpdirname, "tmp.onnx")
283
+
284
+ torch.onnx.export(
285
+ model=model,
286
+ args=tuple(onnx_inputs),
287
+ f=tmp_onnx,
288
+ verbose=False,
289
+ opset_version=opset,
290
+ input_names=onnx_inp_names,
291
+ output_names=onnx_out_names,
292
+ dynamic_axes=onnx_dynamic_axes,
293
+ )
294
+
295
+ onnx_path.unlink(missing_ok=True)
296
+ (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True)
297
+
298
+ onnx_model = onnx.load(str(tmp_onnx))
299
+ onnx.save_model(
300
+ onnx_model,
301
+ str(onnx_path),
302
+ save_as_external_data=(len(os.listdir(tmpdirname)) > 1),
303
+ all_tensors_to_one_file=True,
304
+ location=f"{onnx_model_name}_ext.data",
305
+ size_threshold=1024,
306
+ convert_attribute=False,
307
+ )
308
+
309
+
310
+ @torch.no_grad()
311
+ def export_onnx(hf_model: str, cache_dir: str | None, onnx_path_str: str, with_past: bool, opset: int):
312
+ """
313
+ do export
314
+ model: torch model
315
+ onnx_path: where the onnx model saved to
316
+ sample_inputs_tp: inputs for torch model
317
+ """
318
+ model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir)
319
+
320
+ model = move_to_appropriate_device(model, sample_inputs_tp)
321
+
322
+ sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device)
323
+
324
+ # input_keys would be usesful if the model has some special inputs
325
+ input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past)
326
+
327
+ onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False)
328
+
329
+ onnx_model_name = "model.onnx"
330
+ onnx_path: Path = Path(onnx_path_str).absolute()
331
+ if onnx_path.suffix != ".onnx":
332
+ onnx_path = onnx_path / onnx_model_name
333
+
334
+ do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
335
+ if not with_past:
336
+ return
337
+
338
+ onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True)
339
+
340
+ onnx_model_name = "model_with_past.onnx"
341
+ onnx_path = onnx_path.parent / onnx_model_name
342
+
343
+ do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset)
344
+
345
+
346
+ def parse_arguments():
347
+ """arguments parsing."""
348
+ parser = argparse.ArgumentParser()
349
+
350
+ parser.add_argument(
351
+ "-m",
352
+ "--model",
353
+ required=True,
354
+ type=str,
355
+ default=["meta-llama/Llama-2-70b-hf"],
356
+ help="Pre-trained models in huggingface model hub",
357
+ )
358
+ parser.add_argument(
359
+ "-s",
360
+ "--saved_path",
361
+ required=False,
362
+ type=str,
363
+ default="./onnx_models/",
364
+ help="where the onnx model will be saved",
365
+ )
366
+ parser.add_argument(
367
+ "--cache_dir",
368
+ required=False,
369
+ type=str,
370
+ default=None,
371
+ help=("cache directly of huggingface, by setting this to avoid useless downloading if you have one"),
372
+ )
373
+ parser.add_argument(
374
+ "--with_past",
375
+ action="store_true",
376
+ default=False,
377
+ help=("The tool will export onnx without past-key-value by default"),
378
+ )
379
+ parser.add_argument(
380
+ "--opset",
381
+ required=False,
382
+ type=int,
383
+ default=17,
384
+ help=(
385
+ "the opset to save onnx model, \
386
+ try to increase it if this opset doens't have new features you want"
387
+ ),
388
+ )
389
+ return parser.parse_args()
390
+
391
+
392
+ if __name__ == "__main__":
393
+ args = parse_arguments()
394
+
395
+ export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset)
@@ -0,0 +1,230 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ # It is used to dump machine information for Notebooks
7
+
8
+ import argparse
9
+ import importlib.metadata
10
+ import json
11
+ import logging
12
+ import platform
13
+ from os import environ
14
+
15
+ import cpuinfo
16
+ import psutil
17
+ from py3nvml.py3nvml import (
18
+ NVMLError,
19
+ nvmlDeviceGetCount,
20
+ nvmlDeviceGetHandleByIndex,
21
+ nvmlDeviceGetMemoryInfo,
22
+ nvmlDeviceGetName,
23
+ nvmlInit,
24
+ nvmlShutdown,
25
+ nvmlSystemGetDriverVersion,
26
+ )
27
+
28
+
29
+ class MachineInfo:
30
+ """Class encapsulating Machine Info logic."""
31
+
32
+ def __init__(self, silent=False, logger=None):
33
+ self.silent = silent
34
+
35
+ if logger is None:
36
+ logging.basicConfig(
37
+ format="%(asctime)s - %(name)s - %(levelname)s: %(message)s",
38
+ level=logging.INFO,
39
+ )
40
+ self.logger = logging.getLogger(__name__)
41
+ else:
42
+ self.logger = logger
43
+
44
+ self.machine_info = None
45
+ try:
46
+ self.machine_info = self.get_machine_info()
47
+ except Exception:
48
+ self.logger.exception("Exception in getting machine info.")
49
+ self.machine_info = None
50
+
51
+ def get_machine_info(self):
52
+ """Get machine info in metric format"""
53
+ gpu_info = self.get_gpu_info_by_nvml()
54
+ cpu_info = cpuinfo.get_cpu_info()
55
+
56
+ machine_info = {
57
+ "gpu": gpu_info,
58
+ "cpu": self.get_cpu_info(),
59
+ "memory": self.get_memory_info(),
60
+ "os": platform.platform(),
61
+ "python": self._try_get(cpu_info, ["python_version"]),
62
+ "packages": self.get_related_packages(),
63
+ "onnxruntime": self.get_onnxruntime_info(),
64
+ "pytorch": self.get_pytorch_info(),
65
+ "tensorflow": self.get_tensorflow_info(),
66
+ }
67
+ return machine_info
68
+
69
+ def get_memory_info(self) -> dict:
70
+ """Get memory info"""
71
+ mem = psutil.virtual_memory()
72
+ return {"total": mem.total, "available": mem.available}
73
+
74
+ def _try_get(self, cpu_info: dict, names: list) -> str:
75
+ for name in names:
76
+ if name in cpu_info:
77
+ value = cpu_info[name]
78
+ if isinstance(value, (list, tuple)):
79
+ return ",".join([str(i) for i in value])
80
+ return value
81
+ return ""
82
+
83
+ def get_cpu_info(self) -> dict:
84
+ """Get CPU info"""
85
+ cpu_info = cpuinfo.get_cpu_info()
86
+
87
+ return {
88
+ "brand": self._try_get(cpu_info, ["brand", "brand_raw"]),
89
+ "cores": psutil.cpu_count(logical=False),
90
+ "logical_cores": psutil.cpu_count(logical=True),
91
+ "hz": self._try_get(cpu_info, ["hz_actual"]),
92
+ "l2_cache": self._try_get(cpu_info, ["l2_cache_size"]),
93
+ "flags": self._try_get(cpu_info, ["flags"]),
94
+ "processor": platform.uname().processor,
95
+ }
96
+
97
+ def get_gpu_info_by_nvml(self) -> dict:
98
+ """Get GPU info using nvml"""
99
+ gpu_info_list = []
100
+ driver_version = None
101
+ try:
102
+ nvmlInit()
103
+ driver_version = nvmlSystemGetDriverVersion()
104
+ deviceCount = nvmlDeviceGetCount() # noqa: N806
105
+ for i in range(deviceCount):
106
+ handle = nvmlDeviceGetHandleByIndex(i)
107
+ info = nvmlDeviceGetMemoryInfo(handle)
108
+ gpu_info = {}
109
+ gpu_info["memory_total"] = info.total
110
+ gpu_info["memory_available"] = info.free
111
+ gpu_info["name"] = nvmlDeviceGetName(handle)
112
+ gpu_info_list.append(gpu_info)
113
+ nvmlShutdown()
114
+ except NVMLError as error:
115
+ if not self.silent:
116
+ self.logger.error("Error fetching GPU information using nvml: %s", error)
117
+ return None
118
+
119
+ result = {"driver_version": driver_version, "devices": gpu_info_list}
120
+
121
+ if "CUDA_VISIBLE_DEVICES" in environ:
122
+ result["cuda_visible"] = environ["CUDA_VISIBLE_DEVICES"]
123
+ return result
124
+
125
+ def get_related_packages(self) -> list[str]:
126
+ related_packages = {
127
+ "onnxruntime-gpu",
128
+ "onnxruntime",
129
+ "onnx",
130
+ "transformers",
131
+ "protobuf",
132
+ "sympy",
133
+ "torch",
134
+ "tensorflow",
135
+ "flatbuffers",
136
+ "numpy",
137
+ "onnxconverter-common",
138
+ }
139
+ related_packages_list = {}
140
+ for dist in importlib.metadata.distributions():
141
+ if dist.metadata["Name"].lower() in related_packages:
142
+ related_packages_list[dist.metadata["Name"].lower()] = dist.version
143
+
144
+ return related_packages_list
145
+
146
+ def get_onnxruntime_info(self) -> dict:
147
+ try:
148
+ import onnxruntime # noqa: PLC0415
149
+
150
+ return {
151
+ "version": onnxruntime.__version__,
152
+ "support_gpu": "CUDAExecutionProvider" in onnxruntime.get_available_providers(),
153
+ }
154
+ except ImportError as error:
155
+ if not self.silent:
156
+ self.logger.exception(error)
157
+ return None
158
+ except Exception as exception:
159
+ if not self.silent:
160
+ self.logger.exception(exception, False)
161
+ return None
162
+
163
+ def get_pytorch_info(self) -> dict:
164
+ try:
165
+ import torch # noqa: PLC0415
166
+
167
+ return {
168
+ "version": torch.__version__,
169
+ "support_gpu": torch.cuda.is_available(),
170
+ "cuda": torch.version.cuda,
171
+ }
172
+ except ImportError as error:
173
+ if not self.silent:
174
+ self.logger.exception(error)
175
+ return None
176
+ except Exception as exception:
177
+ if not self.silent:
178
+ self.logger.exception(exception, False)
179
+ return None
180
+
181
+ def get_tensorflow_info(self) -> dict:
182
+ try:
183
+ import tensorflow as tf # noqa: PLC0415
184
+
185
+ return {
186
+ "version": tf.version.VERSION,
187
+ "git_version": tf.version.GIT_VERSION,
188
+ "support_gpu": tf.test.is_built_with_cuda(),
189
+ }
190
+ except ImportError as error:
191
+ if not self.silent:
192
+ self.logger.exception(error)
193
+ return None
194
+ except ModuleNotFoundError as error:
195
+ if not self.silent:
196
+ self.logger.exception(error)
197
+ return None
198
+
199
+
200
+ def parse_arguments():
201
+ parser = argparse.ArgumentParser()
202
+
203
+ parser.add_argument(
204
+ "--silent",
205
+ required=False,
206
+ action="store_true",
207
+ help="Do not print error message",
208
+ )
209
+ parser.set_defaults(silent=False)
210
+
211
+ args = parser.parse_args()
212
+ return args
213
+
214
+
215
+ def get_machine_info(silent=True) -> str:
216
+ machine = MachineInfo(silent)
217
+ return json.dumps(machine.machine_info, indent=2)
218
+
219
+
220
+ def get_device_info(silent=True) -> str:
221
+ machine = MachineInfo(silent)
222
+ info = machine.machine_info
223
+ if info:
224
+ info = {key: value for key, value in info.items() if key in ["gpu", "cpu", "memory"]}
225
+ return json.dumps(info, indent=2)
226
+
227
+
228
+ if __name__ == "__main__":
229
+ args = parse_arguments()
230
+ print(get_machine_info(args.silent))