onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,609 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ import warnings
11
+
12
+ import onnx
13
+ import torch
14
+ from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
15
+ from whisper_chain import chain_model
16
+ from whisper_encoder import WhisperEncoder
17
+ from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper
18
+
19
+ from onnxruntime.quantization.matmul_nbits_quantizer import (
20
+ KQuantWeightOnlyQuantConfig,
21
+ MatMulNBitsQuantizer,
22
+ QuantFormat,
23
+ )
24
+
25
+ logger = logging.getLogger("")
26
+
27
+ PROVIDERS = {
28
+ "cpu": "CPUExecutionProvider",
29
+ "cuda": "CUDAExecutionProvider",
30
+ }
31
+
32
+
33
+ def parse_arguments(argv=None):
34
+ parser = argparse.ArgumentParser()
35
+
36
+ conversion_args = parser.add_argument_group("Conversion Process Args")
37
+ optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)")
38
+ optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)")
39
+ quant_args = parser.add_argument_group("INT8 Quantization Args")
40
+
41
+ #################################
42
+ # Conversion options for Whisper
43
+ #################################
44
+
45
+ conversion_args.add_argument(
46
+ "-m",
47
+ "--model_name_or_path",
48
+ required=False,
49
+ default=PRETRAINED_WHISPER_MODELS[0],
50
+ type=str,
51
+ help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS),
52
+ )
53
+
54
+ conversion_args.add_argument(
55
+ "--model_impl",
56
+ required=False,
57
+ default="hf",
58
+ choices=["hf", "openai"],
59
+ type=str,
60
+ help="Select implementation for export of encoder and decoder subgraphs",
61
+ )
62
+
63
+ conversion_args.add_argument(
64
+ "--cache_dir",
65
+ required=False,
66
+ type=str,
67
+ default=os.path.join(".", "cache_models"),
68
+ help="Directory to cache pre-trained models",
69
+ )
70
+
71
+ conversion_args.add_argument(
72
+ "--output",
73
+ required=False,
74
+ type=str,
75
+ default=os.path.join(".", "onnx_models"),
76
+ help="Output directory",
77
+ )
78
+
79
+ conversion_args.add_argument(
80
+ "-o",
81
+ "--optimize_onnx",
82
+ required=False,
83
+ action="store_true",
84
+ help="Use optimizer.py to optimize onnx model",
85
+ )
86
+ conversion_args.set_defaults(optimize_onnx=False)
87
+
88
+ conversion_args.add_argument(
89
+ "--use_gpu",
90
+ required=False,
91
+ action="store_true",
92
+ help="Use GPU for model inference",
93
+ )
94
+ conversion_args.set_defaults(use_gpu=False)
95
+
96
+ conversion_args.add_argument(
97
+ "-p",
98
+ "--precision",
99
+ required=False,
100
+ type=Precision,
101
+ default=Precision.FLOAT32,
102
+ choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4],
103
+ help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8/int4 for quantization",
104
+ )
105
+
106
+ conversion_args.add_argument(
107
+ "--use_int64_inputs",
108
+ required=False,
109
+ action="store_true",
110
+ help="Use int64 instead of int32 for input_ids and attention_mask.",
111
+ )
112
+ conversion_args.set_defaults(use_int64_inputs=False)
113
+
114
+ conversion_args.add_argument(
115
+ "-r",
116
+ "--provider",
117
+ required=False,
118
+ type=str,
119
+ default="cpu",
120
+ choices=list(PROVIDERS.keys()),
121
+ help="Provider to benchmark. Default is CPUExecutionProvider.",
122
+ )
123
+
124
+ conversion_args.add_argument(
125
+ "--verbose",
126
+ required=False,
127
+ action="store_true",
128
+ help="Enable verbose logging",
129
+ )
130
+ conversion_args.set_defaults(verbose=False)
131
+
132
+ conversion_args.add_argument(
133
+ "-e",
134
+ "--use_external_data_format",
135
+ required=False,
136
+ action="store_true",
137
+ help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.",
138
+ )
139
+ conversion_args.set_defaults(use_external_data_format=False)
140
+
141
+ conversion_args.add_argument(
142
+ "-w",
143
+ "--overwrite",
144
+ required=False,
145
+ action="store_true",
146
+ help="Overwrite existing ONNX model",
147
+ )
148
+ conversion_args.set_defaults(overwrite=False)
149
+
150
+ conversion_args.add_argument(
151
+ "--separate_encoder_and_decoder_init",
152
+ required=False,
153
+ action="store_true",
154
+ help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.",
155
+ )
156
+ conversion_args.set_defaults(separate_encoder_and_decoder_init=False)
157
+
158
+ conversion_args.add_argument(
159
+ "--no_beam_search_op",
160
+ required=False,
161
+ action="store_true",
162
+ help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.",
163
+ )
164
+ conversion_args.set_defaults(no_beam_search_op=False)
165
+
166
+ conversion_args.add_argument(
167
+ "--use_decoder_masked_mha",
168
+ required=False,
169
+ action="store_true",
170
+ help="Use DecoderMaskedMultiHeadAttention kernel for improved performance. This is currently an experimental feature.",
171
+ )
172
+ conversion_args.set_defaults(use_decoder_masked_mha=False)
173
+
174
+ #############################################################
175
+ # Optional inputs for Whisper
176
+ # (listed below in the order that WhisperBeamSearch expects)
177
+ #############################################################
178
+
179
+ optional_inputs.add_argument(
180
+ "-v",
181
+ "--use_vocab_mask",
182
+ required=False,
183
+ action="store_true",
184
+ help="Use vocab_mask as an extra graph input to enable specific logits processing",
185
+ )
186
+ optional_inputs.set_defaults(use_vocab_mask=False)
187
+
188
+ optional_inputs.add_argument(
189
+ "-u",
190
+ "--use_prefix_vocab_mask",
191
+ required=False,
192
+ action="store_true",
193
+ help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
194
+ )
195
+ optional_inputs.set_defaults(use_prefix_vocab_mask=False)
196
+
197
+ optional_inputs.add_argument(
198
+ "-f",
199
+ "--use_forced_decoder_ids",
200
+ required=False,
201
+ action="store_true",
202
+ help="Use decoder_input_ids as an extra graph input to the beam search op",
203
+ )
204
+ optional_inputs.set_defaults(use_forced_decoder_ids=False)
205
+
206
+ optional_inputs.add_argument(
207
+ "-l",
208
+ "--use_logits_processor",
209
+ required=False,
210
+ action="store_true",
211
+ help="Use logits_processor as an extra graph input to enable specific logits processing",
212
+ )
213
+ optional_inputs.set_defaults(use_specific_logits_processor=False)
214
+
215
+ optional_inputs.add_argument(
216
+ "--collect_cross_qk",
217
+ required=False,
218
+ action="store_true",
219
+ help="Beam search model collect stacked cross QK.",
220
+ )
221
+ optional_inputs.set_defaults(collect_cross_qk=False)
222
+
223
+ optional_inputs.add_argument(
224
+ "--extra_decoding_ids",
225
+ required=False,
226
+ action="store_true",
227
+ help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
228
+ )
229
+ optional_inputs.set_defaults(extra_decoding_ids=False)
230
+
231
+ optional_inputs.add_argument(
232
+ "-t",
233
+ "--use_temperature",
234
+ required=False,
235
+ action="store_true",
236
+ help="Use temperature as an extra graph input for the WhisperBeamSearch op",
237
+ )
238
+ optional_inputs.set_defaults(use_temperature=False)
239
+
240
+ optional_inputs.add_argument(
241
+ "--no_repeat_ngram_size",
242
+ type=int,
243
+ default=0,
244
+ help="default to 0",
245
+ )
246
+
247
+ #############################################################
248
+ # Optional outputs for Whisper
249
+ # (listed below in the order that WhisperBeamSearch expects)
250
+ #############################################################
251
+
252
+ optional_outputs.add_argument(
253
+ "--output_sequence_scores",
254
+ required=False,
255
+ action="store_true",
256
+ help="Beam search model output scores for each generated sequence.",
257
+ )
258
+ optional_outputs.set_defaults(output_sequence_scores=False)
259
+
260
+ optional_outputs.add_argument(
261
+ "--output_scores",
262
+ required=False,
263
+ action="store_true",
264
+ help="Beam search model output scores over vocab per generated token.",
265
+ )
266
+ optional_outputs.set_defaults(output_scores=False)
267
+
268
+ optional_outputs.add_argument(
269
+ "--output_cross_qk",
270
+ required=False,
271
+ action="store_true",
272
+ help="Beam search model output collected qk as output. Also hint collect_cross_qk",
273
+ )
274
+ optional_outputs.set_defaults(output_cross_qk=False)
275
+
276
+ optional_outputs.add_argument(
277
+ "--cross_qk_onnx_model",
278
+ required=False,
279
+ type=str,
280
+ default=None,
281
+ help="The model which consumes cross_qk outputs.",
282
+ )
283
+
284
+ optional_outputs.add_argument(
285
+ "--output_no_speech_probs",
286
+ required=False,
287
+ action="store_true",
288
+ help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
289
+ )
290
+ optional_outputs.set_defaults(output_no_speech_probs=False)
291
+
292
+ ###################################
293
+ # Quantization options for Whisper
294
+ ###################################
295
+
296
+ quant_args.add_argument(
297
+ "--accuracy_level",
298
+ default=0,
299
+ required=False,
300
+ type=int,
301
+ help="Accuracy level of the 4-bit quantized MatMul computation.",
302
+ )
303
+
304
+ quant_args.add_argument(
305
+ "--quantize_symmetric",
306
+ required=False,
307
+ action="store_true",
308
+ help="Quantize weights symmetrically",
309
+ )
310
+ quant_args.set_defaults(quantize_symmetric=False)
311
+
312
+ args = parser.parse_args(argv)
313
+
314
+ # Collect cross QKs if either flag is enabled
315
+ args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk
316
+
317
+ # FP32 CPU can be supported here once the DMMHA CPU kernel bugs are fixed
318
+ args.use_decoder_masked_mha = args.use_decoder_masked_mha and args.provider == "cuda"
319
+
320
+ return args
321
+
322
+
323
+ # quant_method is reserved for mixed precision in future
324
+ def make_quant_algo_config(precision, quant_method: str, matmul_nodes=None):
325
+ customized_weight_config = {}
326
+ quant_algo_config = None
327
+
328
+ # need to use k_quant for int8
329
+ if precision == Precision.INT8:
330
+ for node_name in matmul_nodes:
331
+ customized_weight_config[node_name] = {"bits": 8}
332
+ quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
333
+ else:
334
+ quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config)
335
+
336
+ return quant_algo_config
337
+
338
+
339
+ def export_onnx_models(
340
+ model_name_or_path,
341
+ model_impl,
342
+ cache_dir,
343
+ output_dir,
344
+ use_gpu,
345
+ use_external_data_format,
346
+ optimize_onnx,
347
+ precision,
348
+ verbose,
349
+ use_forced_decoder_ids: bool = False,
350
+ merge_encoder_and_decoder_init: bool = True,
351
+ no_beam_search_op: bool = False,
352
+ use_decoder_masked_mha: bool = False,
353
+ output_qk: bool = False,
354
+ overwrite: bool = False,
355
+ use_int32_inputs: bool = True,
356
+ accuracy_level: int = 0,
357
+ quantize_symmetric: bool = False,
358
+ provider: str = "cpu",
359
+ ):
360
+ device = torch.device("cuda" if use_gpu else "cpu")
361
+ if not use_gpu:
362
+ accuracy_level = 4 # change to 4 for CPU EP
363
+ use_fp16_inputs = precision == Precision.FLOAT16 or (precision in (Precision.INT8, Precision.INT4) and use_gpu)
364
+
365
+ models = WhisperHelper.load_model(
366
+ model_name_or_path,
367
+ model_impl,
368
+ cache_dir,
369
+ device,
370
+ torch.float16 if use_fp16_inputs else torch.float32,
371
+ merge_encoder_and_decoder_init,
372
+ no_beam_search_op,
373
+ output_qk,
374
+ )
375
+ config = models["decoder"].config
376
+
377
+ if (not use_external_data_format) and (config.num_hidden_layers > 24):
378
+ logger.warning("You MUST pass `--use_external_data_format` because model size > 2GB")
379
+ raise Exception("Please pass `--use_external_data_format` for this model.")
380
+
381
+ output_paths = []
382
+ for name, model in models.items():
383
+ print(f"========> Handling {name} model......")
384
+ filename_suffix = "_" + name
385
+
386
+ onnx_path = WhisperHelper.get_onnx_path(
387
+ output_dir,
388
+ model_name_or_path,
389
+ suffix=filename_suffix,
390
+ new_folder=False,
391
+ )
392
+
393
+ # Export to ONNX
394
+ if overwrite or not os.path.exists(onnx_path):
395
+ logger.info(f"Exporting ONNX model to {onnx_path}")
396
+ WhisperHelper.export_onnx(
397
+ model,
398
+ onnx_path,
399
+ PROVIDERS[provider],
400
+ verbose,
401
+ use_external_data_format,
402
+ use_fp16_inputs=use_fp16_inputs,
403
+ use_int32_inputs=use_int32_inputs,
404
+ use_encoder_hidden_states=(name == "decoder_init"),
405
+ use_kv_cache_inputs=(name == "decoder"),
406
+ )
407
+ else:
408
+ logger.info(f"Skip exporting: existing ONNX model {onnx_path}")
409
+
410
+ # Optimize ONNX model
411
+ if optimize_onnx or precision != Precision.FLOAT32:
412
+ output_path = WhisperHelper.get_onnx_path(
413
+ output_dir,
414
+ model_name_or_path,
415
+ suffix=filename_suffix + "_" + str(precision),
416
+ new_folder=False,
417
+ )
418
+
419
+ if overwrite or not os.path.exists(output_path):
420
+ if optimize_onnx:
421
+ logger.info(f"Optimizing model to {output_path}")
422
+ WhisperHelper.optimize_onnx(
423
+ onnx_path,
424
+ output_path,
425
+ precision == Precision.FLOAT16,
426
+ model.config.encoder_attention_heads,
427
+ model.config.d_model,
428
+ model.config.decoder_layers,
429
+ use_external_data_format,
430
+ use_gpu=use_gpu,
431
+ provider=provider,
432
+ is_decoder=(name == "decoder"),
433
+ no_beam_search_op=no_beam_search_op,
434
+ use_decoder_masked_mha=use_decoder_masked_mha,
435
+ output_qk=output_qk,
436
+ )
437
+ # Remove old ONNX model and old data file
438
+ if os.path.exists(onnx_path):
439
+ os.remove(onnx_path)
440
+ if os.path.exists(onnx_path + ".data"):
441
+ os.remove(onnx_path + ".data")
442
+ onnx_path = output_path
443
+
444
+ if isinstance(model, WhisperEncoder):
445
+ model.verify_onnx(
446
+ onnx_path,
447
+ PROVIDERS[provider],
448
+ use_fp16_inputs=use_fp16_inputs,
449
+ )
450
+ else:
451
+ model.verify_onnx(
452
+ onnx_path,
453
+ PROVIDERS[provider],
454
+ use_fp16_inputs=use_fp16_inputs,
455
+ use_int32_inputs=use_int32_inputs,
456
+ )
457
+
458
+ if precision in (Precision.INT8, Precision.INT4):
459
+ onnx_model = onnx.load(onnx_path, load_external_data=True)
460
+ matmul_nodes = [node.name for node in onnx_model.graph.node if node.op_type == "MatMul"]
461
+ quant_algo_config = make_quant_algo_config(precision, "k_quant", matmul_nodes)
462
+
463
+ quant = MatMulNBitsQuantizer(
464
+ model=onnx_model,
465
+ block_size=32,
466
+ is_symmetric=quantize_symmetric,
467
+ accuracy_level=accuracy_level,
468
+ quant_format=QuantFormat.QOperator,
469
+ op_types_to_quantize=("MatMul",),
470
+ algo_config=quant_algo_config,
471
+ )
472
+ quant.process()
473
+ if os.path.exists(output_path):
474
+ os.remove(output_path)
475
+ if os.path.exists(output_path + ".data"):
476
+ os.remove(output_path + ".data")
477
+ onnx.save_model(
478
+ quant.model.model,
479
+ output_path,
480
+ save_as_external_data=True,
481
+ all_tensors_to_one_file=True,
482
+ location=os.path.basename(output_path) + ".data",
483
+ size_threshold=0,
484
+ convert_attribute=False,
485
+ )
486
+ else:
487
+ logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
488
+ else:
489
+ output_path = onnx_path
490
+
491
+ output_paths.append(output_path)
492
+
493
+ return output_paths
494
+
495
+
496
+ def main(argv=None):
497
+ warnings.warn(
498
+ "This example is deprecated. Use the Olive recipe instead: "
499
+ "https://github.com/microsoft/olive-recipes/tree/main",
500
+ DeprecationWarning,
501
+ stacklevel=2,
502
+ )
503
+ args = parse_arguments(argv)
504
+
505
+ setup_logger(args.verbose)
506
+
507
+ logger.info(f"Arguments:{args}")
508
+
509
+ cache_dir = args.cache_dir
510
+ output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
511
+ prepare_environment(cache_dir, output_dir, args.use_gpu)
512
+
513
+ if args.precision == Precision.FLOAT16:
514
+ assert args.use_gpu, "fp16 requires --use_gpu"
515
+
516
+ output_paths = export_onnx_models(
517
+ args.model_name_or_path,
518
+ args.model_impl,
519
+ cache_dir,
520
+ output_dir,
521
+ args.use_gpu,
522
+ args.use_external_data_format,
523
+ args.optimize_onnx,
524
+ args.precision,
525
+ args.verbose,
526
+ args.use_forced_decoder_ids,
527
+ not args.separate_encoder_and_decoder_init,
528
+ args.no_beam_search_op,
529
+ args.use_decoder_masked_mha,
530
+ args.output_cross_qk,
531
+ args.overwrite,
532
+ not args.use_int64_inputs,
533
+ args.accuracy_level,
534
+ args.quantize_symmetric,
535
+ args.provider,
536
+ )
537
+
538
+ max_diff = 0
539
+ if not args.no_beam_search_op:
540
+ logger.info("Chaining model ... :")
541
+ args.beam_model_output_dir = WhisperHelper.get_onnx_path(
542
+ output_dir,
543
+ args.model_name_or_path,
544
+ suffix="_beamsearch",
545
+ new_folder=False,
546
+ )
547
+ for path in output_paths:
548
+ if "encoder_decoder" in path or "encoder" in path:
549
+ args.encoder_path = path
550
+ elif "decoder" in path:
551
+ args.decoder_path = path
552
+ chain_model(args)
553
+ output_paths.append(args.beam_model_output_dir)
554
+
555
+ # Check chained model
556
+ ort_session = create_onnxruntime_session(
557
+ args.beam_model_output_dir,
558
+ use_gpu=args.use_gpu,
559
+ provider=args.provider,
560
+ )
561
+ device = torch.device("cuda" if args.use_gpu else "cpu")
562
+
563
+ # Wrap parity check in try-except to allow export to continue in case this produces an error
564
+ try:
565
+ with torch.no_grad():
566
+ # Verify batched decoding with prompts for OpenAI implementation
567
+ if args.model_impl == "openai" and args.use_forced_decoder_ids:
568
+ max_diff = WhisperHelper.verify_onnx(
569
+ args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
570
+ )
571
+ else:
572
+ max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
573
+ if max_diff > 1e-4:
574
+ logger.warning("PyTorch and ONNX Runtime results are NOT close")
575
+ else:
576
+ logger.info("PyTorch and ONNX Runtime results are close")
577
+ except Exception as e:
578
+ logger.warning(
579
+ f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
580
+ )
581
+
582
+ # Remove extra ONNX models saved in output directory
583
+ for _file in os.listdir(output_dir):
584
+ if "_beamsearch" not in _file and "_jump_times" not in _file:
585
+ path = os.path.join(output_dir, _file)
586
+ os.remove(path)
587
+ if path in output_paths:
588
+ output_paths.remove(path)
589
+
590
+ else:
591
+ # Create ancillary JSON files for ONNX Runtime GenAI and/or Hugging Face's Optimum
592
+ WhisperHelper.save_processing(
593
+ args.model_name_or_path,
594
+ args.provider,
595
+ args.separate_encoder_and_decoder_init,
596
+ args.use_decoder_masked_mha,
597
+ args.output_cross_qk,
598
+ next(iter(filter(lambda path: "encoder" in path, output_paths))),
599
+ next(iter(filter(lambda path: "decoder" in path, output_paths))),
600
+ output_dir,
601
+ cache_dir,
602
+ )
603
+
604
+ logger.info(f"Done! Outputs: {output_paths}")
605
+ return max_diff
606
+
607
+
608
+ if __name__ == "__main__":
609
+ main()