onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,653 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from abc import ABC, abstractmethod
7
+
8
+ import ort_flatbuffers_py.fbs as fbs
9
+
10
+ from .types import FbsTypeInfo, value_name_to_typestr
11
+
12
+
13
+ def _create_op_key(domain: str, optype: str):
14
+ return f"{domain}:{optype}"
15
+
16
+
17
+ def _ort_constant_for_domain(domain: str):
18
+ """
19
+ Map a string domain value to the internal ONNX Runtime constant for that domain.
20
+ :param domain: Domain string to map.
21
+ :return: Internal ONNX Runtime constant
22
+ """
23
+
24
+ # constants are defined in <ORT root>/include/onnxruntime/core/graph/constants.h
25
+ # This list is limited to just the domains we have processors for
26
+ domain_to_constant_map = {"ai.onnx": "kOnnxDomain", "ai.onnx.ml": "kMLDomain", "com.microsoft": "kMSDomain"}
27
+
28
+ if domain not in domain_to_constant_map:
29
+ raise ValueError(f"Domain {domain} not found in map to ONNX Runtime constant. Please update map.")
30
+
31
+ return domain_to_constant_map[domain]
32
+
33
+
34
+ def _reg_type_to_cpp_type(reg_type: str):
35
+ if reg_type == "string":
36
+ return "std::string"
37
+ return reg_type
38
+
39
+
40
+ def _split_reg_types(reg_types_str: str):
41
+ """
42
+ Split on underscores but append "_t" to the previous element.
43
+ """
44
+ tokens = reg_types_str.split("_")
45
+ reg_types = []
46
+ for token in tokens:
47
+ if token == "t" and len(reg_types) > 0:
48
+ reg_types[-1] += "_t"
49
+ else:
50
+ reg_types += [token]
51
+ return reg_types
52
+
53
+
54
+ class TypeUsageProcessor(ABC):
55
+ """
56
+ Abstract base class for processors which implement operator specific logic to determine the type or types required.
57
+ """
58
+
59
+ def __init__(self, domain: str, optype: str):
60
+ self.domain = domain
61
+ self.optype = optype
62
+ self.name = _create_op_key(domain, optype)
63
+
64
+ @abstractmethod
65
+ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
66
+ pass
67
+
68
+ def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
69
+ """
70
+ Given the string from a kernel registration, determine if the registration is required or not.
71
+ :param type_in_registration: Type string from kernel registration
72
+ :param globally_allowed_types: Optional set of globally allowed types. If provided, these types take precedence
73
+ in determining the required types.
74
+ :return: True is required. False if not.
75
+ """
76
+ # Not all operators have typed registrations, so this is optionally implemented by derived classes
77
+ raise RuntimeError(f"Did not expect processor for {self.name} to have typed registrations.")
78
+
79
+ def get_cpp_entry(self):
80
+ """
81
+ Get the C++ code that specifies this operator's required types.
82
+ :return: List with any applicable C++ code for this operator's required types. One line per entry.
83
+ """
84
+ # Not applicable for some ops, so return no lines by default.
85
+ return []
86
+
87
+ @abstractmethod
88
+ def to_config_entry(self):
89
+ """
90
+ Generate a configuration file entry in JSON format with the required types for the operator.
91
+ :return: JSON string with required type information.
92
+ """
93
+
94
+ @abstractmethod
95
+ def from_config_entry(self, entry: str):
96
+ """
97
+ Re-create the types required from a configuration file entry created with to_config_entry.
98
+ NOTE: Any existing type information should be cleared prior to re-creating from a config file entry.
99
+ :param entry: Configuration file entry
100
+ """
101
+
102
+
103
+ class DefaultTypeUsageProcessor(TypeUsageProcessor):
104
+ """
105
+ Operator processor which tracks the types used for selected input/s and/or output/s.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ domain: str,
111
+ optype: str,
112
+ inputs: [int] = [0], # noqa: B006
113
+ outputs: [int] = [], # noqa: B006
114
+ required_input_types: dict[int, set[str]] = {}, # noqa: B006
115
+ required_output_types: dict[int, set[str]] = {}, # noqa: B006
116
+ ):
117
+ """
118
+ Create DefaultTypeUsageProcessor. Types for one or more inputs and/or outputs can be tracked by the processor.
119
+ The default is to track the types required for input 0, as this is the most common use case in ONNX.
120
+
121
+ Required input and output types may be specified. These are only applicable to is_typed_registration_needed().
122
+ If a registration type matches a required type, the typed registration is needed.
123
+ There is a separate mechanism for specifying required types from C++ for kernels with untyped registration.
124
+
125
+ :param domain: Operator domain.
126
+ :param optype: Operator name.
127
+ :param inputs: Inputs to track. Zero based index. May be empty.
128
+ :param outputs: Outputs to track. Zero based index. May be empty.
129
+ :param required_input_types: Required input types. May be empty.
130
+ :param required_output_types: Required output types. May be empty.
131
+ """
132
+ super().__init__(domain, optype)
133
+ self._input_types = {}
134
+ self._output_types = {}
135
+
136
+ for i in inputs:
137
+ self._input_types[i] = set()
138
+
139
+ for o in outputs:
140
+ self._output_types[o] = set()
141
+
142
+ if not inputs and not outputs:
143
+ raise ValueError("At least one input or output must be tracked")
144
+
145
+ self._required_input_types = required_input_types
146
+ self._required_output_types = required_output_types
147
+
148
+ def _is_type_enabled(self, reg_type, index, required_types, allowed_type_set):
149
+ cpp_type = _reg_type_to_cpp_type(reg_type)
150
+ return cpp_type in required_types.get(index, set()) or cpp_type in allowed_type_set
151
+
152
+ def is_input_type_enabled(self, reg_type, index, allowed_type_set=None):
153
+ """Whether input type is enabled based on required and allowed types."""
154
+ if allowed_type_set is None:
155
+ allowed_type_set = self._input_types[index]
156
+ return self._is_type_enabled(reg_type, index, self._required_input_types, allowed_type_set)
157
+
158
+ def is_output_type_enabled(self, reg_type, index, allowed_type_set=None):
159
+ """Whether output type is enabled based on required and allowed types."""
160
+ if allowed_type_set is None:
161
+ allowed_type_set = self._output_types[index]
162
+ return self._is_type_enabled(reg_type, index, self._required_output_types, allowed_type_set)
163
+
164
+ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
165
+ for i in self._input_types:
166
+ if i >= node.InputsLength():
167
+ # Some operators have fewer inputs in earlier versions where data that was as an attribute
168
+ # become an input in later versions to allow it to be dynamically provided. Allow for that.
169
+ # e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs
170
+ # raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
171
+ # .format(node.OutputsLength(), self.name, o))
172
+ pass
173
+ else:
174
+ type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo)
175
+ self._input_types[i].add(type_str)
176
+
177
+ for o in self._output_types:
178
+ # Don't know of any ops where the number of outputs changed across versions, so require a valid length
179
+ if o >= node.OutputsLength():
180
+ raise RuntimeError(
181
+ f"Node has {node.OutputsLength()} outputs. Tracker for {self.name} incorrectly configured as it requires {o}."
182
+ )
183
+
184
+ type_str = value_name_to_typestr(node.Outputs(o), value_name_to_typeinfo)
185
+ self._output_types[o].add(type_str)
186
+
187
+ def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
188
+ if 0 not in self._input_types:
189
+ # currently all standard typed registrations are for input 0.
190
+ # custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below).
191
+ raise RuntimeError(f"Expected typed registration to use type from input 0. Node:{self.name}")
192
+
193
+ return self.is_input_type_enabled(type_in_registration, 0, globally_allowed_types)
194
+
195
+ def get_cpp_entry(self):
196
+ entries = []
197
+ domain = _ort_constant_for_domain(self.domain)
198
+ for i in sorted(self._input_types.keys()):
199
+ if self._input_types[i]:
200
+ entries.append(
201
+ "ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Input, {}, {});".format(
202
+ domain, self.optype, i, ", ".join(sorted(self._input_types[i]))
203
+ )
204
+ )
205
+
206
+ for o in sorted(self._output_types.keys()):
207
+ if self._output_types[o]:
208
+ entries.append(
209
+ "ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Output, {}, {});".format(
210
+ domain, self.optype, o, ", ".join(sorted(self._output_types[o]))
211
+ )
212
+ )
213
+
214
+ return entries
215
+
216
+ def to_config_entry(self):
217
+ # convert the sets of types to lists so they can easily written out using the json model
218
+ aggregate_info = {"inputs": {}, "outputs": {}}
219
+
220
+ # filter out empty entries and sort the types
221
+ for i in sorted(self._input_types.keys()):
222
+ if self._input_types[i]:
223
+ aggregate_info["inputs"][i] = sorted(self._input_types[i])
224
+
225
+ for o in sorted(self._output_types.keys()):
226
+ if self._output_types[o]:
227
+ aggregate_info["outputs"][o] = sorted(self._output_types[o])
228
+
229
+ # remove any empty keys
230
+ if not aggregate_info["inputs"]:
231
+ aggregate_info.pop("inputs")
232
+ if not aggregate_info["outputs"]:
233
+ aggregate_info.pop("outputs")
234
+
235
+ entry = json.dumps(aggregate_info) if aggregate_info else None
236
+ return entry
237
+
238
+ def from_config_entry(self, entry: str):
239
+ self._input_types.clear()
240
+ self._output_types.clear()
241
+
242
+ aggregate_info = json.loads(entry)
243
+ if "inputs" in aggregate_info:
244
+ for i_str, values in aggregate_info["inputs"].items():
245
+ self._input_types[int(i_str)] = set(values)
246
+
247
+ if "outputs" in aggregate_info:
248
+ for o_str, values in aggregate_info["outputs"].items():
249
+ self._output_types[int(o_str)] = set(values)
250
+
251
+
252
+ class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor):
253
+ """
254
+ Processor for operators where the second input type is used in a typed kernel registration.
255
+ """
256
+
257
+ def __init__(self, domain: str, optype: str):
258
+ # init with tracking of input 1 only.
259
+ super().__init__(domain, optype, inputs=[1], outputs=[])
260
+
261
+ def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
262
+ return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types)
263
+
264
+
265
+ class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor):
266
+ """
267
+ Processor for operators where the first output type is used in a typed kernel registration.
268
+ """
269
+
270
+ def __init__(self, domain: str, optype: str):
271
+ # init with tracking of output 0 only.
272
+ super().__init__(domain, optype, inputs=[], outputs=[0])
273
+
274
+ def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
275
+ return self.is_output_type_enabled(type_in_registration, 0, globally_allowed_types)
276
+
277
+
278
+ class OneHotProcessor(TypeUsageProcessor):
279
+ """
280
+ Processor for the OneHot operator, which requires custom logic as the type registration key is a concatenation of
281
+ the three types involved instead of a single type name.
282
+ """
283
+
284
+ def __init__(self):
285
+ super().__init__("ai.onnx", "OneHot")
286
+ self._triples = set()
287
+
288
+ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
289
+ type0 = value_name_to_typestr(node.Inputs(0), value_name_to_typeinfo)
290
+ type1 = value_name_to_typestr(node.Inputs(1), value_name_to_typeinfo)
291
+ type2 = value_name_to_typestr(node.Inputs(2), value_name_to_typeinfo)
292
+ # types in kernel registration are ordered this way: input (T1), output (T3), depth (T2)
293
+ key = (type0, type2, type1)
294
+ self._triples.add(key)
295
+
296
+ def is_typed_registration_needed(self, type_in_registration: str, globally_allowed_types: set[str] | None):
297
+ # the OneHot registration involves a concatenation of the 3 types involved
298
+ reg_types = tuple([_reg_type_to_cpp_type(reg_type) for reg_type in _split_reg_types(type_in_registration)])
299
+ if globally_allowed_types is not None:
300
+ return all(reg_type in globally_allowed_types for reg_type in reg_types)
301
+ else:
302
+ return reg_types in self._triples
303
+
304
+ def to_config_entry(self):
305
+ if not self._triples:
306
+ return None
307
+
308
+ aggregate_info = {"custom": sorted(self._triples)}
309
+ entry = json.dumps(aggregate_info)
310
+ return entry
311
+
312
+ def from_config_entry(self, entry: str):
313
+ self._triples.clear()
314
+ aggregate_info = json.loads(entry)
315
+ if "custom" in aggregate_info:
316
+ self._triples = {tuple(triple) for triple in aggregate_info["custom"]}
317
+
318
+
319
+ def _create_operator_type_usage_processors():
320
+ """
321
+ Create a set of processors that determine the required types for all enabled operators.
322
+ :return: Dictionary of operator key to processor. Key is 'domain:operator (e.g. ai.onnx:Cast)'.
323
+ """
324
+ operator_processors = {}
325
+
326
+ def add(processor):
327
+ if processor.name in operator_processors:
328
+ raise RuntimeError("Duplicate processor for " + processor.name)
329
+
330
+ operator_processors[processor.name] = processor
331
+
332
+ # Starting with ops from:
333
+ # - Priority 1P models
334
+ # - Mobilenet + SSD Mobilenet + MobileBert
335
+ # - some known large kernels
336
+ #
337
+ # Ops we are ignoring currently so as not to produce meaningless/unused output:
338
+ # - Implementation is type agnostic:
339
+ # ai.onnx: If, Loop, Reshape, Scan, Shape, Squeeze, Tile, Unsqueeze
340
+ # com.microsoft: DynamicQuantizeMatMul, MatMulIntegerToFloat
341
+ # - Only one type supported in the ORT implementation:
342
+ # ai.onnx: NonMaxSuppression
343
+ # com.microsoft: FusedConv, FusedGemm, FusedMatMul
344
+ # - Implementation does not have any significant type specific code:
345
+ # ai.onnx: Concat, Flatten, Not, Reshape, Shape, Squeeze, Unsqueeze
346
+ #
347
+ default_processor_onnx_ops = [
348
+ "Abs",
349
+ "ArgMax",
350
+ "ArgMin",
351
+ "AveragePool",
352
+ "BatchNormalization",
353
+ "BitShift",
354
+ "Ceil",
355
+ "Clip",
356
+ "Conv",
357
+ "CumSum",
358
+ "Exp",
359
+ "Expand",
360
+ "Floor",
361
+ "Gemm",
362
+ "IsNaN",
363
+ "Log",
364
+ "LogSoftmax",
365
+ "LpNormalization",
366
+ "MatMul",
367
+ "Max",
368
+ "MaxPool",
369
+ "Mean",
370
+ "Min",
371
+ "NonZero",
372
+ "Pad",
373
+ "QLinearConv",
374
+ "QLinearMatMul",
375
+ "Range",
376
+ "Reciprocal",
377
+ "ReduceL1",
378
+ "ReduceL2",
379
+ "ReduceLogSum",
380
+ "ReduceLogSumExp",
381
+ "ReduceMax",
382
+ "ReduceMean",
383
+ "ReduceMin",
384
+ "ReduceProd",
385
+ "ReduceSum",
386
+ "ReduceSumSquare",
387
+ "Relu",
388
+ "Resize",
389
+ "ReverseSequence",
390
+ "RoiAlign",
391
+ "Round",
392
+ "Scatter",
393
+ "ScatterElements",
394
+ "ScatterND",
395
+ "Shrink",
396
+ "Sigmoid",
397
+ "Sign",
398
+ "Sin",
399
+ "Softmax",
400
+ "Split",
401
+ "SplitToSequence",
402
+ "Sqrt",
403
+ "Sum",
404
+ "Tanh",
405
+ "TopK",
406
+ "Transpose",
407
+ "Unique",
408
+ ]
409
+
410
+ # ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available
411
+ default_processor_onnx_ops_requiring_ints_for_input_0 = [
412
+ "Add",
413
+ "Concat",
414
+ "Div",
415
+ "Equal",
416
+ "Greater",
417
+ "Less",
418
+ "Mul",
419
+ "Neg", # used in tflite TransposeConv conversion
420
+ "Sub",
421
+ ]
422
+
423
+ # NOTE: QLinearConv has ONNX and internal implementations
424
+ internal_ops = ["QLinearAdd", "QLinearMul", "QLinearConv"]
425
+
426
+ # TODO - review and add ML ops as needed
427
+ # ML Op notes.
428
+ # CastMap: Switch on value type of input map type, and output type
429
+ # DictVectorizer: Templatized on key+value of input so need to handle like OneHot with custom processor
430
+ # LabelEncoder: Implementation switches on input and output types (only supports string and int64 in T1 and T2)
431
+ # LinearClassifier: Internal switch on input type and also switch on output type
432
+ # SVMClassifier: ditto
433
+ # TreeEnsembleClassifier: Templatized on input type and also switch on output type
434
+ # ZipMap: Switch on output type (derived from attributes)
435
+ default_processor_onnxml_ops = []
436
+
437
+ [add(DefaultTypeUsageProcessor("ai.onnx", op)) for op in default_processor_onnx_ops]
438
+ [
439
+ add(DefaultTypeUsageProcessor("ai.onnx", op, required_input_types={0: {"int32_t", "int64_t"}}))
440
+ for op in default_processor_onnx_ops_requiring_ints_for_input_0
441
+ ]
442
+ [add(DefaultTypeUsageProcessor("ai.onnx.ml", op)) for op in default_processor_onnxml_ops]
443
+ [add(DefaultTypeUsageProcessor("com.microsoft", op)) for op in internal_ops]
444
+
445
+ #
446
+ # Operators that require custom handling
447
+ #
448
+
449
+ # Cast switches on types of input 0 and output 0
450
+ add(DefaultTypeUsageProcessor("ai.onnx", "Cast", inputs=[0], outputs=[0]))
451
+
452
+ # Operators that switch on the type of input 0 and 1
453
+ add(DefaultTypeUsageProcessor("ai.onnx", "Gather", inputs=[0, 1]))
454
+ add(DefaultTypeUsageProcessor("ai.onnx", "GatherElements", inputs=[0, 1]))
455
+ add(DefaultTypeUsageProcessor("ai.onnx", "Pow", inputs=[0, 1]))
456
+ add(DefaultTypeUsageProcessor("ai.onnx", "Slice", inputs=[0, 1]))
457
+
458
+ # Operators that switch on output type
459
+ add(DefaultTypeUsageProcessor("ai.onnx", "ConstantOfShape", inputs=[], outputs=[0]))
460
+
461
+ # Random generator ops produce new data so we track the output type
462
+ onnx_random_ops = ["RandomNormal", "RandomNormalLike", "RandomUniform", "RandomUniformLike", "Multinomial"]
463
+ [add(DefaultTypeUsageProcessor("ai.onnx", op, inputs=[], outputs=[0])) for op in onnx_random_ops]
464
+
465
+ # Where always has a boolean first input so track the second input type for typed registration
466
+ add(Input1TypedRegistrationProcessor("ai.onnx", "Where"))
467
+
468
+ # we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
469
+ # as that's what is used in the typed registration
470
+ add(Output0TypedRegistrationProcessor("ai.onnx", "QuantizeLinear"))
471
+ add(Output0TypedRegistrationProcessor("ai.onnx", "DynamicQuantizeLinear"))
472
+
473
+ # make sure all the dequantize types are enabled. we use int32_t for parts of GEMM and Conv so just
474
+ # enabling int8 and uint8 is not enough.
475
+ # TODO: Only apply required types to the global type list and ignore if it's model based per-op type reduction
476
+ add(
477
+ DefaultTypeUsageProcessor(
478
+ "ai.onnx", "DequantizeLinear", inputs=[0], required_input_types={0: {"int8_t", "uint8_t", "int32_t"}}
479
+ )
480
+ )
481
+
482
+ # OneHot concatenates type strings into a triple in the typed registration
483
+ # e.g. float_int64_t_int64_t
484
+ add(OneHotProcessor())
485
+
486
+ return operator_processors
487
+
488
+
489
+ class OpTypeImplFilterInterface(ABC):
490
+ """
491
+ Class that filters operator implementations based on type.
492
+ """
493
+
494
+ @abstractmethod
495
+ def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
496
+ """
497
+ Given the string from a kernel registration, determine if the registration is required or not.
498
+ :param domain: Operator domain.
499
+ :param optype: Operator type.
500
+ :param type_registration_str: Type string from kernel registration
501
+ :return: True is required. False if not.
502
+ """
503
+
504
+ @abstractmethod
505
+ def get_cpp_entries(self):
506
+ """
507
+ Get the C++ code that specifies the operator types to enable.
508
+ :return: List of strings. One line of C++ code per entry.
509
+ """
510
+
511
+
512
+ class OperatorTypeUsageManager:
513
+ """
514
+ Class to manage the operator type usage processors.
515
+ TODO: Currently the type tracking is not specific to a version of the operator.
516
+ It's unclear how/where version specific logic could/should be added, and it would add significant complexity
517
+ to track types on a per-version basis. Not clear there's enough benefit from doing so either.
518
+ """
519
+
520
+ def __init__(self):
521
+ self._all_operator_processors = _create_operator_type_usage_processors() # all possible processors
522
+ self._operator_processors = {} # processors we have actually used so we can limit output to be meaningful
523
+
524
+ def _get_op_processor(self, key):
525
+ "Add the processor to _operator_processors as it is about to be used."
526
+ processor = None
527
+ if key in self._all_operator_processors:
528
+ if key not in self._operator_processors:
529
+ self._operator_processors[key] = self._all_operator_processors[key]
530
+
531
+ processor = self._operator_processors[key]
532
+
533
+ return processor
534
+
535
+ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
536
+ """
537
+ Process a Node and record info on the types used.
538
+ :param node: Node from ORT format model
539
+ :param value_name_to_typeinfo: Map of value names to TypeInfo instances
540
+ """
541
+ optype = node.OpType().decode()
542
+ domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
543
+
544
+ key = _create_op_key(domain, optype)
545
+ op_processor = self._get_op_processor(key)
546
+ if op_processor:
547
+ op_processor.process_node(node, value_name_to_typeinfo)
548
+
549
+ def get_config_entry(self, domain: str, optype: str):
550
+ """
551
+ Get the config entry specifying the types for this operator.
552
+ :param domain: Operator domain.
553
+ :param optype: Operator type.
554
+ :return: JSON string with type info if available, else None
555
+ """
556
+ key = _create_op_key(domain, optype)
557
+ config_str = None
558
+ if key in self._operator_processors:
559
+ config_str = self._operator_processors[key].to_config_entry()
560
+
561
+ return config_str
562
+
563
+ def restore_from_config_entry(self, domain: str, optype: str, config_entry: str):
564
+ """
565
+ Restore the per-operator type information from a configuration file entry.
566
+ :param domain: Operator domain.
567
+ :param optype: Operator type.
568
+ :param config_entry: JSON string with type info as created by get_config_entry
569
+ """
570
+ key = _create_op_key(domain, optype)
571
+ op_processor = self._get_op_processor(key)
572
+ if op_processor:
573
+ op_processor.from_config_entry(config_entry)
574
+
575
+ def debug_dump(self):
576
+ print("C++ code that will be emitted:")
577
+ [print(cpp_line) for cpp_line in self.get_cpp_entries()]
578
+
579
+ print("Config file type information that will be returned by get_config_entry:")
580
+ for key in sorted(self._operator_processors.keys()):
581
+ entry = self._operator_processors[key].to_config_entry()
582
+ if entry:
583
+ print(f"{key} -> {entry}")
584
+
585
+ # roundtrip test to validate that we can initialize the processor from the entry and get the
586
+ # same values back
587
+ self._operator_processors[key].from_config_entry(entry)
588
+ assert entry == self._operator_processors[key].to_config_entry()
589
+
590
+ class _OpTypeImplFilter(OpTypeImplFilterInterface):
591
+ def __init__(self, manager):
592
+ self._manager = manager
593
+
594
+ def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
595
+ needed = True # we keep the registration unless the per-operator processor says not to
596
+ key = _create_op_key(domain, optype)
597
+ if key in self._manager._operator_processors:
598
+ needed = self._manager._operator_processors[key].is_typed_registration_needed(
599
+ type_in_registration=type_registration_str, globally_allowed_types=None
600
+ )
601
+
602
+ return needed
603
+
604
+ def get_cpp_entries(self):
605
+ entries = []
606
+ for key in sorted(self._manager._operator_processors.keys()):
607
+ entries.extend(self._manager._operator_processors[key].get_cpp_entry())
608
+
609
+ return entries
610
+
611
+ def make_op_type_impl_filter(self):
612
+ """
613
+ Creates an OpTypeImplFilterInterface instance from this manager.
614
+ Filtering uses the manager's operator type usage processor state.
615
+ """
616
+ return OperatorTypeUsageManager._OpTypeImplFilter(self)
617
+
618
+
619
+ class GloballyAllowedTypesOpTypeImplFilter(OpTypeImplFilterInterface):
620
+ """
621
+ Operator implementation filter which uses globally allowed types.
622
+ """
623
+
624
+ _valid_allowed_types = set(FbsTypeInfo.tensordatatype_to_string.values()) # noqa: RUF012
625
+
626
+ def __init__(self, globally_allowed_types: set[str]):
627
+ self._operator_processors = _create_operator_type_usage_processors()
628
+
629
+ if not globally_allowed_types.issubset(self._valid_allowed_types):
630
+ raise ValueError(
631
+ f"Globally allowed types must all be valid. Invalid types: {sorted(globally_allowed_types - self._valid_allowed_types)}"
632
+ )
633
+
634
+ self._globally_allowed_types = globally_allowed_types
635
+
636
+ def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
637
+ key = _create_op_key(domain, optype)
638
+ if key in self._operator_processors:
639
+ needed = self._operator_processors[key].is_typed_registration_needed(
640
+ type_in_registration=type_registration_str, globally_allowed_types=self._globally_allowed_types
641
+ )
642
+ else:
643
+ needed = _reg_type_to_cpp_type(type_registration_str) in self._globally_allowed_types
644
+
645
+ return needed
646
+
647
+ def get_cpp_entries(self):
648
+ return [
649
+ "ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES({});".format(", ".join(sorted(self._globally_allowed_types)))
650
+ ]
651
+
652
+ def global_type_list(self):
653
+ return self._globally_allowed_types
@@ -0,0 +1,7 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ class ArgType(object):
6
+ INPUT = 0
7
+ OUTPUT = 1