onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,154 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ """
6
+ Check OS requirements for ONNX Runtime Python Bindings.
7
+ """
8
+
9
+ import linecache
10
+ import platform
11
+ import warnings
12
+
13
+
14
+ def check_distro_info():
15
+ __my_distro__ = ""
16
+ __my_distro_ver__ = ""
17
+ __my_system__ = platform.system().lower()
18
+
19
+ __OS_RELEASE_FILE__ = "/etc/os-release" # noqa: N806
20
+ __LSB_RELEASE_FILE__ = "/etc/lsb-release" # noqa: N806
21
+
22
+ if __my_system__ == "windows":
23
+ __my_distro__ = __my_system__
24
+ __my_distro_ver__ = platform.release().lower()
25
+
26
+ if __my_distro_ver__ not in ["10", "11", "2016server", "2019server", "2022server", "2025server"]:
27
+ warnings.warn(
28
+ f"Unsupported Windows version ({__my_distro_ver__}). ONNX Runtime supports Windows 10 and above, or Windows Server 2016 and above."
29
+ )
30
+ elif __my_system__ == "linux":
31
+ """Although the 'platform' python module for getting Distro information works well on standard OS images
32
+ running on real hardware, it is not accurate when running on Azure VMs, Git Bash, Cygwin, etc.
33
+ The returned values for release and version are unpredictable for virtualized or emulated environments.
34
+ /etc/os-release and /etc/lsb_release files, on the other hand, are guaranteed to exist and have standard values
35
+ in all OSes supported by onnxruntime. The former is the current standard file to check OS info and the latter
36
+ is its predecessor.
37
+ """
38
+ # Newer systems have /etc/os-release with relevant distro info
39
+ __my_distro__ = linecache.getline(__OS_RELEASE_FILE__, 3)[3:-1]
40
+ __my_distro_ver__ = linecache.getline(__OS_RELEASE_FILE__, 6)[12:-2]
41
+
42
+ # Older systems may have /etc/os-release instead
43
+ if not __my_distro__:
44
+ __my_distro__ = linecache.getline(__LSB_RELEASE_FILE__, 1)[11:-1]
45
+ __my_distro_ver__ = linecache.getline(__LSB_RELEASE_FILE__, 2)[16:-1]
46
+
47
+ # Instead of trying to parse distro specific files,
48
+ # warn the user ONNX Runtime may not work out of the box
49
+ __my_distro__ = __my_distro__.lower()
50
+ __my_distro_ver__ = __my_distro_ver__.lower()
51
+ elif __my_system__ == "darwin":
52
+ __my_distro__ = __my_system__
53
+ __my_distro_ver__ = platform.release().lower()
54
+
55
+ if int(__my_distro_ver__.split(".")[0]) < 11:
56
+ warnings.warn(
57
+ f"Unsupported macOS version ({__my_distro_ver__}). ONNX Runtime supports macOS 11.0 or later."
58
+ )
59
+ elif __my_system__ == "aix":
60
+ import subprocess # noqa: PLC0415
61
+
62
+ returned_output = subprocess.check_output("oslevel")
63
+ __my_distro_ver__str = returned_output.decode("utf-8")
64
+ __my_distro_ver = __my_distro_ver__str[:3]
65
+ else:
66
+ warnings.warn(
67
+ f"Unsupported platform ({__my_system__}). ONNX Runtime supports Linux, macOS, AIX and Windows platforms, only."
68
+ )
69
+
70
+
71
+ def get_package_name_and_version_info():
72
+ package_name = ""
73
+ version = ""
74
+ cuda_version = ""
75
+
76
+ try:
77
+ from .build_and_package_info import __version__ as version # noqa: PLC0415
78
+ from .build_and_package_info import package_name # noqa: PLC0415
79
+
80
+ try: # noqa: SIM105
81
+ from .build_and_package_info import cuda_version # noqa: PLC0415
82
+ except ImportError:
83
+ # cuda_version is optional. For example, cpu only package does not have the attribute.
84
+ pass
85
+ except Exception as e:
86
+ warnings.warn("WARNING: failed to collect package name and version info")
87
+ print(e)
88
+
89
+ return package_name, version, cuda_version
90
+
91
+
92
+ def check_training_module():
93
+ import_ortmodule_exception = None
94
+
95
+ has_ortmodule = False
96
+ try:
97
+ from onnxruntime.training.ortmodule import ORTModule # noqa: F401, PLC0415
98
+
99
+ has_ortmodule = True
100
+ except ImportError:
101
+ # ORTModule not present
102
+ has_ortmodule = False
103
+ except Exception as e:
104
+ # this may happen if Cuda is not installed, we want to raise it after
105
+ # for any exception other than not having ortmodule, we want to continue
106
+ # device version validation and raise the exception after.
107
+ try:
108
+ from onnxruntime.training.ortmodule._fallback import ORTModuleInitException # noqa: PLC0415
109
+
110
+ if isinstance(e, ORTModuleInitException):
111
+ # ORTModule is present but not ready to run yet
112
+ has_ortmodule = True
113
+ except Exception:
114
+ # ORTModule not present
115
+ has_ortmodule = False
116
+
117
+ if not has_ortmodule:
118
+ import_ortmodule_exception = e
119
+
120
+ # collect onnxruntime package name, version, and cuda version
121
+ package_name, version, cuda_version = get_package_name_and_version_info()
122
+
123
+ if has_ortmodule and cuda_version:
124
+ try:
125
+ # collect cuda library build info. the library info may not be available
126
+ # when the build environment has none or multiple libraries installed
127
+ try:
128
+ from .build_and_package_info import cudart_version # noqa: PLC0415
129
+ except ImportError:
130
+ warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
131
+ cudart_version = None
132
+
133
+ def print_build_package_info():
134
+ warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
135
+ warnings.warn(f"onnxruntime training package info: __version__: {version}")
136
+ warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
137
+ warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
138
+
139
+ # collection cuda library info from current environment.
140
+ from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions # noqa: PLC0415
141
+
142
+ local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version)
143
+ if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions:
144
+ print_build_package_info()
145
+ warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
146
+ warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
147
+ except Exception as e:
148
+ warnings.warn("WARNING: failed to collect onnxruntime version and build info")
149
+ print(e)
150
+
151
+ if import_ortmodule_exception:
152
+ raise import_ortmodule_exception
153
+
154
+ return has_ortmodule, package_name, version, cuda_version
@@ -0,0 +1,2 @@
1
+ use_cuda = False
2
+ vs2019 = False
@@ -0,0 +1,18 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+ """
4
+ Short examples used in the documentation.
5
+ """
6
+
7
+ import os
8
+
9
+
10
+ def get_example(name):
11
+ """
12
+ Retrieves the absolute file name of an example.
13
+ """
14
+ this = os.path.abspath(os.path.dirname(__file__))
15
+ full = os.path.join(this, name)
16
+ if not os.path.exists(full):
17
+ raise FileNotFoundError(f"Unable to find example '{name}'")
18
+ return full
Binary file
Binary file
@@ -0,0 +1,13 @@
1
+  backend-test:Q
2
+ 
3
+ xy"Sigmoid test_sigmoidZ
4
+ x
5
+ 
6
+ 
7
+ 
8
+ b
9
+ y
10
+ 
11
+ 
12
+ 
13
+ B
@@ -0,0 +1,78 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: CalTableFlatBuffers
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+
8
+ np = import_numpy()
9
+
10
+
11
+ class KeyValue:
12
+ __slots__ = ["_tab"]
13
+
14
+ @classmethod
15
+ def GetRootAs(cls, buf, offset=0): # noqa: N802
16
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17
+ x = KeyValue()
18
+ x.Init(buf, n + offset)
19
+ return x
20
+
21
+ @classmethod
22
+ def GetRootAsKeyValue(cls, buf, offset=0): # noqa: N802
23
+ """This method is deprecated. Please switch to GetRootAs."""
24
+ return cls.GetRootAs(buf, offset)
25
+
26
+ # KeyValue
27
+ def Init(self, buf, pos): # noqa: N802
28
+ self._tab = flatbuffers.table.Table(buf, pos)
29
+
30
+ # KeyValue
31
+ def Key(self): # noqa: N802
32
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
33
+ if o != 0:
34
+ return self._tab.String(o + self._tab.Pos)
35
+ return None
36
+
37
+ # KeyValue
38
+ def Value(self): # noqa: N802
39
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
40
+ if o != 0:
41
+ return self._tab.String(o + self._tab.Pos)
42
+ return None
43
+
44
+
45
+ def Start(builder): # noqa: N802
46
+ builder.StartObject(2)
47
+
48
+
49
+ def KeyValueStart(builder): # noqa: N802
50
+ """This method is deprecated. Please switch to Start."""
51
+ return Start(builder)
52
+
53
+
54
+ def AddKey(builder, key): # noqa: N802
55
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(key), 0)
56
+
57
+
58
+ def KeyValueAddKey(builder, key): # noqa: N802
59
+ """This method is deprecated. Please switch to AddKey."""
60
+ return AddKey(builder, key)
61
+
62
+
63
+ def AddValue(builder, value): # noqa: N802
64
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0)
65
+
66
+
67
+ def KeyValueAddValue(builder, value): # noqa: N802
68
+ """This method is deprecated. Please switch to AddValue."""
69
+ return AddValue(builder, value)
70
+
71
+
72
+ def End(builder): # noqa: N802
73
+ return builder.EndObject()
74
+
75
+
76
+ def KeyValueEnd(builder): # noqa: N802
77
+ """This method is deprecated. Please switch to End."""
78
+ return End(builder)
@@ -0,0 +1,90 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: CalTableFlatBuffers
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+
8
+ np = import_numpy()
9
+
10
+
11
+ class TrtTable:
12
+ __slots__ = ["_tab"]
13
+
14
+ @classmethod
15
+ def GetRootAs(cls, buf, offset=0): # noqa: N802
16
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
17
+ x = TrtTable()
18
+ x.Init(buf, n + offset)
19
+ return x
20
+
21
+ @classmethod
22
+ def GetRootAsTrtTable(cls, buf, offset=0): # noqa: N802
23
+ """This method is deprecated. Please switch to GetRootAs."""
24
+ return cls.GetRootAs(buf, offset)
25
+
26
+ # TrtTable
27
+ def Init(self, buf, pos): # noqa: N802
28
+ self._tab = flatbuffers.table.Table(buf, pos)
29
+
30
+ # TrtTable
31
+ def Dict(self, j): # noqa: N802
32
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
33
+ if o != 0:
34
+ x = self._tab.Vector(o)
35
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
36
+ x = self._tab.Indirect(x)
37
+ from onnxruntime.quantization.CalTableFlatBuffers.KeyValue import KeyValue # noqa: PLC0415
38
+
39
+ obj = KeyValue()
40
+ obj.Init(self._tab.Bytes, x)
41
+ return obj
42
+ return None
43
+
44
+ # TrtTable
45
+ def DictLength(self): # noqa: N802
46
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
47
+ if o != 0:
48
+ return self._tab.VectorLen(o)
49
+ return 0
50
+
51
+ # TrtTable
52
+ def DictIsNone(self): # noqa: N802
53
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
54
+ return o == 0
55
+
56
+
57
+ def Start(builder): # noqa: N802
58
+ builder.StartObject(1)
59
+
60
+
61
+ def TrtTableStart(builder): # noqa: N802
62
+ """This method is deprecated. Please switch to Start."""
63
+ return Start(builder)
64
+
65
+
66
+ def AddDict(builder, dict): # noqa: N802
67
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(dict), 0)
68
+
69
+
70
+ def TrtTableAddDict(builder, dict): # noqa: N802
71
+ """This method is deprecated. Please switch to AddDict."""
72
+ return AddDict(builder, dict)
73
+
74
+
75
+ def StartDictVector(builder, numElems): # noqa: N802
76
+ return builder.StartVector(4, numElems, 4)
77
+
78
+
79
+ def TrtTableStartDictVector(builder, numElems): # noqa: N802
80
+ """This method is deprecated. Please switch to Start."""
81
+ return StartDictVector(builder, numElems)
82
+
83
+
84
+ def End(builder): # noqa: N802
85
+ return builder.EndObject()
86
+
87
+
88
+ def TrtTableEnd(builder): # noqa: N802
89
+ """This method is deprecated. Please switch to End."""
90
+ return End(builder)
@@ -0,0 +1,19 @@
1
+ from .calibrate import ( # noqa: F401
2
+ CalibraterBase,
3
+ CalibrationDataReader,
4
+ CalibrationMethod,
5
+ MinMaxCalibrater,
6
+ create_calibrator,
7
+ )
8
+ from .qdq_quantizer import QDQQuantizer # noqa: F401
9
+ from .quant_utils import QuantFormat, QuantType, write_calibration_table # noqa: F401
10
+ from .quantize import (
11
+ DynamicQuantConfig, # noqa: F401
12
+ QuantizationMode, # noqa: F401
13
+ StaticQuantConfig, # noqa: F401
14
+ get_qdq_config, # noqa: F401
15
+ quantize, # noqa: F401
16
+ quantize_dynamic, # noqa: F401
17
+ quantize_static, # noqa: F401
18
+ )
19
+ from .shape_inference import quant_pre_process # noqa: F401