onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,477 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ import textwrap
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import onnx
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torch.utils.cpp_extension
18
+ from onnx_model import OnnxModel
19
+ from transformers import WhisperConfig
20
+ from whisper_inputs import convert_inputs_for_ort, get_model_dynamic_axes, get_sample_jump_times_inputs
21
+
22
+ from onnxruntime import InferenceSession
23
+ from onnxruntime.tools import pytorch_export_contrib_ops
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ ##################################################
28
+ # Functions that have to be outside of the class
29
+ # for torch.jit.script_if_tracing to work
30
+ ##################################################
31
+
32
+
33
+ @torch.jit.script_if_tracing
34
+ def index_QKs(alignment_heads: torch.Tensor, QKs: list[torch.Tensor]): # noqa: N802
35
+ """
36
+ Compute the following to get stacked QK tensor that has been indexed for the desired attention heads:
37
+ weights = torch.stack([QKs[_l][:, _h] for _l, _h in alignment_heads], dim=1)
38
+ """
39
+ indexed_QKs = [] # noqa: N806
40
+ for pair in alignment_heads:
41
+ # Each QK is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
42
+ # The `QKs[_l]` selects the right QK from the list of QKs
43
+ # The `QKs[_l][:, _h]` selects the right attention heads from the chosen QK. The `:` is to do this for the batch dim.
44
+ #
45
+ # PyTorch:
46
+ # QKs[_l] is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
47
+ # QKs[_l][:, _h] is of shape (batch_size, sequence_length, num_frames // 2)
48
+ #
49
+ # ONNX:
50
+ # QKs[_l] is of shape (batch_size, num_heads, sequence_length, num_frames // 2)
51
+ # QKs[_l][:, _h] is of shape (batch_size, 1, sequence_length, num_frames // 2) because
52
+ # the `[:, _h]` operation maps to a Gather op and that op does not reduce dimensions
53
+ _l, _h = pair[0], pair[1]
54
+ indexed_QKs.append(QKs[_l][:, _h])
55
+
56
+ # PyTorch:
57
+ # torch.stack will return a tensor of shape (batch_size, num_alignment_heads, sequence_length, num_frames // 2).
58
+ #
59
+ # ONNX:
60
+ # torch.stack will return a tensor of shape (batch_size, num_alignment_heads, 1, sequence_length, num_frames // 2)
61
+ # because the Gather op does not reduce dimensions. To remove the unneeded dimension, torch.squeeze with a specified
62
+ # dim (dim = 2) is added. The torch.squeeze op with a specified dim only runs if the specified dim has a size of 1.
63
+ # Since the dim won't be of size 1 in the PyTorch tensor but it is of size 1 in the ONNX tensor, it will be a no-op
64
+ # in PyTorch and an op in ONNX. Thus, the Squeeze op will only affect the ONNX model.
65
+ weights = torch.stack(indexed_QKs, dim=1)
66
+ weights = torch.squeeze(weights, dim=2)
67
+ return weights
68
+
69
+
70
+ def jump_timings(text_indices, time_indices):
71
+ """
72
+ Calculate jump times from text_indices and time_indices where
73
+ text_indices and time_indices are both 1d vectors
74
+ """
75
+ TOKENS_PER_SECOND = 50.0 # noqa: N806
76
+ diff = text_indices[1:] - text_indices[:-1]
77
+ padding = torch.tensor([1], dtype=torch.int32)
78
+ jumps = torch.cat((padding, diff)).to(torch.bool)
79
+ jump_times = time_indices[jumps].to(torch.float) / TOKENS_PER_SECOND
80
+ return jump_times
81
+
82
+
83
+ def padded_jump_from_dtw(matrix_2d: torch.Tensor, max_length: torch.Tensor):
84
+ """
85
+ Run Dynamic Time Warping (DTW) on batched tensor
86
+ """
87
+ trace = torch.ops.onnxruntime.DynamicTimeWarping(matrix_2d)
88
+ text_indices = trace[0, :]
89
+ time_indices = trace[1, :]
90
+ jump_times = jump_timings(text_indices, time_indices)
91
+ return F.pad(jump_times, [0, int((max_length - jump_times.size(-1)).item())], mode="constant", value=-1.0)
92
+
93
+
94
+ @torch.jit.script_if_tracing
95
+ def batch_jump_times(matrix: torch.Tensor, max_decoded_length: torch.Tensor):
96
+ """
97
+ Compute the following to calculate jump times for all batches:
98
+ batched_jump_times = torch.stack([self.padded_jump_from_dtw(matrix[b], max_decoded_length) for b in range(matrix.size(0))])
99
+ """
100
+ list_of_jump_times = []
101
+ for b in range(matrix.size(0)):
102
+ jump_times = padded_jump_from_dtw(matrix[b], max_decoded_length)
103
+ list_of_jump_times.append(jump_times)
104
+ batched_jump_times = torch.stack(list_of_jump_times)
105
+ return batched_jump_times
106
+
107
+
108
+ class WhisperJumpTimes(torch.nn.Module):
109
+ """Whisper jump times component"""
110
+
111
+ def __init__(self, config: WhisperConfig, device: torch.device, cache_dir: str | os.PathLike):
112
+ super().__init__()
113
+ self.config = config
114
+ self.device = device
115
+ self.cache_dir = cache_dir
116
+
117
+ self.filter_width = 7
118
+ self.qk_scale = 1.0
119
+
120
+ def median_filter(self, weights: torch.Tensor):
121
+ """
122
+ Apply a median filter of width `filter_width` along the last dimension of `weights`
123
+ """
124
+ pad_width = self.filter_width // 2
125
+ x = F.pad(weights, (pad_width, pad_width, 0, 0), mode="reflect")
126
+ x_unfolded = torch.ops.onnxruntime.UnfoldTensor(x, -1, self.filter_width, 1)
127
+ result = torch.select(x_unfolded.sort()[0], dim=-1, index=pad_width)
128
+ return result
129
+
130
+ def forward(
131
+ self,
132
+ alignment_heads: torch.Tensor,
133
+ sot_sequence_length: torch.Tensor,
134
+ segment_length: torch.Tensor,
135
+ QKs: list[torch.Tensor],
136
+ ):
137
+ # Get stacked QKs tensor
138
+ weights = index_QKs(alignment_heads, QKs)
139
+ weights = weights[:, :, : segment_length // 2]
140
+ weights = weights.to(torch.float32)
141
+
142
+ weights = (weights * self.qk_scale).softmax(dim=-1)
143
+ std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
144
+ weights = (weights - mean) / std
145
+ weights = self.median_filter(weights)
146
+
147
+ matrix = torch.mean(weights, 1)
148
+ matrix = -matrix[:, sot_sequence_length:-1]
149
+
150
+ max_decoded_length = torch.tensor([matrix.size(1)], dtype=torch.int64)
151
+ batched_jump_times = batch_jump_times(matrix, max_decoded_length)
152
+ return batched_jump_times
153
+
154
+ def input_names(self):
155
+ input_names = [
156
+ "alignment_heads",
157
+ "sot_sequence_length",
158
+ "segment_length",
159
+ *[f"cross_qk_{i}" for i in range(self.config.decoder_layers)],
160
+ ]
161
+ return input_names
162
+
163
+ def output_names(self):
164
+ output_names = ["jump_times"]
165
+ return output_names
166
+
167
+ def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
168
+ inputs = get_sample_jump_times_inputs(
169
+ self.config,
170
+ self.device,
171
+ batch_size=2,
172
+ sequence_length=8,
173
+ num_alignment_heads=6,
174
+ sot_sequence_length=3,
175
+ segment_length=1332,
176
+ use_fp16=use_fp16_inputs,
177
+ use_int32=use_int32_inputs,
178
+ )
179
+ if return_dict:
180
+ return inputs
181
+ return (
182
+ inputs["alignment_heads"],
183
+ inputs["sot_sequence_length"],
184
+ inputs["segment_length"],
185
+ inputs["QKs"],
186
+ )
187
+
188
+ def create_torch_ops(self):
189
+ """
190
+ 1) Create UnfoldTensor and DynamicTimeWarping as torch ops
191
+ 3) Provide a symbolic mapping from torch ops to ORT contrib ops
192
+
193
+ See https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html#building-with-jit-compilation
194
+ for more details on how this works.
195
+ """
196
+ # Set torch extensions directory to cache directory
197
+ os.environ["TORCH_EXTENSIONS_DIR"] = self.cache_dir
198
+
199
+ # Try to import `ninja` pip package
200
+ try:
201
+ assert torch.utils.cpp_extension.verify_ninja_availability()
202
+ except Exception as e:
203
+ logger.error(f"An error occurred while verifying `ninja` is available: {e}", exc_info=True) # noqa: G201
204
+ install_cmd = "pip install ninja"
205
+ logger.warning(f"Could not import `ninja`. Attempting to install `ninja` via `{install_cmd}`.")
206
+ os.system(install_cmd)
207
+
208
+ # Create UnfoldTensor torch op
209
+ unfold_op_source = textwrap.dedent("""\
210
+ #include "torch/script.h"
211
+
212
+ torch::Tensor UnfoldTensor(torch::Tensor input, int64_t dim, int64_t size, int64_t step) {
213
+ return input.unfold(dim, size, step);
214
+ }
215
+
216
+ // namespace is onnxruntime
217
+ static auto registry = torch::RegisterOperators("onnxruntime::UnfoldTensor", &UnfoldTensor);
218
+ """)
219
+
220
+ torch.utils.cpp_extension.load_inline(
221
+ name="UnfoldTensor",
222
+ cpp_sources=unfold_op_source,
223
+ is_python_module=False,
224
+ verbose=True,
225
+ )
226
+
227
+ # Create DynamicTimeWarping torch op
228
+ dtw_op_source = textwrap.dedent("""\
229
+ #include "torch/script.h"
230
+ #include "torch/torch.h"
231
+ #include <stdexcept>
232
+ #include <tuple>
233
+ #include <vector>
234
+
235
+ torch::Tensor Backtrace(torch::Tensor trace) {
236
+ int64_t i = trace.size(0) - 1;
237
+ int64_t j = trace.size(1) - 1;
238
+ trace.index({0, torch::indexing::Slice()}) = 2;
239
+ trace.index({torch::indexing::Slice(), 0}) = 1;
240
+
241
+ std::vector<int32_t> result_vec;
242
+ while (i > 0 || j > 0) {
243
+ result_vec.push_back(static_cast<int32_t>(i - 1));
244
+ result_vec.push_back(static_cast<int32_t>(j - 1));
245
+ int value = trace[i][j].item<int>();
246
+
247
+ if (value == 0) {
248
+ i--;
249
+ j--;
250
+ } else if (value == 1) {
251
+ i--;
252
+ } else if (value == 2) {
253
+ j--;
254
+ } else {
255
+ throw std::runtime_error("Unexpected trace[i, j]");
256
+ }
257
+ }
258
+
259
+ // Compute result[::-1, :].T
260
+ torch::Tensor result = torch::from_blob(result_vec.data(), {static_cast<long int>(result_vec.size() / 2), 2}, torch::kInt32).clone();
261
+ torch::Tensor reversed = result.flip(0); // result[::-1, :]
262
+ torch::Tensor transposed = reversed.transpose(0, 1); // .T
263
+ return transposed;
264
+ }
265
+
266
+ torch::Tensor DynamicTimeWarping(torch::Tensor x) {
267
+ int64_t N = x.size(0);
268
+ int64_t M = x.size(1);
269
+ torch::Tensor cost = torch::full({N + 1, M + 1}, std::numeric_limits<float>::infinity(), torch::dtype(torch::kFloat32));
270
+ torch::Tensor trace = torch::full({N + 1, M + 1}, -1, torch::dtype(torch::kFloat32));
271
+
272
+ cost[0][0] = 0;
273
+ for (int j = 1; j < M + 1; j++) {
274
+ for (int i = 1; i < N + 1; i++) {
275
+ float c0 = cost[i - 1][j - 1].item<float>();
276
+ float c1 = cost[i - 1][j].item<float>();
277
+ float c2 = cost[i][j - 1].item<float>();
278
+
279
+ float c = 0;
280
+ float t = 0;
281
+
282
+ if (c0 < c1 && c0 < c2) {
283
+ c = c0;
284
+ t = 0;
285
+ } else if (c1 < c0 && c1 < c2) {
286
+ c = c1;
287
+ t = 1;
288
+ } else {
289
+ c = c2;
290
+ t = 2;
291
+ }
292
+
293
+ cost[i][j] = x[i - 1][j - 1].item<float>() + c;
294
+ trace[i][j] = t;
295
+ }
296
+ }
297
+
298
+ return Backtrace(trace);
299
+ }
300
+
301
+ // namespace is onnxruntime
302
+ static auto registry = torch::RegisterOperators("onnxruntime::DynamicTimeWarping", &DynamicTimeWarping);
303
+ """)
304
+
305
+ torch.utils.cpp_extension.load_inline(
306
+ name="DynamicTimeWarping",
307
+ cpp_sources=dtw_op_source,
308
+ is_python_module=False,
309
+ verbose=True,
310
+ )
311
+
312
+ # Create symbolic mapping from torch ops to ORT contrib ops
313
+ pytorch_export_contrib_ops.register()
314
+
315
+ def export_onnx(
316
+ self,
317
+ onnx_model_path: str,
318
+ provider: str,
319
+ verbose: bool = True,
320
+ use_external_data_format: bool = False,
321
+ use_fp16_inputs: bool = False,
322
+ use_int32_inputs: bool = True,
323
+ ):
324
+ """Export word-level timestamps to ONNX
325
+
326
+ Args:
327
+ onnx_model_path (str): path to save ONNX model
328
+ provider (str): provider to use for verifying parity on ONNX model
329
+ verbose (bool, optional): print verbose information. Defaults to True.
330
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
331
+ use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
332
+ use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
333
+ """
334
+ # Shape of timestamps's tensors:
335
+ # Inputs:
336
+ # alignment_heads: (num_alignment_heads, 2)
337
+ # sot_sequence_length: (1)
338
+ # segment_length: (1)
339
+ # cross_qk_*: (batch_size, num_heads, sequence_length, num_frames // 2)
340
+ # Outputs:
341
+ # jump_times: (batch_size, max_length)
342
+
343
+ # Definitions:
344
+ # alignment_heads: the attention head indices where the Q*K values are highly correlated with word-level timestamps
345
+ # (i.e. the alignment between audio and text tokens)
346
+ # This is calculated as follows:
347
+ #
348
+ # ```
349
+ # import base64
350
+ # import gzip
351
+ # import numpy as np
352
+ # import torch
353
+ #
354
+ # # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
355
+ # # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
356
+ # _ALIGNMENT_HEADS = {
357
+ # "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
358
+ # "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
359
+ # "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
360
+ # "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
361
+ # "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
362
+ # "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
363
+ # "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
364
+ # "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
365
+ # "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
366
+ # "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
367
+ # "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
368
+ # "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
369
+ # "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
370
+ # "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
371
+ # }
372
+ #
373
+ # model_name = "large-v3-turbo"
374
+ # array = np.frombuffer(
375
+ # gzip.decompress(base64.b85decode(_ALIGNMENT_HEADS[model_name])), dtype=bool
376
+ # ).copy()
377
+ # mask = torch.from_numpy(array).reshape(
378
+ # self.dims.n_text_layer, self.dims.n_text_head
379
+ # )
380
+ # self.alignment_heads = mask.to_sparse().indices().T
381
+ # ```
382
+ #
383
+ # sot_sequence_length: the length of the start-of-transcription sequence before the first token is generated
384
+ # Typically the start-of-transcription sequence is [<|startoftranscription|>, <|language_token|>, <|task_token|>]
385
+ # so its length is 3.
386
+ #
387
+ # segment_length: the length (in frames) of the audio segment that is being transcribed
388
+ #
389
+ # cross_qk_*: the Q*K values for the cross-attention blocks in the decoder
390
+ # Every decoder layer has a self-attention block and a cross-attention block so there are `n` cross-attention blocks
391
+ # where `n` is the number of decoder layers.
392
+ #
393
+ # jump_times: the timings where jumps occur in speech
394
+ # This allows us to detect when a word began to be spoken by the speaker (start_times) and when a word was finished
395
+ # being spoken by the speaker (end_times).
396
+
397
+ inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
398
+ input_names = self.input_names()
399
+ output_names = self.output_names()
400
+ dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
401
+
402
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
403
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
404
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
405
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
406
+ out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
407
+
408
+ # Create torch ops and map them to ORT contrib ops before export
409
+ self.create_torch_ops()
410
+ torch.onnx.export(
411
+ self,
412
+ args=inputs,
413
+ f=out_path,
414
+ export_params=True,
415
+ input_names=input_names,
416
+ output_names=output_names,
417
+ dynamic_axes=dynamic_axes,
418
+ opset_version=17,
419
+ do_constant_folding=True,
420
+ verbose=verbose,
421
+ custom_opsets={"com.microsoft": 1},
422
+ )
423
+
424
+ if use_external_data_format:
425
+ model = onnx.load_model(out_path, load_external_data=use_external_data_format)
426
+ OnnxModel.save(
427
+ model,
428
+ onnx_model_path,
429
+ save_as_external_data=True,
430
+ all_tensors_to_one_file=True,
431
+ )
432
+
433
+ self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
434
+
435
+ def verify_onnx(
436
+ self,
437
+ onnx_model_path: str,
438
+ provider: str,
439
+ use_fp16_inputs: bool,
440
+ use_int32_inputs: bool,
441
+ ):
442
+ """Verify ONNX model outputs and PyTorch model outputs match
443
+
444
+ Args:
445
+ onnx_model_path (str): path to save ONNX model
446
+ provider (str): execution provider for ONNX model
447
+ use_fp16_inputs (bool, optional): use float16 inputs for the cross_qk_{i}
448
+ use_int32_inputs (bool, optional): use int32 inputs for the alignment_heads and sot_sequence_length
449
+ """
450
+ # Shape of jump times's tensors:
451
+ # Inputs:
452
+ # alignment_heads: (num_alignment_heads, 2)
453
+ # sot_sequence_length: (1)
454
+ # segment_length: (1)
455
+ # cross_qk_*: (batch_size, num_heads, sequence_length, num_frames // 2)
456
+ # Outputs:
457
+ # jump_times: (batch_size, max_length)
458
+ inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
459
+
460
+ # Run PyTorch model
461
+ pt_outputs = (
462
+ self.forward(
463
+ inputs["alignment_heads"], inputs["sot_sequence_length"], inputs["segment_length"], inputs["QKs"]
464
+ )
465
+ .detach()
466
+ .cpu()
467
+ .numpy()
468
+ )
469
+
470
+ # Run ONNX model
471
+ sess = InferenceSession(onnx_model_path, providers=[provider])
472
+ ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
473
+
474
+ # Calculate output difference
475
+ diff = np.abs(pt_outputs - ort_outputs)
476
+ print("Comparing batched jump_times...", flush=True)
477
+ print(f"Max diff: {np.max(diff)}", flush=True)