onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,67 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class ArgTypeAndIndex(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = ArgTypeAndIndex()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsArgTypeAndIndex(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def ArgTypeAndIndexBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # ArgTypeAndIndex
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # ArgTypeAndIndex
32
+ def ArgType(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
36
+ return 0
37
+
38
+ # ArgTypeAndIndex
39
+ def Index(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
43
+ return 0
44
+
45
+ def ArgTypeAndIndexStart(builder):
46
+ builder.StartObject(2)
47
+
48
+ def Start(builder):
49
+ ArgTypeAndIndexStart(builder)
50
+
51
+ def ArgTypeAndIndexAddArgType(builder, argType):
52
+ builder.PrependInt8Slot(0, argType, 0)
53
+
54
+ def AddArgType(builder, argType):
55
+ ArgTypeAndIndexAddArgType(builder, argType)
56
+
57
+ def ArgTypeAndIndexAddIndex(builder, index):
58
+ builder.PrependUint32Slot(1, index, 0)
59
+
60
+ def AddIndex(builder, index):
61
+ ArgTypeAndIndexAddIndex(builder, index)
62
+
63
+ def ArgTypeAndIndexEnd(builder):
64
+ return builder.EndObject()
65
+
66
+ def End(builder):
67
+ return ArgTypeAndIndexEnd(builder)
@@ -0,0 +1,337 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class Attribute(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = Attribute()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsAttribute(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def AttributeBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
26
+
27
+ # Attribute
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # Attribute
32
+ def Name(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.String(o + self._tab.Pos)
36
+ return None
37
+
38
+ # Attribute
39
+ def DocString(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ return self._tab.String(o + self._tab.Pos)
43
+ return None
44
+
45
+ # Attribute
46
+ def Type(self):
47
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
48
+ if o != 0:
49
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
50
+ return 0
51
+
52
+ # Attribute
53
+ def F(self):
54
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
55
+ if o != 0:
56
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
57
+ return 0.0
58
+
59
+ # Attribute
60
+ def I(self):
61
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
62
+ if o != 0:
63
+ return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
64
+ return 0
65
+
66
+ # Attribute
67
+ def S(self):
68
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
69
+ if o != 0:
70
+ return self._tab.String(o + self._tab.Pos)
71
+ return None
72
+
73
+ # Attribute
74
+ def T(self):
75
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(16))
76
+ if o != 0:
77
+ x = self._tab.Indirect(o + self._tab.Pos)
78
+ from ort_flatbuffers_py.fbs.Tensor import Tensor
79
+ obj = Tensor()
80
+ obj.Init(self._tab.Bytes, x)
81
+ return obj
82
+ return None
83
+
84
+ # Attribute
85
+ def G(self):
86
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
87
+ if o != 0:
88
+ x = self._tab.Indirect(o + self._tab.Pos)
89
+ from ort_flatbuffers_py.fbs.Graph import Graph
90
+ obj = Graph()
91
+ obj.Init(self._tab.Bytes, x)
92
+ return obj
93
+ return None
94
+
95
+ # Attribute
96
+ def Floats(self, j):
97
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
98
+ if o != 0:
99
+ a = self._tab.Vector(o)
100
+ return self._tab.Get(flatbuffers.number_types.Float32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
101
+ return 0
102
+
103
+ # Attribute
104
+ def FloatsAsNumpy(self):
105
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
106
+ if o != 0:
107
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Float32Flags, o)
108
+ return 0
109
+
110
+ # Attribute
111
+ def FloatsLength(self):
112
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
113
+ if o != 0:
114
+ return self._tab.VectorLen(o)
115
+ return 0
116
+
117
+ # Attribute
118
+ def FloatsIsNone(self):
119
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
120
+ return o == 0
121
+
122
+ # Attribute
123
+ def Ints(self, j):
124
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
125
+ if o != 0:
126
+ a = self._tab.Vector(o)
127
+ return self._tab.Get(flatbuffers.number_types.Int64Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8))
128
+ return 0
129
+
130
+ # Attribute
131
+ def IntsAsNumpy(self):
132
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
133
+ if o != 0:
134
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int64Flags, o)
135
+ return 0
136
+
137
+ # Attribute
138
+ def IntsLength(self):
139
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
140
+ if o != 0:
141
+ return self._tab.VectorLen(o)
142
+ return 0
143
+
144
+ # Attribute
145
+ def IntsIsNone(self):
146
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
147
+ return o == 0
148
+
149
+ # Attribute
150
+ def Strings(self, j):
151
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
152
+ if o != 0:
153
+ a = self._tab.Vector(o)
154
+ return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
155
+ return ""
156
+
157
+ # Attribute
158
+ def StringsLength(self):
159
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
160
+ if o != 0:
161
+ return self._tab.VectorLen(o)
162
+ return 0
163
+
164
+ # Attribute
165
+ def StringsIsNone(self):
166
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(24))
167
+ return o == 0
168
+
169
+ # Attribute
170
+ def Tensors(self, j):
171
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
172
+ if o != 0:
173
+ x = self._tab.Vector(o)
174
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
175
+ x = self._tab.Indirect(x)
176
+ from ort_flatbuffers_py.fbs.Tensor import Tensor
177
+ obj = Tensor()
178
+ obj.Init(self._tab.Bytes, x)
179
+ return obj
180
+ return None
181
+
182
+ # Attribute
183
+ def TensorsLength(self):
184
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
185
+ if o != 0:
186
+ return self._tab.VectorLen(o)
187
+ return 0
188
+
189
+ # Attribute
190
+ def TensorsIsNone(self):
191
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(26))
192
+ return o == 0
193
+
194
+ # Attribute
195
+ def Graphs(self, j):
196
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
197
+ if o != 0:
198
+ x = self._tab.Vector(o)
199
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
200
+ x = self._tab.Indirect(x)
201
+ from ort_flatbuffers_py.fbs.Graph import Graph
202
+ obj = Graph()
203
+ obj.Init(self._tab.Bytes, x)
204
+ return obj
205
+ return None
206
+
207
+ # Attribute
208
+ def GraphsLength(self):
209
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
210
+ if o != 0:
211
+ return self._tab.VectorLen(o)
212
+ return 0
213
+
214
+ # Attribute
215
+ def GraphsIsNone(self):
216
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(28))
217
+ return o == 0
218
+
219
+ def AttributeStart(builder):
220
+ builder.StartObject(13)
221
+
222
+ def Start(builder):
223
+ AttributeStart(builder)
224
+
225
+ def AttributeAddName(builder, name):
226
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
227
+
228
+ def AddName(builder, name):
229
+ AttributeAddName(builder, name)
230
+
231
+ def AttributeAddDocString(builder, docString):
232
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(docString), 0)
233
+
234
+ def AddDocString(builder, docString):
235
+ AttributeAddDocString(builder, docString)
236
+
237
+ def AttributeAddType(builder, type):
238
+ builder.PrependInt32Slot(2, type, 0)
239
+
240
+ def AddType(builder, type):
241
+ AttributeAddType(builder, type)
242
+
243
+ def AttributeAddF(builder, f):
244
+ builder.PrependFloat32Slot(3, f, 0.0)
245
+
246
+ def AddF(builder, f):
247
+ AttributeAddF(builder, f)
248
+
249
+ def AttributeAddI(builder, i):
250
+ builder.PrependInt64Slot(4, i, 0)
251
+
252
+ def AddI(builder, i):
253
+ AttributeAddI(builder, i)
254
+
255
+ def AttributeAddS(builder, s):
256
+ builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(s), 0)
257
+
258
+ def AddS(builder, s):
259
+ AttributeAddS(builder, s)
260
+
261
+ def AttributeAddT(builder, t):
262
+ builder.PrependUOffsetTRelativeSlot(6, flatbuffers.number_types.UOffsetTFlags.py_type(t), 0)
263
+
264
+ def AddT(builder, t):
265
+ AttributeAddT(builder, t)
266
+
267
+ def AttributeAddG(builder, g):
268
+ builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(g), 0)
269
+
270
+ def AddG(builder, g):
271
+ AttributeAddG(builder, g)
272
+
273
+ def AttributeAddFloats(builder, floats):
274
+ builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(floats), 0)
275
+
276
+ def AddFloats(builder, floats):
277
+ AttributeAddFloats(builder, floats)
278
+
279
+ def AttributeStartFloatsVector(builder, numElems):
280
+ return builder.StartVector(4, numElems, 4)
281
+
282
+ def StartFloatsVector(builder, numElems: int) -> int:
283
+ return AttributeStartFloatsVector(builder, numElems)
284
+
285
+ def AttributeAddInts(builder, ints):
286
+ builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(ints), 0)
287
+
288
+ def AddInts(builder, ints):
289
+ AttributeAddInts(builder, ints)
290
+
291
+ def AttributeStartIntsVector(builder, numElems):
292
+ return builder.StartVector(8, numElems, 8)
293
+
294
+ def StartIntsVector(builder, numElems: int) -> int:
295
+ return AttributeStartIntsVector(builder, numElems)
296
+
297
+ def AttributeAddStrings(builder, strings):
298
+ builder.PrependUOffsetTRelativeSlot(10, flatbuffers.number_types.UOffsetTFlags.py_type(strings), 0)
299
+
300
+ def AddStrings(builder, strings):
301
+ AttributeAddStrings(builder, strings)
302
+
303
+ def AttributeStartStringsVector(builder, numElems):
304
+ return builder.StartVector(4, numElems, 4)
305
+
306
+ def StartStringsVector(builder, numElems: int) -> int:
307
+ return AttributeStartStringsVector(builder, numElems)
308
+
309
+ def AttributeAddTensors(builder, tensors):
310
+ builder.PrependUOffsetTRelativeSlot(11, flatbuffers.number_types.UOffsetTFlags.py_type(tensors), 0)
311
+
312
+ def AddTensors(builder, tensors):
313
+ AttributeAddTensors(builder, tensors)
314
+
315
+ def AttributeStartTensorsVector(builder, numElems):
316
+ return builder.StartVector(4, numElems, 4)
317
+
318
+ def StartTensorsVector(builder, numElems: int) -> int:
319
+ return AttributeStartTensorsVector(builder, numElems)
320
+
321
+ def AttributeAddGraphs(builder, graphs):
322
+ builder.PrependUOffsetTRelativeSlot(12, flatbuffers.number_types.UOffsetTFlags.py_type(graphs), 0)
323
+
324
+ def AddGraphs(builder, graphs):
325
+ AttributeAddGraphs(builder, graphs)
326
+
327
+ def AttributeStartGraphsVector(builder, numElems):
328
+ return builder.StartVector(4, numElems, 4)
329
+
330
+ def StartGraphsVector(builder, numElems: int) -> int:
331
+ return AttributeStartGraphsVector(builder, numElems)
332
+
333
+ def AttributeEnd(builder):
334
+ return builder.EndObject()
335
+
336
+ def End(builder):
337
+ return AttributeEnd(builder)
@@ -0,0 +1,18 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ class AttributeType(object):
6
+ UNDEFINED = 0
7
+ FLOAT = 1
8
+ INT = 2
9
+ STRING = 3
10
+ TENSOR = 4
11
+ GRAPH = 5
12
+ FLOATS = 6
13
+ INTS = 7
14
+ STRINGS = 8
15
+ TENSORS = 9
16
+ GRAPHS = 10
17
+ SPARSE_TENSOR = 11
18
+ SPARSE_TENSORS = 12
@@ -0,0 +1,125 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ class Checkpoint(object):
10
+ __slots__ = ['_tab']
11
+
12
+ @classmethod
13
+ def GetRootAs(cls, buf, offset=0):
14
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
15
+ x = Checkpoint()
16
+ x.Init(buf, n + offset)
17
+ return x
18
+
19
+ @classmethod
20
+ def GetRootAsCheckpoint(cls, buf, offset=0):
21
+ """This method is deprecated. Please switch to GetRootAs."""
22
+ return cls.GetRootAs(buf, offset)
23
+ @classmethod
24
+ def CheckpointBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
25
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)
26
+
27
+ # Checkpoint
28
+ def Init(self, buf, pos):
29
+ self._tab = flatbuffers.table.Table(buf, pos)
30
+
31
+ # Checkpoint
32
+ def Version(self):
33
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
34
+ if o != 0:
35
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
36
+ return 0
37
+
38
+ # Checkpoint
39
+ def ModuleState(self):
40
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
41
+ if o != 0:
42
+ x = self._tab.Indirect(o + self._tab.Pos)
43
+ from ort_flatbuffers_py.fbs.ModuleState import ModuleState
44
+ obj = ModuleState()
45
+ obj.Init(self._tab.Bytes, x)
46
+ return obj
47
+ return None
48
+
49
+ # Checkpoint
50
+ def OptimizerGroups(self, j):
51
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
52
+ if o != 0:
53
+ x = self._tab.Vector(o)
54
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
55
+ x = self._tab.Indirect(x)
56
+ from ort_flatbuffers_py.fbs.OptimizerGroup import OptimizerGroup
57
+ obj = OptimizerGroup()
58
+ obj.Init(self._tab.Bytes, x)
59
+ return obj
60
+ return None
61
+
62
+ # Checkpoint
63
+ def OptimizerGroupsLength(self):
64
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
65
+ if o != 0:
66
+ return self._tab.VectorLen(o)
67
+ return 0
68
+
69
+ # Checkpoint
70
+ def OptimizerGroupsIsNone(self):
71
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
72
+ return o == 0
73
+
74
+ # Checkpoint
75
+ def PropertyBag(self):
76
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
77
+ if o != 0:
78
+ x = self._tab.Indirect(o + self._tab.Pos)
79
+ from ort_flatbuffers_py.fbs.PropertyBag import PropertyBag
80
+ obj = PropertyBag()
81
+ obj.Init(self._tab.Bytes, x)
82
+ return obj
83
+ return None
84
+
85
+ def CheckpointStart(builder):
86
+ builder.StartObject(4)
87
+
88
+ def Start(builder):
89
+ CheckpointStart(builder)
90
+
91
+ def CheckpointAddVersion(builder, version):
92
+ builder.PrependInt32Slot(0, version, 0)
93
+
94
+ def AddVersion(builder, version):
95
+ CheckpointAddVersion(builder, version)
96
+
97
+ def CheckpointAddModuleState(builder, moduleState):
98
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(moduleState), 0)
99
+
100
+ def AddModuleState(builder, moduleState):
101
+ CheckpointAddModuleState(builder, moduleState)
102
+
103
+ def CheckpointAddOptimizerGroups(builder, optimizerGroups):
104
+ builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerGroups), 0)
105
+
106
+ def AddOptimizerGroups(builder, optimizerGroups):
107
+ CheckpointAddOptimizerGroups(builder, optimizerGroups)
108
+
109
+ def CheckpointStartOptimizerGroupsVector(builder, numElems):
110
+ return builder.StartVector(4, numElems, 4)
111
+
112
+ def StartOptimizerGroupsVector(builder, numElems: int) -> int:
113
+ return CheckpointStartOptimizerGroupsVector(builder, numElems)
114
+
115
+ def CheckpointAddPropertyBag(builder, propertyBag):
116
+ builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(propertyBag), 0)
117
+
118
+ def AddPropertyBag(builder, propertyBag):
119
+ CheckpointAddPropertyBag(builder, propertyBag)
120
+
121
+ def CheckpointEnd(builder):
122
+ return builder.EndObject()
123
+
124
+ def End(builder):
125
+ return CheckpointEnd(builder)
@@ -0,0 +1,120 @@
1
+ # automatically generated by the FlatBuffers compiler, do not modify
2
+
3
+ # namespace: fbs
4
+
5
+ import flatbuffers
6
+ from flatbuffers.compat import import_numpy
7
+ np = import_numpy()
8
+
9
+ # deprecated: no longer using kernel def hashes
10
+ class DeprecatedKernelCreateInfos(object):
11
+ __slots__ = ['_tab']
12
+
13
+ @classmethod
14
+ def GetRootAs(cls, buf, offset=0):
15
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
16
+ x = DeprecatedKernelCreateInfos()
17
+ x.Init(buf, n + offset)
18
+ return x
19
+
20
+ @classmethod
21
+ def GetRootAsDeprecatedKernelCreateInfos(cls, buf, offset=0):
22
+ """This method is deprecated. Please switch to GetRootAs."""
23
+ return cls.GetRootAs(buf, offset)
24
+ @classmethod
25
+ def DeprecatedKernelCreateInfosBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
26
+ return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x52\x54\x4D", size_prefixed=size_prefixed)
27
+
28
+ # DeprecatedKernelCreateInfos
29
+ def Init(self, buf, pos):
30
+ self._tab = flatbuffers.table.Table(buf, pos)
31
+
32
+ # DeprecatedKernelCreateInfos
33
+ def NodeIndices(self, j):
34
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
35
+ if o != 0:
36
+ a = self._tab.Vector(o)
37
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
38
+ return 0
39
+
40
+ # DeprecatedKernelCreateInfos
41
+ def NodeIndicesAsNumpy(self):
42
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
43
+ if o != 0:
44
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
45
+ return 0
46
+
47
+ # DeprecatedKernelCreateInfos
48
+ def NodeIndicesLength(self):
49
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
50
+ if o != 0:
51
+ return self._tab.VectorLen(o)
52
+ return 0
53
+
54
+ # DeprecatedKernelCreateInfos
55
+ def NodeIndicesIsNone(self):
56
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
57
+ return o == 0
58
+
59
+ # DeprecatedKernelCreateInfos
60
+ def KernelDefHashes(self, j):
61
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
62
+ if o != 0:
63
+ a = self._tab.Vector(o)
64
+ return self._tab.Get(flatbuffers.number_types.Uint64Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 8))
65
+ return 0
66
+
67
+ # DeprecatedKernelCreateInfos
68
+ def KernelDefHashesAsNumpy(self):
69
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
70
+ if o != 0:
71
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint64Flags, o)
72
+ return 0
73
+
74
+ # DeprecatedKernelCreateInfos
75
+ def KernelDefHashesLength(self):
76
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
77
+ if o != 0:
78
+ return self._tab.VectorLen(o)
79
+ return 0
80
+
81
+ # DeprecatedKernelCreateInfos
82
+ def KernelDefHashesIsNone(self):
83
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
84
+ return o == 0
85
+
86
+ def DeprecatedKernelCreateInfosStart(builder):
87
+ builder.StartObject(2)
88
+
89
+ def Start(builder):
90
+ DeprecatedKernelCreateInfosStart(builder)
91
+
92
+ def DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices):
93
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(nodeIndices), 0)
94
+
95
+ def AddNodeIndices(builder, nodeIndices):
96
+ DeprecatedKernelCreateInfosAddNodeIndices(builder, nodeIndices)
97
+
98
+ def DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems):
99
+ return builder.StartVector(4, numElems, 4)
100
+
101
+ def StartNodeIndicesVector(builder, numElems: int) -> int:
102
+ return DeprecatedKernelCreateInfosStartNodeIndicesVector(builder, numElems)
103
+
104
+ def DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes):
105
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(kernelDefHashes), 0)
106
+
107
+ def AddKernelDefHashes(builder, kernelDefHashes):
108
+ DeprecatedKernelCreateInfosAddKernelDefHashes(builder, kernelDefHashes)
109
+
110
+ def DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems):
111
+ return builder.StartVector(8, numElems, 8)
112
+
113
+ def StartKernelDefHashesVector(builder, numElems: int) -> int:
114
+ return DeprecatedKernelCreateInfosStartKernelDefHashesVector(builder, numElems)
115
+
116
+ def DeprecatedKernelCreateInfosEnd(builder):
117
+ return builder.EndObject()
118
+
119
+ def End(builder):
120
+ return DeprecatedKernelCreateInfosEnd(builder)