onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1027 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import subprocess
13
+ import sys
14
+ import tempfile
15
+ from itertools import chain
16
+
17
+ import onnx
18
+ import torch
19
+ from benchmark_helper import Precision, prepare_environment, setup_logger
20
+ from convert_generation import replace_mha_with_gqa
21
+ from dist_settings import barrier, get_rank, get_size, init_dist
22
+ from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs
23
+ from llama_parity import main as parity_check
24
+ from llama_torch import setup_torch_model
25
+ from onnx_model import OnnxModel
26
+ from optimizer import optimize_model
27
+ from packaging import version
28
+ from transformers import AutoConfig, AutoModelForCausalLM
29
+
30
+ from onnxruntime import quantization as ort_quantization
31
+ from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
32
+
33
+ torch_export_onnx_opset_version = 14
34
+ logger = logging.getLogger("")
35
+ init_dist()
36
+
37
+
38
+ def get_model_dynamic_axes(input_names: list[str], output_names: list[str]):
39
+ dynamic_axes = {}
40
+ for name in input_names + output_names:
41
+ if name in input_names:
42
+ # shape is (batch_size, sequence_length)
43
+ dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
44
+ elif name == "logits":
45
+ # shape is (batch_size, sequence_length, vocab_size)
46
+ dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
47
+ elif "present" in name:
48
+ # shape is (batch_size, num_heads, sequence_length, head_size)
49
+ dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
50
+ else:
51
+ raise Exception("Unknown input or output name found")
52
+ return dynamic_axes
53
+
54
+
55
+ def get_model_with_past_kv_dynamic_axes(input_names: list[str], output_names: list[str]):
56
+ dynamic_axes = {}
57
+ for name in input_names + output_names:
58
+ if name in {"input_ids", "position_ids"}:
59
+ # shape is (batch_size, 1)
60
+ dynamic_axes[name] = {0: "batch_size"}
61
+ elif name == "attention_mask":
62
+ # shape is (batch_size, past_sequence_length + 1)
63
+ dynamic_axes[name] = {0: "batch_size", 1: "past_sequence_length + 1"}
64
+ elif "past" in name:
65
+ # shape is (batch_size, num_heads, past_sequence_length, head_size)
66
+ dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
67
+ elif name == "logits":
68
+ # shape is (batch_size, 1, vocab_size)
69
+ dynamic_axes[name] = {0: "batch_size"}
70
+ elif "present" in name:
71
+ # shape is (batch_size, num_heads, past_sequence_length + 1, head_size)
72
+ dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length + 1"}
73
+ else:
74
+ raise Exception("Unknown input or output name found")
75
+ return dynamic_axes
76
+
77
+
78
+ def get_merged_model_dynamic_axes(input_names: list[str], output_names: list[str]):
79
+ dynamic_axes = {}
80
+ for name in input_names + output_names:
81
+ if name in {"input_ids", "position_ids"}:
82
+ # shape is (batch_size, sequence_length)
83
+ dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
84
+ elif name == "attention_mask":
85
+ # shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length)
86
+ # for prompt generation, past_sequence_length = 0
87
+ # for token generation, sequence_length = 1
88
+ dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"}
89
+ elif "past" in name:
90
+ # shape is (batch_size, num_heads, past_sequence_length, head_size)
91
+ dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
92
+ elif name == "logits":
93
+ # shape is (batch_size, sequence_length, vocab_size)
94
+ dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
95
+ elif "present" in name:
96
+ # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size) = (batch_size, num_heads, total_sequence_length, head_size)
97
+ # for prompt generation, past_sequence_length = 0
98
+ # for token generation, sequence_length = 1
99
+ dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
100
+ else:
101
+ raise Exception("Unknown input or output name found")
102
+ return dynamic_axes
103
+
104
+
105
+ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: str):
106
+ onnx.save(
107
+ onnx_model,
108
+ output_path,
109
+ save_as_external_data=True,
110
+ all_tensors_to_one_file=True,
111
+ location=data_path,
112
+ size_threshold=1024,
113
+ convert_attribute=False,
114
+ )
115
+
116
+
117
+ def run_dynamo_export(
118
+ args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1
119
+ ):
120
+ from torch._dynamo import config
121
+
122
+ config.capture_scalar_outputs = True
123
+
124
+ # Dummy values for export
125
+ batch_size, sequence_length, past_sequence_length = 2, 8, 0
126
+ device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
127
+
128
+ temp_name = args.model_name.lower().replace("-", "").replace("_", "")
129
+ max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
130
+
131
+ # Export decoder_with_past_model.onnx
132
+ input_ids, attn_mask, pos_ids, past_kv = get_merged_sample_with_past_kv_inputs(
133
+ l_config,
134
+ device,
135
+ batch_size,
136
+ sequence_length,
137
+ past_sequence_length,
138
+ max_seq_len=max_sequence_length,
139
+ use_fp16=False,
140
+ world_size=world_size,
141
+ )
142
+ temp_dir = tempfile.TemporaryDirectory()
143
+ temp_path = os.path.join(temp_dir.name, "temp.onnx")
144
+ torch.onnx.dynamo_export(
145
+ llama, input_ids, attn_mask, pos_ids, past_kv, export_options=torch.onnx.ExportOptions(dynamic_shapes=True)
146
+ ).save(temp_path)
147
+
148
+ # Check decoder_with_past_model.onnx and save all external data to one file
149
+ onnx.checker.check_model(temp_path)
150
+ onnx.shape_inference.infer_shapes_path(temp_path)
151
+
152
+ output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx")
153
+ onnx_model = onnx.load_model(temp_path, load_external_data=True)
154
+ save_onnx_model(onnx_model, output_path, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data")
155
+ del onnx_model
156
+ temp_dir.cleanup()
157
+
158
+ logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!")
159
+
160
+
161
+ def _prepare_dir(dir_path):
162
+ if not os.path.exists(dir_path):
163
+ os.makedirs(dir_path)
164
+
165
+
166
+ def run_torchscript_separate_export(
167
+ args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1
168
+ ):
169
+ # Dummy values for export
170
+ batch_size, sequence_length = 2, 8
171
+
172
+ # set device used to export model
173
+ # for llama-2-70b we will use current gpus to speed up export process
174
+ # for other models, we will use CPU to make sure we have enough memory to do export
175
+ device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
176
+
177
+ # Export decoder_model.onnx
178
+ decoder_inputs = get_sample_inputs(l_config, device, batch_size, sequence_length)
179
+
180
+ input_names = ["input_ids", "attention_mask", "position_ids"]
181
+ output_names = [
182
+ "logits",
183
+ *list(
184
+ chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
185
+ ),
186
+ ]
187
+ dynamic_axes = get_model_dynamic_axes(input_names, output_names)
188
+
189
+ # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
190
+ # Use temp folder per rank to avoid race condition here.
191
+ temp_dir = f"./temp_{rank}"
192
+ _prepare_dir(temp_dir)
193
+ temp_path = os.path.join(temp_dir, "temp.onnx")
194
+ torch.onnx.export(
195
+ llama,
196
+ args=decoder_inputs,
197
+ f=temp_path,
198
+ export_params=True,
199
+ input_names=input_names,
200
+ output_names=output_names,
201
+ dynamic_axes=dynamic_axes,
202
+ opset_version=torch_export_onnx_opset_version,
203
+ do_constant_folding=True,
204
+ verbose=args.verbose,
205
+ )
206
+
207
+ # Check decoder_model.onnx and save all external data to one file
208
+ onnx.checker.check_model(temp_path)
209
+ onnx.shape_inference.infer_shapes_path(temp_path)
210
+
211
+ output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx")
212
+ onnx_model = onnx.load_model(temp_path, load_external_data=True)
213
+ save_onnx_model(
214
+ onnx_model,
215
+ output_path,
216
+ f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx.data",
217
+ )
218
+ del onnx_model
219
+ shutil.rmtree(temp_dir)
220
+
221
+ # Export decoder_with_past_model.onnx
222
+ decoder_with_past_inputs = get_sample_with_past_kv_inputs(
223
+ l_config,
224
+ device,
225
+ batch_size,
226
+ sequence_length,
227
+ use_fp16=False,
228
+ world_size=world_size,
229
+ )
230
+ input_names = [
231
+ "input_ids",
232
+ "attention_mask",
233
+ "position_ids",
234
+ *list(
235
+ chain.from_iterable(
236
+ (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers)
237
+ )
238
+ ),
239
+ ]
240
+ output_names = [
241
+ "logits",
242
+ *list(
243
+ chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
244
+ ),
245
+ ]
246
+ dynamic_axes = get_model_with_past_kv_dynamic_axes(input_names, output_names)
247
+
248
+ # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
249
+ # Use temp folder per rank to avoid race condition here.
250
+ temp_dir = f"./temp_past_{rank}"
251
+ _prepare_dir(temp_dir)
252
+ temp_path = os.path.join(temp_dir, "temp.onnx")
253
+ torch.onnx.export(
254
+ llama,
255
+ args=decoder_with_past_inputs,
256
+ f=temp_path,
257
+ export_params=True,
258
+ input_names=input_names,
259
+ output_names=output_names,
260
+ dynamic_axes=dynamic_axes,
261
+ opset_version=torch_export_onnx_opset_version,
262
+ do_constant_folding=True,
263
+ verbose=args.verbose,
264
+ )
265
+
266
+ # Check decoder_with_past_model.onnx and save all external data to one file
267
+ onnx.checker.check_model(temp_path)
268
+ onnx.shape_inference.infer_shapes_path(temp_path)
269
+
270
+ output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx")
271
+ onnx_model = onnx.load_model(temp_path, load_external_data=True)
272
+ save_onnx_model(
273
+ onnx_model,
274
+ output_path,
275
+ f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx.data",
276
+ )
277
+ del onnx_model
278
+ shutil.rmtree(temp_dir)
279
+
280
+ logger.info(
281
+ f"The {args.model_name} separate ONNX model has been successfully created with the TorchScript exporter!"
282
+ )
283
+
284
+
285
+ def run_torchscript_merged_export(
286
+ args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1
287
+ ):
288
+ # Dummy values for export
289
+ batch_size, sequence_length, past_sequence_length = 2, 8, 0
290
+
291
+ # set device used to export model
292
+ # for llama-2-70b we will use current gpus to speed up export process
293
+ # for other models, we will use CPU to make sure we have enough memory to do export
294
+ device = llama.device if args.model_name == "Llama-2-70b-hf" else torch.device("cpu")
295
+
296
+ temp_name = args.model_name.lower().replace("-", "").replace("_", "")
297
+ max_sequence_length = 16384 if "codellama" in temp_name else 4096 if "llama2" in temp_name else 2048
298
+
299
+ # Export decoder_merged_model.onnx
300
+ decoder_merged_inputs = get_merged_sample_with_past_kv_inputs(
301
+ l_config,
302
+ device,
303
+ batch_size,
304
+ sequence_length,
305
+ past_sequence_length,
306
+ max_seq_len=max_sequence_length,
307
+ use_fp16=False,
308
+ world_size=world_size,
309
+ )
310
+ input_names = [
311
+ "input_ids",
312
+ "attention_mask",
313
+ "position_ids",
314
+ *list(
315
+ chain.from_iterable(
316
+ (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers)
317
+ )
318
+ ),
319
+ ]
320
+ output_names = [
321
+ "logits",
322
+ *list(
323
+ chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers))
324
+ ),
325
+ ]
326
+ dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names)
327
+
328
+ # Avoid using system temp dir to avoid overflood on hard disk as 70b model is very large.
329
+ # Use temp folder per rank to avoid race condition here.
330
+ temp_dir = f"./temp_{rank}"
331
+ _prepare_dir(temp_dir)
332
+ temp_path = os.path.join(temp_dir, "temp.onnx")
333
+ torch.onnx.export(
334
+ llama,
335
+ args=decoder_merged_inputs,
336
+ f=temp_path,
337
+ export_params=True,
338
+ input_names=input_names,
339
+ output_names=output_names,
340
+ dynamic_axes=dynamic_axes,
341
+ opset_version=torch_export_onnx_opset_version,
342
+ do_constant_folding=True,
343
+ verbose=args.verbose,
344
+ )
345
+
346
+ # Check decoder_merged_model.onnx and save all external data to one file
347
+ onnx.checker.check_model(temp_path)
348
+ onnx.shape_inference.infer_shapes_path(temp_path)
349
+
350
+ output_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx")
351
+ onnx_model = onnx.load_model(temp_path, load_external_data=True)
352
+ save_onnx_model(
353
+ onnx_model,
354
+ output_path,
355
+ f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx.data",
356
+ )
357
+ del onnx_model
358
+ shutil.rmtree(temp_dir)
359
+
360
+ logger.info(f"The {args.model_name} merged ONNX model has been successfully created with the TorchScript exporter!")
361
+
362
+
363
+ # Optimize the model as FP32
364
+ def optimize_export(
365
+ args: argparse.Namespace,
366
+ config: AutoConfig,
367
+ input_path: str,
368
+ output_path: str,
369
+ remove_model: bool = True,
370
+ world_size: int = 1,
371
+ window_size: int = -1,
372
+ ):
373
+ from fusion_options import FusionOptions
374
+
375
+ optimization_options = FusionOptions("gpt2")
376
+
377
+ model_opt = optimize_model(
378
+ input_path,
379
+ model_type="gpt2",
380
+ num_heads=config.num_attention_heads,
381
+ hidden_size=config.hidden_size,
382
+ opt_level=0,
383
+ optimization_options=optimization_options,
384
+ only_onnxruntime=False,
385
+ )
386
+ if args.use_gqa:
387
+ model_opt = use_group_query_attention(config, model_opt, world_size, window_size)
388
+ model_opt.save_model_to_file(output_path, use_external_data_format=True)
389
+
390
+ # Run symbolic shape inference on optimized model to avoid shape errors during runtime
391
+ # Ex: Before attention fusion, RotaryEmbedding assumes a 4D input and produces a 4D output.
392
+ # After attention fusion, RotaryEmbedding expects a 3D input and produces a 3D output.
393
+ wheel_cmd = [sys.executable, "-m", "onnxruntime.tools.symbolic_shape_infer"]
394
+ source_cmd = [sys.executable, "../symbolic_shape_infer.py"]
395
+ symbolic_shape_infer_args = [
396
+ "--input",
397
+ output_path,
398
+ "--output",
399
+ output_path,
400
+ "--auto_merge",
401
+ "--save_as_external_data",
402
+ "--all_tensors_to_one_file",
403
+ "--external_data_location",
404
+ os.path.basename(output_path) + ".data",
405
+ ]
406
+
407
+ file_path = os.path.dirname(__file__)
408
+ if os.path.exists(os.path.join(file_path, "../../../tools/symbolic_shape_infer.py")):
409
+ main_cmd = wheel_cmd
410
+ else:
411
+ main_cmd = source_cmd
412
+ subprocess.run(main_cmd + symbolic_shape_infer_args) # noqa: PLW1510
413
+
414
+ logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!")
415
+ if remove_model:
416
+ remove_existing_model(input_path)
417
+
418
+
419
+ def convert_to_float16(args: argparse.Namespace, old_paths: list[str], rank: int = 0):
420
+ decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx")
421
+ decoder_with_past_model_fp16_path = os.path.join(
422
+ args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp16.onnx"
423
+ )
424
+ decoder_merged_model_fp16_path = os.path.join(
425
+ args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp16.onnx"
426
+ )
427
+ new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path]
428
+
429
+ logger.info("Converting to float16...")
430
+ for fp32_path, fp16_path in zip(old_paths, new_paths):
431
+ if os.path.exists(fp32_path):
432
+ model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True))
433
+ model.convert_float_to_float16(keep_io_types=False)
434
+ model.save_model_to_file(fp16_path, use_external_data_format=True)
435
+ del model
436
+ logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!")
437
+ remove_existing_model(fp32_path)
438
+
439
+ logger.info(f"The {args.model_name} ONNX model has been successfully converted to float16!")
440
+ return new_paths
441
+
442
+
443
+ def use_group_query_attention(config: AutoConfig, model_opt: OnnxModel, world_size: int = 1, window_size: int = -1):
444
+ # Replace MultiHeadAttention with GroupQueryAttention
445
+ model_opt = replace_mha_with_gqa(model_opt, "attention_mask", config.num_key_value_heads, world_size, window_size)
446
+ model_opt.prune_graph()
447
+ model_opt.update_graph(allow_remove_graph_inputs=True)
448
+ return model_opt
449
+
450
+
451
+ def smooth_quant(
452
+ args: argparse.Namespace,
453
+ decoder_model_fp32_path: str,
454
+ decoder_with_past_model_fp32_path: str,
455
+ decoder_model_int8_path: str,
456
+ decoder_with_past_model_int8_path: str,
457
+ ):
458
+ from neural_compressor import PostTrainingQuantConfig
459
+ from neural_compressor import quantization as intel_quantization
460
+ from neural_compressor import set_workspace
461
+ from onnx.external_data_helper import load_external_data_for_model
462
+ from quant_kv_dataloader import QuantKVDataLoader
463
+
464
+ set_workspace(args.nc_workspace)
465
+ quantization_config = PostTrainingQuantConfig(
466
+ calibration_sampling_size=[args.calibration_sampling_size],
467
+ recipes={
468
+ "optypes_to_exclude_output_quant": ["MatMul"],
469
+ "smooth_quant": True,
470
+ "smooth_quant_args": {"alpha": args.smooth_quant_alpha},
471
+ },
472
+ op_type_dict={
473
+ "^((?!(MatMul|Gather|Conv)).)*$": {
474
+ "weight": {"dtype": ["fp32"]},
475
+ "activation": {"dtype": ["fp32"]},
476
+ }
477
+ },
478
+ )
479
+
480
+ # Convert decoder_model.onnx to INT8
481
+ decoder_model_int8 = intel_quantization.fit(
482
+ decoder_model_fp32_path,
483
+ quantization_config,
484
+ calib_dataloader=QuantKVDataLoader(args),
485
+ )
486
+ load_external_data_for_model(
487
+ decoder_model_int8._model,
488
+ os.path.split(decoder_model_int8._model_path)[0],
489
+ )
490
+ save_onnx_model(
491
+ decoder_model_int8._model,
492
+ decoder_model_int8_path,
493
+ f"{args.model_name}_decoder_model_int8.onnx.data",
494
+ )
495
+ del decoder_model_int8
496
+ logger.info(
497
+ f"The ONNX model at {decoder_model_fp32_path} has been quantized to int8 and saved at {decoder_model_int8_path}!"
498
+ )
499
+ remove_existing_model(decoder_model_fp32_path)
500
+
501
+ # Convert decoder_with_past_model.onnx to INT8
502
+ decoder_with_past_model_int8 = intel_quantization.fit(
503
+ decoder_with_past_model_fp32_path,
504
+ quantization_config,
505
+ calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path),
506
+ )
507
+ load_external_data_for_model(
508
+ decoder_with_past_model_int8._model,
509
+ os.path.split(decoder_with_past_model_int8._model_path)[0],
510
+ )
511
+ save_onnx_model(
512
+ decoder_with_past_model_int8._model,
513
+ decoder_with_past_model_int8_path,
514
+ f"{args.model_name}_decoder_with_past_model_int8.onnx.data",
515
+ )
516
+ del decoder_with_past_model_int8
517
+ logger.info(
518
+ f"The ONNX model at {decoder_with_past_model_fp32_path} has been quantized to int8 and saved at {decoder_with_past_model_int8_path}!"
519
+ )
520
+ remove_existing_model(decoder_with_past_model_fp32_path)
521
+
522
+ logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
523
+
524
+ logger.warning(f"Removing {args.nc_workspace}")
525
+ shutil.rmtree(args.nc_workspace)
526
+
527
+
528
+ def remove_existing_model(model_path: str):
529
+ # Remove ONNX model and its external data
530
+ data_path = os.path.join(model_path + ".data")
531
+ os.remove(model_path)
532
+ os.remove(data_path)
533
+ logger.warning(f"Removed {model_path} and {data_path}")
534
+
535
+
536
+ def remove_existing_files(output_path: str):
537
+ for filename in os.listdir(output_path):
538
+ filepath = os.path.join(output_path, filename)
539
+ if ".onnx" in filename or ".onnx.data" in filename:
540
+ os.remove(filepath)
541
+ logger.warning(f"Removed {filepath}")
542
+
543
+
544
+ def optimize_optimum(config: AutoConfig, args: argparse.Namespace):
545
+ tmp_file = os.path.join(args.output, args.model_name + ".tmp.onnx")
546
+ output_file = os.path.join(args.output, args.model_name + ".onnx")
547
+ window_size = -1 if not hasattr(config, "sliding_window") else config.sliding_window
548
+ optimize_export(args, config, args.input, tmp_file, remove_model=False, window_size=window_size)
549
+ logger.info(f"Model successfully optimized to {tmp_file}")
550
+ opt_model = OnnxModel(onnx.load_model(tmp_file, load_external_data=True))
551
+ if args.precision == Precision.FLOAT16:
552
+ opt_model.convert_float_to_float16(keep_io_types=False)
553
+ logger.info("Model successfully fused and quantized to FP16!")
554
+ opt_model.save_model_to_file(output_file, use_external_data_format=True)
555
+ logger.info(f"Output model successfully saved to {output_file}")
556
+ logger.info(f"Removing {tmp_file}")
557
+ remove_existing_model(tmp_file)
558
+
559
+
560
+ def get_args():
561
+ parser = argparse.ArgumentParser()
562
+
563
+ parser.add_argument(
564
+ "-m",
565
+ "--model_name",
566
+ required=True,
567
+ help="Model name in Hugging Face",
568
+ )
569
+
570
+ parser.add_argument(
571
+ "-i",
572
+ "--input",
573
+ required=False,
574
+ default=os.path.join("."),
575
+ help="Directory path to PyTorch model and associated files if saved on disk, or ONNX model file location if optimize_optimum is passed.",
576
+ )
577
+
578
+ parser.add_argument(
579
+ "-o",
580
+ "--output",
581
+ required=False,
582
+ default=os.path.join(".", "llama_onnx_models"),
583
+ help="Directory path to save exported model files in",
584
+ )
585
+
586
+ parser.add_argument(
587
+ "-p",
588
+ "--precision",
589
+ required=False,
590
+ type=Precision,
591
+ default=Precision.FLOAT32,
592
+ choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4],
593
+ help="Precision to export model in",
594
+ )
595
+
596
+ parser.add_argument(
597
+ "-e",
598
+ "--execution_provider",
599
+ required=False,
600
+ default="cpu",
601
+ choices=["cpu", "cuda", "rocm"],
602
+ help="Execution provider to verify parity with",
603
+ )
604
+
605
+ parser.add_argument(
606
+ "-r",
607
+ "--reexport",
608
+ required=False,
609
+ action="store_true",
610
+ help="Re-export models and overwrite existing models in output folder",
611
+ )
612
+ parser.set_defaults(reexport=False)
613
+
614
+ parser.add_argument(
615
+ "--use_gqa",
616
+ required=False,
617
+ action="store_true",
618
+ help="Use GroupQueryAttention instead of MultiHeadAttention",
619
+ )
620
+ parser.set_defaults(use_gqa=False)
621
+
622
+ parser.add_argument(
623
+ "--no_merged",
624
+ required=False,
625
+ action="store_true",
626
+ help="Export models into 2 ONNX files instead of 1. Deprecated in favor of exporting into 1 ONNX file.",
627
+ )
628
+ parser.set_defaults(no_merged=False)
629
+
630
+ parser.add_argument(
631
+ "-q",
632
+ "--quantization_method",
633
+ default="",
634
+ choices=["blockwise", "smooth_quant", "quantize_dynamic"],
635
+ help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.",
636
+ )
637
+
638
+ blockwise_group = parser.add_argument_group("blockwise (4-bit quantization)")
639
+
640
+ blockwise_group.add_argument(
641
+ "--block_size",
642
+ required=False,
643
+ default=32,
644
+ type=int,
645
+ help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
646
+ )
647
+
648
+ blockwise_group.add_argument(
649
+ "--int4_accuracy_level",
650
+ required=False,
651
+ type=int,
652
+ help="Accuracy level of the 4-bit quantized MatMul computation. "
653
+ "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details "
654
+ "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).",
655
+ )
656
+
657
+ smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)")
658
+
659
+ smooth_quant_group.add_argument(
660
+ "--smooth_quant_alpha",
661
+ required=False,
662
+ default=0.8,
663
+ type=float,
664
+ help="Strength to control migration difficulty from activation to weights. Default is 0.8 to match value \
665
+ used in original paper for LLaMA. Paper recommends using values in [0.4, 0.6] range. \
666
+ Link to paper: https://arxiv.org/pdf/2211.10438.pdf",
667
+ )
668
+
669
+ smooth_quant_group.add_argument(
670
+ "--smooth_quant_dataset",
671
+ required=False,
672
+ default="NeelNanda/pile-10k",
673
+ help="Path to dataset for calibration during quantization",
674
+ )
675
+
676
+ smooth_quant_group.add_argument(
677
+ "--pad_max",
678
+ required=False,
679
+ default=196,
680
+ type=int,
681
+ help="Max padding size",
682
+ )
683
+
684
+ smooth_quant_group.add_argument(
685
+ "--calibration_sampling_size",
686
+ required=False,
687
+ type=int,
688
+ default=8,
689
+ help="Calibration sampling size for quantization config",
690
+ )
691
+
692
+ smooth_quant_group.add_argument(
693
+ "--nc_workspace",
694
+ required=False,
695
+ type=str,
696
+ default=os.path.join(".", "nc_workspace"),
697
+ help="Workspace to save intermediate files generated by Intel's Neural Compressor package.",
698
+ )
699
+
700
+ quantize_dynamic_group = parser.add_argument_group("quantize_dynamic (8-bit quantization)")
701
+
702
+ quantize_dynamic_group.add_argument(
703
+ "--quantize_embedding_layer",
704
+ required=False,
705
+ action="store_true",
706
+ help="Quantize MatMul, GEMM, and Gather.",
707
+ )
708
+ quantize_dynamic_group.set_defaults(quantize_embedding_layer=False)
709
+
710
+ quantize_dynamic_group.add_argument(
711
+ "--quantize_per_channel",
712
+ required=False,
713
+ action="store_true",
714
+ help="Quantize weights per each channel.",
715
+ )
716
+ quantize_dynamic_group.set_defaults(quantize_per_channel=False)
717
+
718
+ quantize_dynamic_group.add_argument(
719
+ "--quantize_reduce_range",
720
+ required=False,
721
+ action="store_true",
722
+ help="Quantize weights with 7 bits.",
723
+ )
724
+ quantize_dynamic_group.set_defaults(quantize_reduce_range=False)
725
+
726
+ parser.add_argument(
727
+ "-v",
728
+ "--verbose",
729
+ action="store_true",
730
+ help="Print verbose logs",
731
+ )
732
+ parser.set_defaults(verbose=False)
733
+
734
+ parser.add_argument(
735
+ "-d",
736
+ "--use_dynamo_export",
737
+ action="store_true",
738
+ help="Use the new Dynamo exporter instead of the old TorchScript exporter",
739
+ )
740
+ parser.set_defaults(use_dynamo_export=False)
741
+
742
+ parser.add_argument(
743
+ "--cache_dir",
744
+ required=False,
745
+ type=str,
746
+ default="./model_cache",
747
+ help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
748
+ )
749
+
750
+ parser.add_argument(
751
+ "--optimize_optimum",
752
+ action="store_true",
753
+ help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.",
754
+ )
755
+
756
+ parser.add_argument(
757
+ "--small_gpu",
758
+ action="store_true",
759
+ help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB.",
760
+ )
761
+
762
+ parser.set_defaults(optimize_optimum=False)
763
+
764
+ args = parser.parse_args()
765
+ return args
766
+
767
+
768
+ def main():
769
+ if version.parse(torch.__version__) < version.parse("2.2.0"):
770
+ logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.")
771
+ return
772
+
773
+ args = get_args()
774
+ setup_logger(args.verbose)
775
+ prepare_environment(args.input, args.output, args.execution_provider != "cpu")
776
+ if args.reexport:
777
+ remove_existing_files(args.output)
778
+ logger.info(f"Arguments: {args}")
779
+
780
+ world_size = get_size()
781
+ rank = get_rank()
782
+ args.world_size = world_size
783
+
784
+ # Load model and config
785
+ use_auth_token = args.input == os.path.join(".")
786
+ setattr(args, "use_auth_token", use_auth_token) # noqa: B010
787
+
788
+ original_model_name = args.model_name
789
+ setattr(args, "original_model_name", original_model_name) # noqa: B010
790
+ args.model_name = args.model_name.split("/")[-1]
791
+
792
+ setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
793
+ setattr(args, "device", torch.device(args.device_name)) # noqa: B010
794
+
795
+ location = args.original_model_name if use_auth_token else args.input
796
+
797
+ if args.optimize_optimum:
798
+ config = AutoConfig.from_pretrained(args.original_model_name, cache_dir=args.cache_dir)
799
+ optimize_optimum(config, args)
800
+ return
801
+
802
+ # Use CUDA for LLaMA-2-70B to speed up export and CPU for other models
803
+ l_config, llama = setup_torch_model(
804
+ args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None
805
+ )
806
+
807
+ assert l_config.num_attention_heads % world_size == 0 and l_config.num_key_value_heads % world_size == 0
808
+
809
+ barrier()
810
+ for i in range(world_size):
811
+ if i == rank:
812
+ # Set model paths for FP32 model
813
+ decoder_model_fp32_path = os.path.join(
814
+ args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32.onnx"
815
+ )
816
+ decoder_with_past_model_fp32_path = os.path.join(
817
+ args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32.onnx"
818
+ )
819
+ decoder_merged_model_fp32_path = os.path.join(
820
+ args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32.onnx"
821
+ )
822
+ old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
823
+
824
+ missing_separate_exports = (
825
+ args.no_merged
826
+ and not os.path.exists(decoder_model_fp32_path)
827
+ and not os.path.exists(decoder_with_past_model_fp32_path)
828
+ )
829
+ missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path)
830
+
831
+ # Export to ONNX
832
+ if missing_separate_exports or missing_merged_export:
833
+ if args.use_dynamo_export:
834
+ logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.")
835
+ logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/")
836
+ logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/")
837
+ logger.warning(
838
+ "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script"
839
+ )
840
+ logger.warning(
841
+ "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step."
842
+ )
843
+ run_dynamo_export(args, l_config, llama)
844
+ elif args.no_merged:
845
+ run_torchscript_separate_export(args, l_config, llama, rank, world_size)
846
+ else:
847
+ run_torchscript_merged_export(args, l_config, llama, rank, world_size)
848
+ del llama # Delete LLaMA model from memory since it will be loaded again during parity check
849
+
850
+ # Set model paths to store FP32 optimized model
851
+ decoder_model_fp32_opt_path = os.path.join(
852
+ args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp32_opt.onnx"
853
+ )
854
+ decoder_with_past_model_fp32_opt_path = os.path.join(
855
+ args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_fp32_opt.onnx"
856
+ )
857
+ decoder_merged_model_fp32_opt_path = os.path.join(
858
+ args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_fp32_opt.onnx"
859
+ )
860
+ new_paths = [
861
+ decoder_model_fp32_opt_path,
862
+ decoder_with_past_model_fp32_opt_path,
863
+ decoder_merged_model_fp32_opt_path,
864
+ ]
865
+
866
+ if args.use_dynamo_export:
867
+ continue
868
+
869
+ # Run the optimizer script.
870
+ logger.info("Optimizing models...")
871
+ for orig_path, opt_path in zip(old_paths, new_paths):
872
+ if os.path.exists(orig_path):
873
+ optimize_export(args, l_config, input_path=orig_path, output_path=opt_path, world_size=world_size)
874
+
875
+ # Re-assign default FP32 model paths as their optimized versions
876
+ decoder_model_fp32_path = decoder_model_fp32_opt_path
877
+ decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path
878
+ decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path
879
+ old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path]
880
+
881
+ logger.info(
882
+ f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!"
883
+ )
884
+
885
+ # Change precision of exported models from FP32
886
+ if args.precision == Precision.FLOAT16:
887
+ new_paths = convert_to_float16(args, old_paths, rank)
888
+
889
+ elif args.precision == Precision.INT8:
890
+ decoder_model_int8_path = os.path.join(
891
+ args.output, f"rank_{rank}_{args.model_name}_decoder_model_int8.onnx"
892
+ )
893
+ decoder_with_past_model_int8_path = os.path.join(
894
+ args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int8.onnx"
895
+ )
896
+ decoder_merged_model_int8_path = os.path.join(
897
+ args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int8.onnx"
898
+ )
899
+ new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path]
900
+
901
+ if args.quantization_method == "smooth_quant":
902
+ if not args.no_merged:
903
+ logger.error("SmoothQuant must be used on separately exported models")
904
+ else:
905
+ logger.info(
906
+ f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8"
907
+ )
908
+ smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1])
909
+
910
+ elif args.quantization_method == "quantize_dynamic":
911
+ logger.warning(
912
+ "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`."
913
+ )
914
+
915
+ logger.info("Quantizing to int8...")
916
+ for fp32_path, int8_path in zip(old_paths, new_paths):
917
+ if os.path.exists(fp32_path):
918
+ ort_quantization.quantize_dynamic(
919
+ fp32_path,
920
+ int8_path,
921
+ op_types_to_quantize=(
922
+ ["MatMul", "Gemm", "Gather"]
923
+ if args.quantize_embedding_layer
924
+ else ["MatMul", "Gemm"]
925
+ ),
926
+ per_channel=args.quantize_per_channel,
927
+ reduce_range=args.quantize_reduce_range,
928
+ use_external_data_format=True,
929
+ extra_options={"MatMulConstBOnly": True},
930
+ )
931
+ logger.info(
932
+ f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!"
933
+ )
934
+ remove_existing_model(decoder_model_fp32_path)
935
+
936
+ logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!")
937
+
938
+ else:
939
+ raise Exception(f"Could not recognize {args.quantization_method} as a quantization method")
940
+
941
+ elif args.precision == Precision.INT4:
942
+ if args.execution_provider != "cpu":
943
+ old_paths = convert_to_float16(args, old_paths, rank)
944
+
945
+ decoder_model_int4_path = os.path.join(
946
+ args.output, f"rank_{rank}_{args.model_name}_decoder_model_int4.onnx"
947
+ )
948
+ decoder_with_past_model_int4_path = os.path.join(
949
+ args.output, f"rank_{rank}_{args.model_name}_decoder_with_past_model_int4.onnx"
950
+ )
951
+ decoder_merged_model_int4_path = os.path.join(
952
+ args.output, f"rank_{rank}_{args.model_name}_decoder_merged_model_int4.onnx"
953
+ )
954
+ new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path]
955
+
956
+ for fp_path, int4_path in zip(old_paths, new_paths):
957
+ if os.path.exists(fp_path):
958
+ model = onnx.load_model(fp_path, load_external_data=True)
959
+ quant = MatMul4BitsQuantizer(
960
+ model=model,
961
+ block_size=args.block_size,
962
+ is_symmetric=True,
963
+ accuracy_level=args.int4_accuracy_level,
964
+ nodes_to_exclude=[],
965
+ )
966
+ quant.process()
967
+ quant.model.save_model_to_file(int4_path, use_external_data_format=True)
968
+ del model
969
+ del quant
970
+ logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!")
971
+ remove_existing_model(fp_path)
972
+ barrier()
973
+
974
+ if args.use_dynamo_export:
975
+ return
976
+
977
+ logger.info("Verifying parity on all ONNX models created")
978
+
979
+ # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
980
+ args.precision = (
981
+ "fp32"
982
+ if args.precision in {Precision.INT8, Precision.FLOAT32}
983
+ or (args.precision == Precision.INT4 and args.execution_provider == "cpu")
984
+ else "fp16"
985
+ )
986
+
987
+ # Verify parity on all saved ONNX models
988
+ for filename in os.listdir(args.output):
989
+ if (
990
+ ".data" in filename
991
+ or ".onnx" not in filename
992
+ or args.precision not in filename
993
+ or f"rank_{rank}" not in filename
994
+ ):
995
+ continue
996
+
997
+ parity_cmd = [
998
+ "-m",
999
+ original_model_name,
1000
+ "-o",
1001
+ os.path.join(args.output, filename),
1002
+ "-ep",
1003
+ args.execution_provider,
1004
+ "--precision",
1005
+ args.precision,
1006
+ "--cache_dir",
1007
+ args.cache_dir,
1008
+ "--torch_model_directory",
1009
+ args.input,
1010
+ ]
1011
+ if args.small_gpu:
1012
+ parity_cmd.append("--small_gpu")
1013
+ if "with_past" in filename:
1014
+ parity_cmd.append("--use_past_kv")
1015
+ if "merged" in filename:
1016
+ parity_cmd.append("--merged")
1017
+
1018
+ try:
1019
+ logger.info(f"check parity with cmd: {parity_cmd}")
1020
+ parity_check(parity_cmd)
1021
+ except Exception as e:
1022
+ logger.exception(f"An error occurred while verifying parity: {e}")
1023
+ sys.exit(-1)
1024
+
1025
+
1026
+ if __name__ == "__main__":
1027
+ main()