onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,663 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import json
5
+ import typing
6
+ from abc import ABC, abstractmethod
7
+
8
+ import ort_flatbuffers_py.fbs as fbs
9
+
10
+ from .types import FbsTypeInfo, value_name_to_typestr
11
+
12
+
13
+ def _create_op_key(domain: str, optype: str):
14
+ return f"{domain}:{optype}"
15
+
16
+
17
+ def _ort_constant_for_domain(domain: str):
18
+ """
19
+ Map a string domain value to the internal ONNX Runtime constant for that domain.
20
+ :param domain: Domain string to map.
21
+ :return: Internal ONNX Runtime constant
22
+ """
23
+
24
+ # constants are defined in <ORT root>/include/onnxruntime/core/graph/constants.h
25
+ # This list is limited to just the domains we have processors for
26
+ domain_to_constant_map = {"ai.onnx": "kOnnxDomain", "ai.onnx.ml": "kMLDomain", "com.microsoft": "kMSDomain"}
27
+
28
+ if domain not in domain_to_constant_map:
29
+ raise ValueError(f"Domain {domain} not found in map to ONNX Runtime constant. Please update map.")
30
+
31
+ return domain_to_constant_map[domain]
32
+
33
+
34
+ def _reg_type_to_cpp_type(reg_type: str):
35
+ if reg_type == "string":
36
+ return "std::string"
37
+ return reg_type
38
+
39
+
40
+ def _split_reg_types(reg_types_str: str):
41
+ """
42
+ Split on underscores but append "_t" to the previous element.
43
+ """
44
+ tokens = reg_types_str.split("_")
45
+ reg_types = []
46
+ for token in tokens:
47
+ if token == "t" and len(reg_types) > 0:
48
+ reg_types[-1] += "_t"
49
+ else:
50
+ reg_types += [token]
51
+ return reg_types
52
+
53
+
54
+ class TypeUsageProcessor(ABC):
55
+ """
56
+ Abstract base class for processors which implement operator specific logic to determine the type or types required.
57
+ """
58
+
59
+ def __init__(self, domain: str, optype: str):
60
+ self.domain = domain
61
+ self.optype = optype
62
+ self.name = _create_op_key(domain, optype)
63
+
64
+ @abstractmethod
65
+ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
66
+ pass
67
+
68
+ def is_typed_registration_needed(
69
+ self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]]
70
+ ):
71
+ """
72
+ Given the string from a kernel registration, determine if the registration is required or not.
73
+ :param type_in_registration: Type string from kernel registration
74
+ :param globally_allowed_types: Optional set of globally allowed types. If provided, these types take precedence
75
+ in determining the required types.
76
+ :return: True is required. False if not.
77
+ """
78
+ # Not all operators have typed registrations, so this is optionally implemented by derived classes
79
+ raise RuntimeError(f"Did not expect processor for {self.name} to have typed registrations.")
80
+
81
+ def get_cpp_entry(self):
82
+ """
83
+ Get the C++ code that specifies this operator's required types.
84
+ :return: List with any applicable C++ code for this operator's required types. One line per entry.
85
+ """
86
+ # Not applicable for some ops, so return no lines by default.
87
+ return []
88
+
89
+ @abstractmethod
90
+ def to_config_entry(self):
91
+ """
92
+ Generate a configuration file entry in JSON format with the required types for the operator.
93
+ :return: JSON string with required type information.
94
+ """
95
+
96
+ @abstractmethod
97
+ def from_config_entry(self, entry: str):
98
+ """
99
+ Re-create the types required from a configuration file entry created with to_config_entry.
100
+ NOTE: Any existing type information should be cleared prior to re-creating from a config file entry.
101
+ :param entry: Configuration file entry
102
+ """
103
+
104
+
105
+ class DefaultTypeUsageProcessor(TypeUsageProcessor):
106
+ """
107
+ Operator processor which tracks the types used for selected input/s and/or output/s.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ domain: str,
113
+ optype: str,
114
+ inputs: [int] = [0], # noqa: B006
115
+ outputs: [int] = [], # noqa: B006
116
+ required_input_types: typing.Dict[int, typing.Set[str]] = {}, # noqa: B006
117
+ required_output_types: typing.Dict[int, typing.Set[str]] = {}, # noqa: B006
118
+ ):
119
+ """
120
+ Create DefaultTypeUsageProcessor. Types for one or more inputs and/or outputs can be tracked by the processor.
121
+ The default is to track the types required for input 0, as this is the most common use case in ONNX.
122
+
123
+ Required input and output types may be specified. These are only applicable to is_typed_registration_needed().
124
+ If a registration type matches a required type, the typed registration is needed.
125
+ There is a separate mechanism for specifying required types from C++ for kernels with untyped registration.
126
+
127
+ :param domain: Operator domain.
128
+ :param optype: Operator name.
129
+ :param inputs: Inputs to track. Zero based index. May be empty.
130
+ :param outputs: Outputs to track. Zero based index. May be empty.
131
+ :param required_input_types: Required input types. May be empty.
132
+ :param required_output_types: Required output types. May be empty.
133
+ """
134
+ super().__init__(domain, optype)
135
+ self._input_types = {}
136
+ self._output_types = {}
137
+
138
+ for i in inputs:
139
+ self._input_types[i] = set()
140
+
141
+ for o in outputs:
142
+ self._output_types[o] = set()
143
+
144
+ if not inputs and not outputs:
145
+ raise ValueError("At least one input or output must be tracked")
146
+
147
+ self._required_input_types = required_input_types
148
+ self._required_output_types = required_output_types
149
+
150
+ def _is_type_enabled(self, reg_type, index, required_types, allowed_type_set):
151
+ cpp_type = _reg_type_to_cpp_type(reg_type)
152
+ return cpp_type in required_types.get(index, set()) or cpp_type in allowed_type_set
153
+
154
+ def is_input_type_enabled(self, reg_type, index, allowed_type_set=None):
155
+ """Whether input type is enabled based on required and allowed types."""
156
+ if allowed_type_set is None:
157
+ allowed_type_set = self._input_types[index]
158
+ return self._is_type_enabled(reg_type, index, self._required_input_types, allowed_type_set)
159
+
160
+ def is_output_type_enabled(self, reg_type, index, allowed_type_set=None):
161
+ """Whether output type is enabled based on required and allowed types."""
162
+ if allowed_type_set is None:
163
+ allowed_type_set = self._output_types[index]
164
+ return self._is_type_enabled(reg_type, index, self._required_output_types, allowed_type_set)
165
+
166
+ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
167
+ for i in self._input_types:
168
+ if i >= node.InputsLength():
169
+ # Some operators have fewer inputs in earlier versions where data that was as an attribute
170
+ # become an input in later versions to allow it to be dynamically provided. Allow for that.
171
+ # e.g. Slice-1 had attributes for the indices, and Slice-10 moved those to be inputs
172
+ # raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.'
173
+ # .format(node.OutputsLength(), self.name, o))
174
+ pass
175
+ else:
176
+ type_str = value_name_to_typestr(node.Inputs(i), value_name_to_typeinfo)
177
+ self._input_types[i].add(type_str)
178
+
179
+ for o in self._output_types:
180
+ # Don't know of any ops where the number of outputs changed across versions, so require a valid length
181
+ if o >= node.OutputsLength():
182
+ raise RuntimeError(
183
+ f"Node has {node.OutputsLength()} outputs. Tracker for {self.name} incorrectly configured as it requires {o}."
184
+ )
185
+
186
+ type_str = value_name_to_typestr(node.Outputs(o), value_name_to_typeinfo)
187
+ self._output_types[o].add(type_str)
188
+
189
+ def is_typed_registration_needed(
190
+ self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]]
191
+ ):
192
+ if 0 not in self._input_types:
193
+ # currently all standard typed registrations are for input 0.
194
+ # custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below).
195
+ raise RuntimeError(f"Expected typed registration to use type from input 0. Node:{self.name}")
196
+
197
+ return self.is_input_type_enabled(type_in_registration, 0, globally_allowed_types)
198
+
199
+ def get_cpp_entry(self):
200
+ entries = []
201
+ domain = _ort_constant_for_domain(self.domain)
202
+ for i in sorted(self._input_types.keys()):
203
+ if self._input_types[i]:
204
+ entries.append(
205
+ "ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Input, {}, {});".format(
206
+ domain, self.optype, i, ", ".join(sorted(self._input_types[i]))
207
+ )
208
+ )
209
+
210
+ for o in sorted(self._output_types.keys()):
211
+ if self._output_types[o]:
212
+ entries.append(
213
+ "ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Output, {}, {});".format(
214
+ domain, self.optype, o, ", ".join(sorted(self._output_types[o]))
215
+ )
216
+ )
217
+
218
+ return entries
219
+
220
+ def to_config_entry(self):
221
+ # convert the sets of types to lists so they can easily written out using the json model
222
+ aggregate_info = {"inputs": {}, "outputs": {}}
223
+
224
+ # filter out empty entries and sort the types
225
+ for i in sorted(self._input_types.keys()):
226
+ if self._input_types[i]:
227
+ aggregate_info["inputs"][i] = sorted(self._input_types[i])
228
+
229
+ for o in sorted(self._output_types.keys()):
230
+ if self._output_types[o]:
231
+ aggregate_info["outputs"][o] = sorted(self._output_types[o])
232
+
233
+ # remove any empty keys
234
+ if not aggregate_info["inputs"]:
235
+ aggregate_info.pop("inputs")
236
+ if not aggregate_info["outputs"]:
237
+ aggregate_info.pop("outputs")
238
+
239
+ entry = json.dumps(aggregate_info) if aggregate_info else None
240
+ return entry
241
+
242
+ def from_config_entry(self, entry: str):
243
+ self._input_types.clear()
244
+ self._output_types.clear()
245
+
246
+ aggregate_info = json.loads(entry)
247
+ if "inputs" in aggregate_info:
248
+ for i_str, values in aggregate_info["inputs"].items():
249
+ self._input_types[int(i_str)] = set(values)
250
+
251
+ if "outputs" in aggregate_info:
252
+ for o_str, values in aggregate_info["outputs"].items():
253
+ self._output_types[int(o_str)] = set(values)
254
+
255
+
256
+ class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor):
257
+ """
258
+ Processor for operators where the second input type is used in a typed kernel registration.
259
+ """
260
+
261
+ def __init__(self, domain: str, optype: str):
262
+ # init with tracking of input 1 only.
263
+ super().__init__(domain, optype, inputs=[1], outputs=[])
264
+
265
+ def is_typed_registration_needed(
266
+ self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]]
267
+ ):
268
+ return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types)
269
+
270
+
271
+ class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor):
272
+ """
273
+ Processor for operators where the first output type is used in a typed kernel registration.
274
+ """
275
+
276
+ def __init__(self, domain: str, optype: str):
277
+ # init with tracking of output 0 only.
278
+ super().__init__(domain, optype, inputs=[], outputs=[0])
279
+
280
+ def is_typed_registration_needed(
281
+ self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]]
282
+ ):
283
+ return self.is_output_type_enabled(type_in_registration, 0, globally_allowed_types)
284
+
285
+
286
+ class OneHotProcessor(TypeUsageProcessor):
287
+ """
288
+ Processor for the OneHot operator, which requires custom logic as the type registration key is a concatenation of
289
+ the three types involved instead of a single type name.
290
+ """
291
+
292
+ def __init__(self):
293
+ super().__init__("ai.onnx", "OneHot")
294
+ self._triples = set()
295
+
296
+ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
297
+ type0 = value_name_to_typestr(node.Inputs(0), value_name_to_typeinfo)
298
+ type1 = value_name_to_typestr(node.Inputs(1), value_name_to_typeinfo)
299
+ type2 = value_name_to_typestr(node.Inputs(2), value_name_to_typeinfo)
300
+ # types in kernel registration are ordered this way: input (T1), output (T3), depth (T2)
301
+ key = (type0, type2, type1)
302
+ self._triples.add(key)
303
+
304
+ def is_typed_registration_needed(
305
+ self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]]
306
+ ):
307
+ # the OneHot registration involves a concatenation of the 3 types involved
308
+ reg_types = tuple([_reg_type_to_cpp_type(reg_type) for reg_type in _split_reg_types(type_in_registration)])
309
+ if globally_allowed_types is not None:
310
+ return all(reg_type in globally_allowed_types for reg_type in reg_types)
311
+ else:
312
+ return reg_types in self._triples
313
+
314
+ def to_config_entry(self):
315
+ if not self._triples:
316
+ return None
317
+
318
+ aggregate_info = {"custom": sorted(self._triples)}
319
+ entry = json.dumps(aggregate_info)
320
+ return entry
321
+
322
+ def from_config_entry(self, entry: str):
323
+ self._triples.clear()
324
+ aggregate_info = json.loads(entry)
325
+ if "custom" in aggregate_info:
326
+ self._triples = {tuple(triple) for triple in aggregate_info["custom"]}
327
+
328
+
329
+ def _create_operator_type_usage_processors():
330
+ """
331
+ Create a set of processors that determine the required types for all enabled operators.
332
+ :return: Dictionary of operator key to processor. Key is 'domain:operator (e.g. ai.onnx:Cast)'.
333
+ """
334
+ operator_processors = {}
335
+
336
+ def add(processor):
337
+ if processor.name in operator_processors:
338
+ raise RuntimeError("Duplicate processor for " + processor.name)
339
+
340
+ operator_processors[processor.name] = processor
341
+
342
+ # Starting with ops from:
343
+ # - Priority 1P models
344
+ # - Mobilenet + SSD Mobilenet + MobileBert
345
+ # - some known large kernels
346
+ #
347
+ # Ops we are ignoring currently so as not to produce meaningless/unused output:
348
+ # - Implementation is type agnostic:
349
+ # ai.onnx: If, Loop, Reshape, Scan, Shape, Squeeze, Tile, Unsqueeze
350
+ # com.microsoft: DynamicQuantizeMatMul, MatMulIntegerToFloat
351
+ # - Only one type supported in the ORT implementation:
352
+ # ai.onnx: NonMaxSuppression
353
+ # com.microsoft: FusedConv, FusedGemm, FusedMatMul
354
+ # - Implementation does not have any significant type specific code:
355
+ # ai.onnx: Concat, Flatten, Not, Reshape, Shape, Squeeze, Unsqueeze
356
+ #
357
+ default_processor_onnx_ops = [
358
+ "Abs",
359
+ "ArgMax",
360
+ "ArgMin",
361
+ "AveragePool",
362
+ "BatchNormalization",
363
+ "BitShift",
364
+ "Ceil",
365
+ "Clip",
366
+ "Conv",
367
+ "CumSum",
368
+ "Exp",
369
+ "Expand",
370
+ "Floor",
371
+ "Gemm",
372
+ "IsNaN",
373
+ "Log",
374
+ "LogSoftmax",
375
+ "LpNormalization",
376
+ "MatMul",
377
+ "Max",
378
+ "MaxPool",
379
+ "Mean",
380
+ "Min",
381
+ "NonZero",
382
+ "Pad",
383
+ "QLinearConv",
384
+ "QLinearMatMul",
385
+ "Range",
386
+ "Reciprocal",
387
+ "ReduceL1",
388
+ "ReduceL2",
389
+ "ReduceLogSum",
390
+ "ReduceLogSumExp",
391
+ "ReduceMax",
392
+ "ReduceMean",
393
+ "ReduceMin",
394
+ "ReduceProd",
395
+ "ReduceSum",
396
+ "ReduceSumSquare",
397
+ "Relu",
398
+ "Resize",
399
+ "ReverseSequence",
400
+ "RoiAlign",
401
+ "Round",
402
+ "Scatter",
403
+ "ScatterElements",
404
+ "ScatterND",
405
+ "Shrink",
406
+ "Sigmoid",
407
+ "Sign",
408
+ "Sin",
409
+ "Softmax",
410
+ "Split",
411
+ "SplitToSequence",
412
+ "Sqrt",
413
+ "Sum",
414
+ "Tanh",
415
+ "TopK",
416
+ "Transpose",
417
+ "Unique",
418
+ ]
419
+
420
+ # ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available
421
+ default_processor_onnx_ops_requiring_ints_for_input_0 = [
422
+ "Add",
423
+ "Concat",
424
+ "Div",
425
+ "Equal",
426
+ "Greater",
427
+ "Less",
428
+ "Mul",
429
+ "Neg", # used in tflite TransposeConv conversion
430
+ "Sub",
431
+ ]
432
+
433
+ # NOTE: QLinearConv has ONNX and internal implementations
434
+ internal_ops = ["QLinearAdd", "QLinearMul", "QLinearConv"]
435
+
436
+ # TODO - review and add ML ops as needed
437
+ # ML Op notes.
438
+ # CastMap: Switch on value type of input map type, and output type
439
+ # DictVectorizer: Templatized on key+value of input so need to handle like OneHot with custom processor
440
+ # LabelEncoder: Implementation switches on input and output types (only supports string and int64 in T1 and T2)
441
+ # LinearClassifier: Internal switch on input type and also switch on output type
442
+ # SVMClassifier: ditto
443
+ # TreeEnsembleClassifier: Templatized on input type and also switch on output type
444
+ # ZipMap: Switch on output type (derived from attributes)
445
+ default_processor_onnxml_ops = []
446
+
447
+ [add(DefaultTypeUsageProcessor("ai.onnx", op)) for op in default_processor_onnx_ops]
448
+ [
449
+ add(DefaultTypeUsageProcessor("ai.onnx", op, required_input_types={0: {"int32_t", "int64_t"}}))
450
+ for op in default_processor_onnx_ops_requiring_ints_for_input_0
451
+ ]
452
+ [add(DefaultTypeUsageProcessor("ai.onnx.ml", op)) for op in default_processor_onnxml_ops]
453
+ [add(DefaultTypeUsageProcessor("com.microsoft", op)) for op in internal_ops]
454
+
455
+ #
456
+ # Operators that require custom handling
457
+ #
458
+
459
+ # Cast switches on types of input 0 and output 0
460
+ add(DefaultTypeUsageProcessor("ai.onnx", "Cast", inputs=[0], outputs=[0]))
461
+
462
+ # Operators that switch on the type of input 0 and 1
463
+ add(DefaultTypeUsageProcessor("ai.onnx", "Gather", inputs=[0, 1]))
464
+ add(DefaultTypeUsageProcessor("ai.onnx", "GatherElements", inputs=[0, 1]))
465
+ add(DefaultTypeUsageProcessor("ai.onnx", "Pow", inputs=[0, 1]))
466
+ add(DefaultTypeUsageProcessor("ai.onnx", "Slice", inputs=[0, 1]))
467
+
468
+ # Operators that switch on output type
469
+ add(DefaultTypeUsageProcessor("ai.onnx", "ConstantOfShape", inputs=[], outputs=[0]))
470
+
471
+ # Random generator ops produce new data so we track the output type
472
+ onnx_random_ops = ["RandomNormal", "RandomNormalLike", "RandomUniform", "RandomUniformLike", "Multinomial"]
473
+ [add(DefaultTypeUsageProcessor("ai.onnx", op, inputs=[], outputs=[0])) for op in onnx_random_ops]
474
+
475
+ # Where always has a boolean first input so track the second input type for typed registration
476
+ add(Input1TypedRegistrationProcessor("ai.onnx", "Where"))
477
+
478
+ # we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type
479
+ # as that's what is used in the typed registration
480
+ add(Output0TypedRegistrationProcessor("ai.onnx", "QuantizeLinear"))
481
+ add(Output0TypedRegistrationProcessor("ai.onnx", "DynamicQuantizeLinear"))
482
+
483
+ # make sure all the dequantize types are enabled. we use int32_t for parts of GEMM and Conv so just
484
+ # enabling int8 and uint8 is not enough.
485
+ # TODO: Only apply required types to the global type list and ignore if it's model based per-op type reduction
486
+ add(
487
+ DefaultTypeUsageProcessor(
488
+ "ai.onnx", "DequantizeLinear", inputs=[0], required_input_types={0: {"int8_t", "uint8_t", "int32_t"}}
489
+ )
490
+ )
491
+
492
+ # OneHot concatenates type strings into a triple in the typed registration
493
+ # e.g. float_int64_t_int64_t
494
+ add(OneHotProcessor())
495
+
496
+ return operator_processors
497
+
498
+
499
+ class OpTypeImplFilterInterface(ABC):
500
+ """
501
+ Class that filters operator implementations based on type.
502
+ """
503
+
504
+ @abstractmethod
505
+ def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
506
+ """
507
+ Given the string from a kernel registration, determine if the registration is required or not.
508
+ :param domain: Operator domain.
509
+ :param optype: Operator type.
510
+ :param type_registration_str: Type string from kernel registration
511
+ :return: True is required. False if not.
512
+ """
513
+
514
+ @abstractmethod
515
+ def get_cpp_entries(self):
516
+ """
517
+ Get the C++ code that specifies the operator types to enable.
518
+ :return: List of strings. One line of C++ code per entry.
519
+ """
520
+
521
+
522
+ class OperatorTypeUsageManager:
523
+ """
524
+ Class to manage the operator type usage processors.
525
+ TODO: Currently the type tracking is not specific to a version of the operator.
526
+ It's unclear how/where version specific logic could/should be added, and it would add significant complexity
527
+ to track types on a per-version basis. Not clear there's enough benefit from doing so either.
528
+ """
529
+
530
+ def __init__(self):
531
+ self._all_operator_processors = _create_operator_type_usage_processors() # all possible processors
532
+ self._operator_processors = {} # processors we have actually used so we can limit output to be meaningful
533
+
534
+ def _get_op_processor(self, key):
535
+ "Add the processor to _operator_processors as it is about to be used."
536
+ processor = None
537
+ if key in self._all_operator_processors:
538
+ if key not in self._operator_processors:
539
+ self._operator_processors[key] = self._all_operator_processors[key]
540
+
541
+ processor = self._operator_processors[key]
542
+
543
+ return processor
544
+
545
+ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict):
546
+ """
547
+ Process a Node and record info on the types used.
548
+ :param node: Node from ORT format model
549
+ :param value_name_to_typeinfo: Map of value names to TypeInfo instances
550
+ """
551
+ optype = node.OpType().decode()
552
+ domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
553
+
554
+ key = _create_op_key(domain, optype)
555
+ op_processor = self._get_op_processor(key)
556
+ if op_processor:
557
+ op_processor.process_node(node, value_name_to_typeinfo)
558
+
559
+ def get_config_entry(self, domain: str, optype: str):
560
+ """
561
+ Get the config entry specifying the types for this operator.
562
+ :param domain: Operator domain.
563
+ :param optype: Operator type.
564
+ :return: JSON string with type info if available, else None
565
+ """
566
+ key = _create_op_key(domain, optype)
567
+ config_str = None
568
+ if key in self._operator_processors:
569
+ config_str = self._operator_processors[key].to_config_entry()
570
+
571
+ return config_str
572
+
573
+ def restore_from_config_entry(self, domain: str, optype: str, config_entry: str):
574
+ """
575
+ Restore the per-operator type information from a configuration file entry.
576
+ :param domain: Operator domain.
577
+ :param optype: Operator type.
578
+ :param config_entry: JSON string with type info as created by get_config_entry
579
+ """
580
+ key = _create_op_key(domain, optype)
581
+ op_processor = self._get_op_processor(key)
582
+ if op_processor:
583
+ op_processor.from_config_entry(config_entry)
584
+
585
+ def debug_dump(self):
586
+ print("C++ code that will be emitted:")
587
+ [print(cpp_line) for cpp_line in self.get_cpp_entries()]
588
+
589
+ print("Config file type information that will be returned by get_config_entry:")
590
+ for key in sorted(self._operator_processors.keys()):
591
+ entry = self._operator_processors[key].to_config_entry()
592
+ if entry:
593
+ print(f"{key} -> {entry}")
594
+
595
+ # roundtrip test to validate that we can initialize the processor from the entry and get the
596
+ # same values back
597
+ self._operator_processors[key].from_config_entry(entry)
598
+ assert entry == self._operator_processors[key].to_config_entry()
599
+
600
+ class _OpTypeImplFilter(OpTypeImplFilterInterface):
601
+ def __init__(self, manager):
602
+ self._manager = manager
603
+
604
+ def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
605
+ needed = True # we keep the registration unless the per-operator processor says not to
606
+ key = _create_op_key(domain, optype)
607
+ if key in self._manager._operator_processors:
608
+ needed = self._manager._operator_processors[key].is_typed_registration_needed(
609
+ type_in_registration=type_registration_str, globally_allowed_types=None
610
+ )
611
+
612
+ return needed
613
+
614
+ def get_cpp_entries(self):
615
+ entries = []
616
+ for key in sorted(self._manager._operator_processors.keys()):
617
+ entries.extend(self._manager._operator_processors[key].get_cpp_entry())
618
+
619
+ return entries
620
+
621
+ def make_op_type_impl_filter(self):
622
+ """
623
+ Creates an OpTypeImplFilterInterface instance from this manager.
624
+ Filtering uses the manager's operator type usage processor state.
625
+ """
626
+ return OperatorTypeUsageManager._OpTypeImplFilter(self)
627
+
628
+
629
+ class GloballyAllowedTypesOpTypeImplFilter(OpTypeImplFilterInterface):
630
+ """
631
+ Operator implementation filter which uses globally allowed types.
632
+ """
633
+
634
+ _valid_allowed_types = set(FbsTypeInfo.tensordatatype_to_string.values()) # noqa: RUF012
635
+
636
+ def __init__(self, globally_allowed_types: typing.Set[str]):
637
+ self._operator_processors = _create_operator_type_usage_processors()
638
+
639
+ if not globally_allowed_types.issubset(self._valid_allowed_types):
640
+ raise ValueError(
641
+ f"Globally allowed types must all be valid. Invalid types: {sorted(globally_allowed_types - self._valid_allowed_types)}"
642
+ )
643
+
644
+ self._globally_allowed_types = globally_allowed_types
645
+
646
+ def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str):
647
+ key = _create_op_key(domain, optype)
648
+ if key in self._operator_processors:
649
+ needed = self._operator_processors[key].is_typed_registration_needed(
650
+ type_in_registration=type_registration_str, globally_allowed_types=self._globally_allowed_types
651
+ )
652
+ else:
653
+ needed = _reg_type_to_cpp_type(type_registration_str) in self._globally_allowed_types
654
+
655
+ return needed
656
+
657
+ def get_cpp_entries(self):
658
+ return [
659
+ "ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES({});".format(", ".join(sorted(self._globally_allowed_types)))
660
+ ]
661
+
662
+ def global_type_list(self):
663
+ return self._globally_allowed_types
@@ -0,0 +1,7 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ class ArgType(object):
6
+ INPUT = 0
7
+ OUTPUT = 1