mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (458) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +2 -2
  5. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +42 -11
  9. mindspore/_extends/builtin_operations.py +3 -3
  10. mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
  11. mindspore/_extends/optimize/cell_utils.py +96 -0
  12. mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
  13. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
  14. mindspore/_extends/parse/__init__.py +3 -3
  15. mindspore/_extends/parse/compile_config.py +44 -22
  16. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
  17. mindspore/_extends/parse/parser.py +65 -84
  18. mindspore/_extends/parse/resources.py +39 -0
  19. mindspore/_extends/parse/standard_method.py +58 -14
  20. mindspore/_extends/parse/trope.py +8 -1
  21. mindspore/_extends/pijit/__init__.py +1 -2
  22. mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
  23. mindspore/amp.py +4 -22
  24. mindspore/atlprov.dll +0 -0
  25. mindspore/avcodec-59.dll +0 -0
  26. mindspore/avdevice-59.dll +0 -0
  27. mindspore/avfilter-8.dll +0 -0
  28. mindspore/avformat-59.dll +0 -0
  29. mindspore/avutil-57.dll +0 -0
  30. mindspore/boost/adasum.py +1 -1
  31. mindspore/boost/boost_cell_wrapper.py +4 -4
  32. mindspore/c1.dll +0 -0
  33. mindspore/c1xx.dll +0 -0
  34. mindspore/c2.dll +0 -0
  35. mindspore/common/__init__.py +43 -12
  36. mindspore/common/_grad_function.py +2 -1
  37. mindspore/common/_pijit_context.py +28 -7
  38. mindspore/common/_stub_tensor.py +1 -209
  39. mindspore/common/_tensor_cpp_method.py +1 -1
  40. mindspore/common/_tensor_docs.py +178 -53
  41. mindspore/common/_utils.py +9 -1
  42. mindspore/common/api.py +377 -203
  43. mindspore/common/dtype.py +108 -57
  44. mindspore/common/dump.py +11 -16
  45. mindspore/common/dynamic_shape/__init__.py +0 -0
  46. mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
  47. mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
  48. mindspore/common/file_system.py +59 -9
  49. mindspore/common/generator.py +5 -3
  50. mindspore/common/hook_handle.py +33 -5
  51. mindspore/common/jit_config.py +1 -1
  52. mindspore/common/jit_trace.py +84 -105
  53. mindspore/common/np_dtype.py +3 -3
  54. mindspore/common/parameter.py +27 -29
  55. mindspore/common/recompute.py +5 -7
  56. mindspore/common/sparse_tensor.py +0 -3
  57. mindspore/common/symbol.py +0 -1
  58. mindspore/common/tensor.py +117 -131
  59. mindspore/communication/_comm_helper.py +46 -4
  60. mindspore/communication/management.py +79 -7
  61. mindspore/context.py +67 -55
  62. mindspore/dataset/__init__.py +1 -1
  63. mindspore/dataset/audio/transforms.py +1 -1
  64. mindspore/dataset/core/config.py +38 -4
  65. mindspore/dataset/engine/datasets.py +350 -322
  66. mindspore/dataset/engine/datasets_user_defined.py +70 -24
  67. mindspore/dataset/engine/iterators.py +2 -2
  68. mindspore/dataset/engine/obs/config_loader.py +2 -2
  69. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
  70. mindspore/dataset/transforms/c_transforms.py +2 -2
  71. mindspore/dataset/transforms/py_transforms.py +7 -3
  72. mindspore/dataset/transforms/transforms.py +10 -6
  73. mindspore/dataset/vision/__init__.py +1 -1
  74. mindspore/dataset/vision/py_transforms.py +8 -8
  75. mindspore/dataset/vision/transforms.py +17 -5
  76. mindspore/dataset/vision/utils.py +632 -21
  77. mindspore/dataset/vision/validators.py +1 -0
  78. mindspore/device_context/ascend/device.py +1 -1
  79. mindspore/device_context/ascend/op_tuning.py +35 -1
  80. mindspore/device_context/gpu/__init__.py +2 -2
  81. mindspore/device_context/gpu/device.py +1 -1
  82. mindspore/device_context/gpu/op_precision.py +4 -2
  83. mindspore/device_context/gpu/op_tuning.py +6 -3
  84. mindspore/device_manager.py +16 -9
  85. mindspore/dnnl.dll +0 -0
  86. mindspore/dpcmi.dll +0 -0
  87. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
  88. mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
  89. mindspore/experimental/optim/adadelta.py +13 -20
  90. mindspore/experimental/optim/adagrad.py +15 -22
  91. mindspore/experimental/optim/adam.py +17 -24
  92. mindspore/experimental/optim/adamax.py +14 -22
  93. mindspore/experimental/optim/adamw.py +28 -34
  94. mindspore/experimental/optim/asgd.py +15 -25
  95. mindspore/experimental/optim/lr_scheduler.py +27 -45
  96. mindspore/experimental/optim/nadam.py +14 -24
  97. mindspore/experimental/optim/optimizer.py +13 -23
  98. mindspore/experimental/optim/radam.py +18 -24
  99. mindspore/experimental/optim/rmsprop.py +14 -25
  100. mindspore/experimental/optim/rprop.py +15 -26
  101. mindspore/experimental/optim/sgd.py +9 -19
  102. mindspore/hal/__init__.py +4 -4
  103. mindspore/hal/contiguous_tensors_handle.py +2 -2
  104. mindspore/hal/memory.py +27 -7
  105. mindspore/include/api/cell.h +65 -5
  106. mindspore/include/api/cfg.h +24 -7
  107. mindspore/include/api/context.h +1 -0
  108. mindspore/include/api/delegate.h +10 -2
  109. mindspore/include/api/dual_abi_helper.h +100 -19
  110. mindspore/include/api/graph.h +14 -1
  111. mindspore/include/api/kernel.h +16 -3
  112. mindspore/include/api/kernel_api.h +9 -1
  113. mindspore/include/api/metrics/accuracy.h +9 -0
  114. mindspore/include/api/model.h +8 -1
  115. mindspore/include/api/model_group.h +4 -0
  116. mindspore/include/api/model_parallel_runner.h +2 -0
  117. mindspore/include/api/status.h +48 -10
  118. mindspore/include/api/types.h +8 -3
  119. mindspore/include/c_api/model_c.h +0 -58
  120. mindspore/include/c_api/tensor_c.h +0 -26
  121. mindspore/include/dataset/constants.h +9 -0
  122. mindspore/include/dataset/vision_ascend.h +1 -1
  123. mindspore/jpeg62.dll +0 -0
  124. mindspore/mindrecord/tools/cifar10.py +61 -11
  125. mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
  126. mindspore/mindspore_backend_common.dll +0 -0
  127. mindspore/mindspore_backend_manager.dll +0 -0
  128. mindspore/mindspore_common.dll +0 -0
  129. mindspore/mindspore_core.dll +0 -0
  130. mindspore/mindspore_cpu_res_manager.dll +0 -0
  131. mindspore/mindspore_dump.dll +0 -0
  132. mindspore/mindspore_frontend.dll +0 -0
  133. mindspore/mindspore_glog.dll +0 -0
  134. mindspore/mindspore_memory_pool.dll +0 -0
  135. mindspore/mindspore_ms_backend.dll +0 -0
  136. mindspore/mindspore_ops.dll +0 -0
  137. mindspore/mindspore_ops_host.dll +0 -0
  138. mindspore/mindspore_ops_kernel_common.dll +0 -0
  139. mindspore/mindspore_profiler.dll +0 -0
  140. mindspore/mindspore_pyboost.dll +0 -0
  141. mindspore/mindspore_pynative.dll +0 -0
  142. mindspore/mindspore_res_manager.dll +0 -0
  143. mindspore/mindspore_runtime_pipeline.dll +0 -0
  144. mindspore/mint/__init__.py +6 -46
  145. mindspore/mint/distributed/__init__.py +5 -0
  146. mindspore/mint/distributed/distributed.py +429 -23
  147. mindspore/mint/nn/__init__.py +1 -1
  148. mindspore/mint/nn/functional.py +53 -6
  149. mindspore/mint/nn/layer/_functions.py +163 -294
  150. mindspore/mint/nn/layer/activation.py +8 -6
  151. mindspore/mint/nn/layer/conv.py +140 -104
  152. mindspore/mint/nn/layer/normalization.py +11 -25
  153. mindspore/mint/optim/adam.py +19 -18
  154. mindspore/mint/optim/adamw.py +14 -8
  155. mindspore/mint/optim/sgd.py +5 -5
  156. mindspore/msobj140.dll +0 -0
  157. mindspore/mspdb140.dll +0 -0
  158. mindspore/mspdbcore.dll +0 -0
  159. mindspore/mspdbst.dll +0 -0
  160. mindspore/mspft140.dll +0 -0
  161. mindspore/msvcdis140.dll +0 -0
  162. mindspore/msvcp140_1.dll +0 -0
  163. mindspore/msvcp140_2.dll +0 -0
  164. mindspore/msvcp140_atomic_wait.dll +0 -0
  165. mindspore/msvcp140_codecvt_ids.dll +0 -0
  166. mindspore/nn/cell.py +491 -623
  167. mindspore/nn/grad/cell_grad.py +11 -12
  168. mindspore/nn/layer/activation.py +36 -36
  169. mindspore/nn/layer/basic.py +74 -77
  170. mindspore/nn/layer/channel_shuffle.py +4 -4
  171. mindspore/nn/layer/combined.py +4 -2
  172. mindspore/nn/layer/conv.py +117 -110
  173. mindspore/nn/layer/dense.py +9 -7
  174. mindspore/nn/layer/embedding.py +50 -52
  175. mindspore/nn/layer/image.py +38 -40
  176. mindspore/nn/layer/math.py +111 -112
  177. mindspore/nn/layer/normalization.py +56 -44
  178. mindspore/nn/layer/pooling.py +58 -63
  179. mindspore/nn/layer/rnn_cells.py +33 -33
  180. mindspore/nn/layer/rnns.py +56 -56
  181. mindspore/nn/layer/thor_layer.py +74 -73
  182. mindspore/nn/layer/transformer.py +11 -1
  183. mindspore/nn/learning_rate_schedule.py +20 -20
  184. mindspore/nn/loss/loss.py +79 -81
  185. mindspore/nn/optim/adam.py +4 -6
  186. mindspore/nn/optim/adasum.py +2 -2
  187. mindspore/nn/optim/asgd.py +2 -0
  188. mindspore/nn/optim/lamb.py +1 -3
  189. mindspore/nn/optim/optimizer.py +1 -1
  190. mindspore/nn/optim/tft_wrapper.py +2 -3
  191. mindspore/nn/optim/thor.py +2 -2
  192. mindspore/nn/probability/distribution/_utils/utils.py +2 -2
  193. mindspore/nn/probability/distribution/exponential.py +2 -1
  194. mindspore/nn/probability/distribution/poisson.py +2 -1
  195. mindspore/nn/sparse/sparse.py +3 -3
  196. mindspore/nn/wrap/cell_wrapper.py +73 -42
  197. mindspore/nn/wrap/grad_reducer.py +37 -52
  198. mindspore/nn/wrap/loss_scale.py +72 -74
  199. mindspore/numpy/array_creations.py +7 -7
  200. mindspore/numpy/fft.py +1 -1
  201. mindspore/numpy/math_ops.py +5 -5
  202. mindspore/numpy/utils_const.py +1 -1
  203. mindspore/opencv_core452.dll +0 -0
  204. mindspore/opencv_imgcodecs452.dll +0 -0
  205. mindspore/opencv_imgproc452.dll +0 -0
  206. mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
  207. mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
  208. mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
  209. mindspore/ops/_op_impl/cpu/__init__.py +1 -0
  210. mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
  211. mindspore/ops/_vmap/vmap_array_ops.py +31 -13
  212. mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
  213. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
  214. mindspore/ops/auto_generate/gen_extend_func.py +27 -145
  215. mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
  216. mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
  217. mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
  218. mindspore/ops/composite/__init__.py +10 -0
  219. mindspore/ops/composite/base.py +9 -5
  220. mindspore/ops/composite/multitype_ops/__init__.py +12 -1
  221. mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
  222. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
  223. mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
  224. mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
  225. mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
  226. mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
  227. mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
  228. mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
  229. mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
  230. mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
  231. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
  232. mindspore/ops/function/__init__.py +4 -1
  233. mindspore/ops/function/_add_attr_func.py +11 -6
  234. mindspore/ops/function/array_func.py +19 -102
  235. mindspore/ops/function/debug_func.py +8 -5
  236. mindspore/ops/function/grad/grad_func.py +5 -13
  237. mindspore/ops/function/math_func.py +77 -572
  238. mindspore/ops/function/nn_func.py +46 -94
  239. mindspore/ops/function/other_func.py +4 -1
  240. mindspore/ops/function/random_func.py +44 -5
  241. mindspore/ops/function/vmap_func.py +2 -1
  242. mindspore/ops/functional.py +4 -4
  243. mindspore/ops/functional_overload.py +594 -18
  244. mindspore/ops/op_info_register.py +21 -0
  245. mindspore/ops/operations/__init__.py +16 -11
  246. mindspore/ops/operations/_custom_ops_utils.py +689 -34
  247. mindspore/ops/operations/_inner_ops.py +14 -18
  248. mindspore/ops/operations/_sequence_ops.py +1 -1
  249. mindspore/ops/operations/array_ops.py +5 -51
  250. mindspore/ops/operations/comm_ops.py +186 -41
  251. mindspore/ops/operations/custom_ops.py +303 -177
  252. mindspore/ops/operations/debug_ops.py +59 -4
  253. mindspore/ops/operations/image_ops.py +13 -13
  254. mindspore/ops/operations/manually_defined/ops_def.py +27 -28
  255. mindspore/ops/operations/math_ops.py +8 -9
  256. mindspore/ops/operations/nn_ops.py +8 -40
  257. mindspore/ops/primitive.py +9 -20
  258. mindspore/ops/tensor_method.py +63 -15
  259. mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
  260. mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
  261. mindspore/ops_generate/api/functions_cc_generator.py +58 -10
  262. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
  263. mindspore/ops_generate/common/base_generator.py +14 -0
  264. mindspore/ops_generate/common/gen_constants.py +8 -3
  265. mindspore/ops_generate/common/gen_utils.py +0 -19
  266. mindspore/ops_generate/common/op_proto.py +11 -4
  267. mindspore/ops_generate/common/template.py +88 -11
  268. mindspore/ops_generate/gen_ops.py +1 -1
  269. mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
  270. mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
  271. mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
  272. mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
  273. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
  274. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
  275. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
  276. mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
  277. mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
  278. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
  279. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
  280. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
  281. mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
  282. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
  283. mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
  284. mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
  285. mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
  286. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
  287. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
  288. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
  289. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
  290. mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
  291. mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
  292. mindspore/parallel/_auto_parallel_context.py +16 -23
  293. mindspore/parallel/_cell_wrapper.py +113 -45
  294. mindspore/parallel/_parallel_serialization.py +4 -3
  295. mindspore/parallel/_ps_context.py +4 -6
  296. mindspore/parallel/_tensor.py +167 -12
  297. mindspore/parallel/_transformer/moe.py +1 -1
  298. mindspore/parallel/_transformer/transformer.py +17 -12
  299. mindspore/parallel/_utils.py +5 -11
  300. mindspore/parallel/auto_parallel.py +35 -14
  301. mindspore/parallel/checkpoint_convert.py +3 -3
  302. mindspore/parallel/checkpoint_transform.py +13 -7
  303. mindspore/parallel/cluster/process_entity/_api.py +88 -49
  304. mindspore/parallel/cluster/process_entity/_utils.py +95 -7
  305. mindspore/parallel/cluster/run.py +48 -7
  306. mindspore/parallel/function/__init__.py +8 -1
  307. mindspore/parallel/function/reshard_func.py +12 -12
  308. mindspore/parallel/nn/__init__.py +15 -2
  309. mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
  310. mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
  311. mindspore/parallel/shard.py +10 -25
  312. mindspore/parallel/transform_safetensors.py +469 -174
  313. mindspore/pgodb140.dll +0 -0
  314. mindspore/pgort140.dll +0 -0
  315. mindspore/profiler/__init__.py +2 -1
  316. mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
  317. mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
  318. mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
  319. mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
  320. mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
  321. mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
  322. mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
  323. mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
  324. mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
  325. mindspore/profiler/analysis/task_manager.py +1 -1
  326. mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
  327. mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
  328. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
  329. mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
  330. mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
  331. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
  332. mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
  333. mindspore/profiler/common/constant.py +16 -0
  334. mindspore/profiler/common/msprof_cmd_tool.py +2 -2
  335. mindspore/profiler/common/path_manager.py +9 -0
  336. mindspore/profiler/common/profiler_context.py +50 -29
  337. mindspore/profiler/common/profiler_info.py +0 -16
  338. mindspore/profiler/common/profiler_meta_data.py +1 -0
  339. mindspore/profiler/common/profiler_op_analyse.py +239 -0
  340. mindspore/profiler/common/profiler_output_path.py +23 -8
  341. mindspore/profiler/common/profiler_parameters.py +128 -35
  342. mindspore/profiler/dynamic_profile/__init__.py +0 -0
  343. mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
  344. mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
  345. mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
  346. mindspore/profiler/dynamic_profiler.py +374 -338
  347. mindspore/profiler/envprofiler.py +42 -12
  348. mindspore/profiler/experimental_config.py +112 -7
  349. mindspore/profiler/mstx.py +33 -12
  350. mindspore/profiler/platform/__init__.py +2 -3
  351. mindspore/profiler/platform/cpu_profiler.py +10 -4
  352. mindspore/profiler/platform/npu_profiler.py +30 -20
  353. mindspore/profiler/profiler.py +218 -154
  354. mindspore/profiler/profiler_action_controller.py +65 -77
  355. mindspore/profiler/profiler_interface.py +2 -2
  356. mindspore/profiler/schedule.py +10 -4
  357. mindspore/rewrite/common/config.py +1 -0
  358. mindspore/rewrite/common/namer.py +1 -0
  359. mindspore/rewrite/common/namespace.py +1 -0
  360. mindspore/rewrite/node/node.py +31 -11
  361. mindspore/rewrite/parsers/assign_parser.py +1 -1
  362. mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
  363. mindspore/run_check/_check_version.py +7 -10
  364. mindspore/runtime/__init__.py +8 -6
  365. mindspore/runtime/event.py +10 -4
  366. mindspore/runtime/executor.py +87 -45
  367. mindspore/runtime/memory.py +31 -32
  368. mindspore/runtime/thread_bind_core.py +299 -165
  369. mindspore/safeguard/rewrite_obfuscation.py +12 -13
  370. mindspore/swresample-4.dll +0 -0
  371. mindspore/swscale-6.dll +0 -0
  372. mindspore/tbbmalloc.dll +0 -0
  373. mindspore/tinyxml2.dll +0 -0
  374. mindspore/train/_utils.py +17 -7
  375. mindspore/train/amp.py +43 -23
  376. mindspore/train/callback/__init__.py +5 -5
  377. mindspore/train/callback/_callback.py +2 -1
  378. mindspore/train/callback/_checkpoint.py +4 -14
  379. mindspore/train/callback/_flops_collector.py +11 -7
  380. mindspore/train/callback/_landscape.py +0 -1
  381. mindspore/train/callback/_train_fault_tolerance.py +98 -21
  382. mindspore/train/data_sink.py +15 -6
  383. mindspore/train/dataset_helper.py +14 -5
  384. mindspore/train/model.py +133 -69
  385. mindspore/train/serialization.py +168 -126
  386. mindspore/train/summary/summary_record.py +13 -2
  387. mindspore/train/train_thor/model_thor.py +2 -2
  388. mindspore/turbojpeg.dll +0 -0
  389. mindspore/utils/__init__.py +3 -2
  390. mindspore/utils/dryrun.py +0 -6
  391. mindspore/utils/runtime_execution_order_check.py +163 -77
  392. mindspore/utils/sdc_detect.py +68 -0
  393. mindspore/utils/utils.py +14 -17
  394. mindspore/vcmeta.dll +0 -0
  395. mindspore/vcruntime140.dll +0 -0
  396. mindspore/vcruntime140_1.dll +0 -0
  397. mindspore/version.py +1 -1
  398. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
  399. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
  400. mindspore/_deprecated/jit.py +0 -198
  401. mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
  402. mindspore/communication/_hccl_management.py +0 -297
  403. mindspore/experimental/es/embedding_service.py +0 -891
  404. mindspore/experimental/es/embedding_service_layer.py +0 -581
  405. mindspore/profiler/common/validator/__init__.py +0 -14
  406. mindspore/profiler/common/validator/validate_path.py +0 -84
  407. mindspore/profiler/parser/__init__.py +0 -14
  408. mindspore/profiler/parser/aicpu_data_parser.py +0 -272
  409. mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
  410. mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
  411. mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
  412. mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
  413. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
  414. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
  415. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
  416. mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
  417. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
  418. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
  419. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
  420. mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
  421. mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
  422. mindspore/profiler/parser/ascend_flops_generator.py +0 -116
  423. mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
  424. mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
  425. mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
  426. mindspore/profiler/parser/ascend_memory_generator.py +0 -185
  427. mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
  428. mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
  429. mindspore/profiler/parser/ascend_op_generator.py +0 -334
  430. mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
  431. mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
  432. mindspore/profiler/parser/base_timeline_generator.py +0 -483
  433. mindspore/profiler/parser/container.py +0 -229
  434. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
  435. mindspore/profiler/parser/flops_parser.py +0 -531
  436. mindspore/profiler/parser/framework_enum.py +0 -111
  437. mindspore/profiler/parser/framework_parser.py +0 -464
  438. mindspore/profiler/parser/framework_struct.py +0 -61
  439. mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
  440. mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
  441. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
  442. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
  443. mindspore/profiler/parser/hccl_parser.py +0 -573
  444. mindspore/profiler/parser/hwts_log_parser.py +0 -122
  445. mindspore/profiler/parser/integrator.py +0 -526
  446. mindspore/profiler/parser/memory_usage_parser.py +0 -277
  447. mindspore/profiler/parser/minddata_analyzer.py +0 -800
  448. mindspore/profiler/parser/minddata_parser.py +0 -186
  449. mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
  450. mindspore/profiler/parser/op_intermediate_parser.py +0 -149
  451. mindspore/profiler/parser/optime_parser.py +0 -250
  452. mindspore/profiler/parser/profiler_info.py +0 -213
  453. mindspore/profiler/parser/step_trace_parser.py +0 -666
  454. mindspore/utils/hooks.py +0 -81
  455. /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
  456. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
  457. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
  458. {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1109 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Custom op dsl file, used for dynamic format/data type select, update akg info and compile akg info"""
16
+ from __future__ import absolute_import
17
+ import os
18
+ import sys
19
+ import json
20
+ import copy
21
+ import functools
22
+ import subprocess
23
+ import shutil
24
+
25
+ from tbe.common.buildcfg import get_current_build_config
26
+ from impl.util.util_select_op_base import gen_param
27
+ from impl.util.util_select_op_base import get_dynamic_param_in_json
28
+
29
+ BLOCK = 16
30
+ FP16_MAX = 65504
31
+ OP = "op"
32
+ STR = "str"
33
+ NAME = "name"
34
+ TENSOR_NAME = "tensor_name"
35
+ ATTR = "attr"
36
+ VALUE = "value"
37
+ SHAPE = "shape"
38
+ FORMAT = "format"
39
+ DATA_TYPE = "data_type"
40
+ ORI_SHAPE = "ori_shape"
41
+ ORI_FORMAT = "ori_format"
42
+ ORI_DATA_TYPE = "ori_data_type"
43
+ OP_DESC = "op_desc"
44
+ INPUT_DESC = "input_desc"
45
+ OUTPUT_DESC = "output_desc"
46
+ FRACTAL_NZ = "FRACTAL_NZ"
47
+ DEFAULT_FORMAT = "DefaultFormat"
48
+ FLOAT16 = "float16"
49
+ FLOAT32 = "float32"
50
+ O_SUFFIX = ".o"
51
+ JSON_SUFFIX = ".json"
52
+
53
+
54
+ def copy_shape(shape):
55
+ """Deep copy shape"""
56
+ res = []
57
+ if isinstance(shape, int):
58
+ shape = [shape]
59
+ for _, s in enumerate(shape):
60
+ res.append(s)
61
+ return res
62
+
63
+ # InfoGlobalConfig is used to store global configuration for info files.
64
+ # It can be accessed or modified internally in custom.py using InfoGlobalConfig.xxx.
65
+
66
+
67
+ class InfoGlobalConfig:
68
+ # whether enable akg cce lib
69
+ enable_cce_lib = False
70
+ # ascend arch type, for 910B and 910A
71
+ ascend_arch = ""
72
+
73
+
74
+ class OpInfer:
75
+ """Base infer class, used to provide supported formats and data type of each op and update each of"""
76
+
77
+ def __init__(self, op_desc):
78
+ self.name = op_desc[NAME]
79
+ self.op_desc = op_desc
80
+ self.input_desc = []
81
+ self.output_desc = []
82
+ self.attr = {}
83
+ if isinstance(op_desc.get(INPUT_DESC), list):
84
+ for desc in op_desc[INPUT_DESC]:
85
+ for item in desc:
86
+ self.input_desc.append(item)
87
+ if isinstance(op_desc.get(ATTR), list):
88
+ for item in op_desc[ATTR]:
89
+ self.attr[item[NAME]] = item
90
+ if isinstance(op_desc.get(OUTPUT_DESC), list):
91
+ for item in op_desc[OUTPUT_DESC]:
92
+ self.output_desc.append(item)
93
+
94
+ @staticmethod
95
+ def is_nz(shape):
96
+ """check if shape can be converted to FRACTAL_NZ"""
97
+ if len(shape) >= 2 and shape[-2] % BLOCK == 0 and shape[-1] % BLOCK == 0:
98
+ return True
99
+ return False
100
+
101
+ @staticmethod
102
+ def update_format(formats, new_format):
103
+ """combine new_format to formats"""
104
+ new_formats = [new_format] if not isinstance(new_format, (list, tuple)) else new_format
105
+ for f in new_formats:
106
+ if f not in formats:
107
+ formats.append(f)
108
+
109
+ def get_attr(self, key):
110
+ """get the value of attr"""
111
+ if key not in self.attr:
112
+ raise KeyError("Can not find attr '{}' in op '{}'".format(key, self.name))
113
+ return self.attr.get(key)[VALUE]
114
+
115
+ def set_attr(self, key, value):
116
+ """set the value of attr"""
117
+ if key not in self.attr:
118
+ raise KeyError("Can not find attr '{}' in op '{}'".format(key, self.name))
119
+ self.attr.get(key)[VALUE] = value
120
+
121
+ def supported_type(self):
122
+ """get the supported data type of current op"""
123
+ keep_fp32 = False
124
+ for item in self.input_desc:
125
+ # check if type can reduce precision
126
+ value = item.get(VALUE, None)
127
+ if item[DATA_TYPE] == FLOAT32 and value is not None and abs(value) > FP16_MAX:
128
+ keep_fp32 = True
129
+ break
130
+ io_type = ",".join([t[DATA_TYPE] for t in self.input_desc] + [t[DATA_TYPE] for t in self.output_desc])
131
+ fp32_type = io_type.replace(FLOAT16, FLOAT32)
132
+ fp16_type = io_type.replace(FLOAT32, FLOAT16)
133
+ supported_types = [io_type]
134
+ if fp32_type not in supported_types:
135
+ supported_types.append(fp32_type)
136
+ if not keep_fp32 and fp16_type not in supported_types:
137
+ supported_types.append(fp16_type)
138
+ return supported_types
139
+
140
+ def supported_format(self):
141
+ """get the supported format of current op"""
142
+ io_num = len(self.input_desc) + len(self.output_desc)
143
+ nd = ["ND"] * io_num
144
+ return [",".join(nd)]
145
+
146
+ def infer_type(self):
147
+ """infer data type"""
148
+ fixed_out_type_ops = ["Equal", "Less", "LessEqual", "Greater", "GreaterEqual"]
149
+ if self.name not in fixed_out_type_ops:
150
+ self.output_desc[0][DATA_TYPE] = self.input_desc[0][DATA_TYPE]
151
+
152
+ def infer_format(self):
153
+ """infer format"""
154
+ self.output_desc[0][FORMAT] = self.input_desc[0][FORMAT]
155
+
156
+ def infer_shape(self):
157
+ """infer shape"""
158
+ self.output_desc[0][SHAPE] = copy_shape(self.input_desc[0][SHAPE])
159
+
160
+ def infer_ori_shape(self):
161
+ """infer original shape"""
162
+ for _, desc in enumerate(self.output_desc):
163
+ desc[ORI_SHAPE] = copy_shape(desc[SHAPE])
164
+
165
+ def infer(self):
166
+ """infer shape, format and data type"""
167
+ self.infer_type()
168
+ self.infer_format()
169
+ self.infer_shape()
170
+
171
+ def post_process(self):
172
+ """post process after infer"""
173
+
174
+ def update(self):
175
+ """update each of"""
176
+ for _, desc in enumerate(self.output_desc):
177
+ desc[ORI_DATA_TYPE] = desc[DATA_TYPE]
178
+ desc[ORI_FORMAT] = desc[FORMAT]
179
+ self.infer_ori_shape()
180
+ self.infer()
181
+ self.post_process()
182
+
183
+
184
+ class Elemwise(OpInfer):
185
+ """Elemwise op with one input and one output."""
186
+
187
+ def supported_format(self):
188
+ if self.name == "Reciprocal":
189
+ supported_formats = ["ND,ND"]
190
+ # pad will cause 'divided by 0'
191
+ if self.is_nz(self.input_desc[0][SHAPE]):
192
+ self.update_format(supported_formats, "FRACTAL_NZ,FRACTAL_NZ")
193
+ return supported_formats
194
+ return ["ND,ND", "FRACTAL_NZ,FRACTAL_NZ", "NC1HWC0,NC1HWC0", "FRACTAL_Z,FRACTAL_Z"]
195
+
196
+ def infer_ori_shape(self):
197
+ self.output_desc[0][ORI_SHAPE] = self.input_desc[0][ORI_SHAPE]
198
+
199
+
200
+ class Cast(Elemwise):
201
+ """Cast op."""
202
+
203
+ def supported_type(self):
204
+ in_type = self.input_desc[0][DATA_TYPE]
205
+ out_type = self.output_desc[0][DATA_TYPE]
206
+ io_type = ",".join([in_type, out_type])
207
+ return [io_type]
208
+
209
+ def infer_type(self):
210
+ self.output_desc[0][DATA_TYPE] = self.output_desc[0][DATA_TYPE]
211
+
212
+
213
+ class ElemwiseBinaryNoBroadcast(OpInfer):
214
+ """Elemwise op with two inputs and one output, not supports broadcast."""
215
+
216
+ def supported_format(self):
217
+ return ["ND,ND,ND", "FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ", "NC1HWC0,NC1HWC0,NC1HWC0",
218
+ "FRACTAL_Z,FRACTAL_Z,FRACTAL_Z"]
219
+
220
+ def infer_ori_shape(self):
221
+ self.output_desc[0][ORI_SHAPE] = self.input_desc[0][ORI_SHAPE]
222
+
223
+
224
+ class ElemwiseBinary(OpInfer):
225
+ """Elemwise op with two inputs and one output, supports broadcast."""
226
+
227
+ @staticmethod
228
+ def nd2fractal_nz(shape):
229
+ """convert ND shape to FRACTAL_NZ shape"""
230
+ if len(shape) == 1:
231
+ if shape[-1] == 1:
232
+ return [1, 1, 1, 1]
233
+ if shape[-1] % BLOCK == 0:
234
+ return [shape[-1] // BLOCK, 1, 1, BLOCK]
235
+ elif len(shape) >= 2:
236
+ if shape[-2] == 1 and shape[-1] == 1:
237
+ return shape[:-2] + [1, 1, 1, 1]
238
+ if shape[-2] == 1 and shape[-1] % BLOCK == 0:
239
+ return shape[:-2] + [shape[-1] // BLOCK, 1, 1, BLOCK]
240
+ if shape[-2] % BLOCK == 0 and shape[-1] == 1:
241
+ return shape[:-2] + [1, shape[-2] // BLOCK, BLOCK, 1]
242
+ return []
243
+
244
+ def broadcast_shape(self, sh0, sh1):
245
+ """calculate broadcast shape"""
246
+ out_shape = []
247
+ max_len = max(len(sh0), len(sh1))
248
+ pad_sh0 = [1] * (max_len - len(sh0)) + sh0
249
+ pad_sh1 = [1] * (max_len - len(sh1)) + sh1
250
+ for i in range(max_len):
251
+ a, b = pad_sh0[i], pad_sh1[i]
252
+ if a == 1:
253
+ out_shape.append(b)
254
+ elif b in [1, a]:
255
+ out_shape.append(a)
256
+ else:
257
+ raise ValueError("For '{}', input shapes {} and {} can not broadcast".format(self.name, sh0, sh1))
258
+ return pad_sh0, pad_sh1, out_shape
259
+
260
+ def supported_format(self):
261
+ sh0, sh1 = self.input_desc[0][SHAPE], self.input_desc[1][SHAPE]
262
+ supported_formats = ["ND,ND,ND"]
263
+ is_const_0 = (VALUE in self.input_desc[0])
264
+ is_const_1 = (VALUE in self.input_desc[1])
265
+ if sh0 == sh1 or is_const_0 or is_const_1:
266
+ # No broadcast case
267
+ self.update_format(supported_formats, ["FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ", "NC1HWC0,NC1HWC0,NC1HWC0",
268
+ "FRACTAL_Z,FRACTAL_Z,FRACTAL_Z"])
269
+ else:
270
+ # note: (1, 640), (640) "FRACTAL_NZ,ND,FRACTAL_NZ", (1, 640) comes from MatMul
271
+ if len(sh0) == 2 and len(sh1) == 1:
272
+ if sh0[-1] == sh1[-1] and sh1[-1] % BLOCK == 0:
273
+ self.update_format(supported_formats, "FRACTAL_NZ,ND,FRACTAL_NZ")
274
+ elif len(sh0) == 1 and len(sh1) == 2:
275
+ if sh0[-1] == sh1[-1] and sh0[-1] % BLOCK == 0:
276
+ self.update_format(supported_formats, "ND,FRACTAL_NZ,FRACTAL_NZ")
277
+ # Broadcast case
278
+ pad_sh0, pad_sh1, _ = self.broadcast_shape(sh0, sh1)
279
+ # 1D with broadcast only supports "ND,ND,ND"
280
+ if len(pad_sh0) > 1:
281
+ nz0, nz1 = self.is_nz(pad_sh0), self.is_nz(pad_sh1)
282
+ if nz0 and nz1:
283
+ self.update_format(supported_formats, "FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ")
284
+ elif nz0:
285
+ self.update_format(supported_formats, "FRACTAL_NZ,ND,FRACTAL_NZ")
286
+ elif nz1:
287
+ self.update_format(supported_formats, "ND,FRACTAL_NZ,FRACTAL_NZ")
288
+ # note: ND,ND,FRACTAL_NZ? e.g. (1024, 1), (1, 5120)
289
+ return supported_formats
290
+
291
+ def infer_format(self):
292
+ # select special format
293
+ special_formats = ["FRACTAL", "C0"]
294
+ format0, format1 = self.input_desc[0][FORMAT], self.input_desc[1][FORMAT]
295
+ for f in special_formats:
296
+ if format0.find(f) != -1:
297
+ self.output_desc[0][FORMAT] = format0
298
+ return
299
+ self.output_desc[0][FORMAT] = format1
300
+
301
+ def infer_shape(self):
302
+ sh0, sh1 = self.input_desc[0][SHAPE], self.input_desc[1][SHAPE]
303
+ if sh0 == sh1:
304
+ self.output_desc[0][SHAPE] = copy_shape(sh0)
305
+ format0, format1 = self.input_desc[0][FORMAT], self.input_desc[1][FORMAT]
306
+ if format0 != format1:
307
+ new_sh0 = self.nd2fractal_nz(sh0)
308
+ new_sh1 = self.nd2fractal_nz(sh1)
309
+ if format0 == FRACTAL_NZ and new_sh1:
310
+ _, _, out_shape = self.broadcast_shape(sh0, new_sh1)
311
+ self.output_desc[0][SHAPE] = out_shape
312
+ return
313
+ if format1 == FRACTAL_NZ and new_sh0:
314
+ _, _, out_shape = self.broadcast_shape(new_sh0, sh1)
315
+ self.output_desc[0][SHAPE] = out_shape
316
+ return
317
+ _, _, out_shape = self.broadcast_shape(sh0, sh1)
318
+ self.output_desc[0][SHAPE] = out_shape
319
+
320
+ def infer_ori_shape(self):
321
+ sh0, sh1 = self.input_desc[0][ORI_SHAPE], self.input_desc[1][ORI_SHAPE]
322
+ _, _, out_shape = self.broadcast_shape(sh0, sh1)
323
+ self.output_desc[0][ORI_SHAPE] = out_shape
324
+
325
+
326
+ class MatMul(OpInfer):
327
+ """MatMul op."""
328
+
329
+ def supported_format(self):
330
+ input_num = len(self.input_desc)
331
+ # MatMul cce only support ND
332
+ if InfoGlobalConfig.enable_cce_lib and input_num == 2:
333
+ return ["ND,ND,ND"]
334
+ if InfoGlobalConfig.enable_cce_lib and input_num == 3:
335
+ return ["ND,ND,ND,ND"]
336
+ if input_num == 2:
337
+ return ["FRACTAL_NZ,FRACTAL_NZ,FRACTAL_NZ"]
338
+ if input_num == 3:
339
+ bias_shape = self.input_desc[2][SHAPE]
340
+ if len(bias_shape) == 1 and (bias_shape[-1] == 1 or bias_shape[-1] % BLOCK == 0):
341
+ return ["FRACTAL_NZ,FRACTAL_NZ,ND,FRACTAL_NZ"]
342
+ return ["ND,ND,ND,ND"]
343
+ raise ValueError("MatMul only supports 2 or 3 input tensors, but got {} input tensors".format(input_num))
344
+
345
+ def nd_infer(self, sh0, sh1, trans_a, trans_b):
346
+ """infer shape with nd format"""
347
+ if len(sh0) != len(sh1):
348
+ raise ValueError("For '{}', input shape '{}' and '{}' are not supported".format(self.name, sh0, sh1))
349
+ m = sh0[-2] if not trans_a else sh0[-1]
350
+ n = sh1[-1] if not trans_b else sh1[-2]
351
+ res = sh0[:-2] + [m, n]
352
+ return res
353
+
354
+ def infer_shape(self):
355
+ sh0, sh1 = self.input_desc[0][SHAPE], self.input_desc[1][SHAPE]
356
+ format0, format1 = self.input_desc[0][FORMAT], self.input_desc[1][FORMAT]
357
+ trans_a, trans_b = self.get_attr("transpose_a"), self.get_attr("transpose_b")
358
+ if format0 != format1 or len(sh0) != len(sh1):
359
+ raise ValueError("For '{}', input '{}' and '{}' are not supported"
360
+ .format(self.name, self.input_desc[0], self.input_desc[1]))
361
+ if format0 != FRACTAL_NZ and len(sh0) >= 2:
362
+ self.output_desc[0][SHAPE] = self.nd_infer(sh0, sh1, trans_a, trans_b)
363
+ elif format0 == FRACTAL_NZ and len(sh0) >= 4:
364
+ m1, m0 = sh0[-3], sh0[-2]
365
+ if trans_a:
366
+ m1, m0 = sh0[-4], sh0[-1]
367
+ n1, n0 = sh1[-4], sh1[-1]
368
+ if trans_b:
369
+ n1, n0 = sh1[-3], sh1[-2]
370
+ self.output_desc[0][SHAPE] = sh0[:-4] + [n1, m1, m0, n0]
371
+ else:
372
+ raise ValueError("For '{}', input '{}' and '{}' are not supported"
373
+ .format(self.name, self.input_desc[0], self.input_desc[1]))
374
+
375
+ def infer_ori_shape(self):
376
+ sh0, sh1 = self.input_desc[0][ORI_SHAPE], self.input_desc[1][ORI_SHAPE]
377
+ trans_a, trans_b = self.get_attr("transpose_a"), self.get_attr("transpose_b")
378
+ self.output_desc[0][ORI_SHAPE] = self.nd_infer(sh0, sh1, trans_a, trans_b)
379
+
380
+ def post_process(self):
381
+ self.op_desc[ATTR].append({DATA_TYPE: STR, NAME: "left_format", VALUE: self.input_desc[0][FORMAT]})
382
+ self.op_desc[ATTR].append({DATA_TYPE: STR, NAME: "right_format", VALUE: self.input_desc[1][FORMAT]})
383
+ self.op_desc[ATTR].append({DATA_TYPE: STR, NAME: "dst_type", VALUE: self.output_desc[0][DATA_TYPE]})
384
+
385
+ def infer_type(self):
386
+ """infer data type"""
387
+ if "910B" in InfoGlobalConfig.ascend_arch and not InfoGlobalConfig.enable_cce_lib:
388
+ self.output_desc[0][DATA_TYPE] = "float32"
389
+ else:
390
+ super().infer_type()
391
+
392
+ def supported_type(self):
393
+ if "910B" in InfoGlobalConfig.ascend_arch and not InfoGlobalConfig.enable_cce_lib:
394
+ support_types = "float16,float16,float32"
395
+ return [support_types]
396
+ return super().supported_type()
397
+
398
+ class BatchMatMul(MatMul):
399
+ """BatchMatMul op. Only support cce lib"""
400
+ def __init__(self, op_desc):
401
+ super().__init__(op_desc)
402
+ if "910B" not in InfoGlobalConfig.ascend_arch or not InfoGlobalConfig.enable_cce_lib:
403
+ raise ValueError("BatchMatMul only support 910B cce lib")
404
+
405
+ def infer_shape(self):
406
+ sh0, sh1 = self.input_desc[0][SHAPE], self.input_desc[1][SHAPE]
407
+ format0, format1 = self.input_desc[0][FORMAT], self.input_desc[1][FORMAT]
408
+ trans_a, trans_b = self.get_attr("transpose_a"), self.get_attr("transpose_b")
409
+ # only support nd
410
+ if (format0 != FRACTAL_NZ and format1 != FRACTAL_NZ):
411
+ self.output_desc[0][SHAPE] = self.nd_infer(sh0, sh1, trans_a, trans_b)
412
+ else:
413
+ raise ValueError("For '{}', input '{}' and '{}' are not supported"
414
+ .format(self.name, self.input_desc[0], self.input_desc[1]))
415
+
416
+ def nd_infer(self, sh0, sh1, trans_a, trans_b):
417
+ """infer shape with nd format"""
418
+ m = sh0[-2] if not trans_a else sh0[-1]
419
+ n = sh1[-1] if not trans_b else sh1[-2]
420
+ res = sh0[:-2] + [m, n]
421
+ return res
422
+
423
+ def infer_type(self):
424
+ """infer data type"""
425
+ self.output_desc[0][DATA_TYPE] = "float16"
426
+
427
+ def supported_type(self):
428
+ """supported type"""
429
+ return ["float16,float16,float16"]
430
+
431
+ class Reduce(OpInfer):
432
+ """Reduce op."""
433
+
434
+ @staticmethod
435
+ def _out_nz(rank, axis):
436
+ """check if output remains FRACTAL_NZ"""
437
+ if rank - 2 not in axis and rank - 1 not in axis:
438
+ return True
439
+ return False
440
+
441
+ @staticmethod
442
+ def _reduced_shape(shape, axis, keep_dims):
443
+ """calc reduced shape"""
444
+ out_shape = []
445
+ for i, s in enumerate(shape):
446
+ if i in axis:
447
+ if keep_dims:
448
+ out_shape.append(1)
449
+ else:
450
+ out_shape.append(s)
451
+ return out_shape
452
+
453
+ def _get_axis(self, rank):
454
+ axis_input = self.input_desc[1][VALUE]
455
+ axis = []
456
+ if isinstance(axis_input, int):
457
+ axis = [axis_input + rank if axis_input < 0 else axis_input]
458
+ else:
459
+ axis = [i + rank if i < 0 else i for i in axis_input]
460
+ return axis
461
+
462
+ def supported_type(self):
463
+ in_type = self.input_desc[0][DATA_TYPE]
464
+ if in_type == FLOAT16:
465
+ return ["float16,int64,float16", "float32,int64,float32"]
466
+ if in_type == FLOAT32:
467
+ return ["float32,int64,float32"]
468
+ io_type = ",".join([in_type, "int64", in_type])
469
+ return [io_type]
470
+
471
+ def supported_format(self):
472
+ supported_formats = ["ND,DefaultFormat,ND"]
473
+ shape = self.input_desc[0][SHAPE]
474
+ rank = len(shape)
475
+ axis = self._get_axis(rank)
476
+ if self.is_nz(shape):
477
+ if self._out_nz(rank, axis):
478
+ supported_formats.append("FRACTAL_NZ,DefaultFormat,FRACTAL_NZ")
479
+ return supported_formats
480
+
481
+ def infer_shape(self):
482
+ ori_format, cur_format = self.input_desc[0][ORI_FORMAT], self.input_desc[0][FORMAT]
483
+ if cur_format == FRACTAL_NZ and cur_format != ori_format:
484
+ ori_shape, cur_shape = self.input_desc[0][ORI_SHAPE], self.input_desc[0][SHAPE]
485
+ ori_rank = len(ori_shape)
486
+ rank = len(cur_shape)
487
+ axis = self._get_axis(ori_rank)
488
+ new_axis = []
489
+ for i in axis:
490
+ if i == ori_rank - 1:
491
+ new_axis.extend([rank - 4, rank - 1])
492
+ elif i == ori_rank - 2:
493
+ new_axis.extend([rank - 3, rank - 2])
494
+ else:
495
+ new_axis.append(i)
496
+ self.input_desc[1][VALUE] = new_axis
497
+ self.input_desc[1][SHAPE] = [len(new_axis)]
498
+ self.output_desc[0][SHAPE] = self._reduced_shape(cur_shape, new_axis, self.get_attr("keep_dims"))
499
+ else:
500
+ self.output_desc[0][SHAPE] = self.output_desc[0][ORI_SHAPE]
501
+
502
+ def infer_ori_shape(self):
503
+ shape = self.input_desc[0][ORI_SHAPE]
504
+ rank = len(shape)
505
+ axis = self._get_axis(rank)
506
+ self.output_desc[0][ORI_SHAPE] = self._reduced_shape(shape, axis, self.get_attr("keep_dims"))
507
+
508
+
509
+ class Reshape(OpInfer):
510
+ """Reshape op."""
511
+
512
+ def supported_format(self):
513
+ return ["ND,DefaultFormat,ND"]
514
+
515
+ def infer_shape(self):
516
+ """Reshape keeps ND format, so the output shape will not be changed"""
517
+ self.output_desc[0][SHAPE] = self.output_desc[0][ORI_SHAPE]
518
+
519
+ def infer_ori_shape(self):
520
+ shape = self.input_desc[0][ORI_SHAPE]
521
+ out_shape = copy_shape(self.input_desc[1][VALUE])
522
+ if -1 in out_shape:
523
+ idx = out_shape.index(-1)
524
+ tmp = []
525
+ for _, s in enumerate(out_shape):
526
+ if s != -1:
527
+ tmp.append(s)
528
+ if len(tmp) + 1 != len(out_shape):
529
+ raise ValueError("Find multiple -1 in attr 'shape' {}".format(out_shape))
530
+ tmp_sz = functools.reduce(lambda x, y: x * y, tmp, 1)
531
+ out_shape[idx] = functools.reduce(lambda x, y: x * y, shape, 1) // tmp_sz
532
+ self.output_desc[0][ORI_SHAPE] = out_shape
533
+
534
+ def post_process(self):
535
+ self.input_desc[1]["ori_value"] = self.input_desc[1][VALUE]
536
+ self.input_desc[1][VALUE] = self.output_desc[0][SHAPE]
537
+
538
+
539
+ class ExpandDimAndSqueeze(Reshape):
540
+ def copy_axis(self, axis):
541
+ out_axis = []
542
+ if isinstance(axis, int):
543
+ out_axis.append(axis)
544
+ else:
545
+ out_axis = copy.deepcopy(axis)
546
+ return out_axis
547
+
548
+
549
+ class Squeeze(ExpandDimAndSqueeze):
550
+ def infer_ori_shape(self):
551
+ axis = self.copy_axis(self.input_desc[1][VALUE])
552
+ input_shape = copy_shape(self.input_desc[0][SHAPE])
553
+ for idx in axis:
554
+ if input_shape[idx] != 1:
555
+ raise ValueError("The value of attr 'axis' is wrong , the squeezed axis must be 1, but got {}. 'axis': "
556
+ "{}, input shape: {}".format(input_shape[idx], axis, input_shape))
557
+ input_shape.pop(idx)
558
+ self.output_desc[0][ORI_SHAPE] = input_shape
559
+
560
+
561
+ class ExpandDim(ExpandDimAndSqueeze):
562
+ def infer_ori_shape(self):
563
+ axis = self.copy_axis(self.input_desc[1][VALUE])
564
+ input_shape = copy_shape(self.input_desc[0][SHAPE])
565
+ for idx in axis:
566
+ input_shape.insert(idx, 1)
567
+ self.output_desc[0][ORI_SHAPE] = input_shape
568
+
569
+
570
+ class BroadcastTo(OpInfer):
571
+ """BroadcastTo op."""
572
+
573
+ def supported_format(self):
574
+ io_format = ["ND"] * len(self.input_desc)
575
+ return [",".join(io_format)]
576
+
577
+ def infer_shape(self):
578
+ """Broadcast op keeps ND format, so the output shape will not be changed"""
579
+ self.output_desc[0][SHAPE] = self.output_desc[0][ORI_SHAPE]
580
+
581
+ def infer_ori_shape(self):
582
+ shape = self.input_desc[0][ORI_SHAPE]
583
+ broad_shape = self.get_attr(SHAPE) if SHAPE in self.attr else self.input_desc[1][VALUE]
584
+ if len(broad_shape) < len(shape):
585
+ raise ValueError("The length of attr 'shape' must be >= the length of input shape, but got attr 'shape': "
586
+ "{}, input shape: {}".format(broad_shape, shape))
587
+ pad_shape = [1] * (len(broad_shape) - len(shape)) + shape
588
+ out_shape = []
589
+ for i, b in enumerate(broad_shape):
590
+ if b == -1:
591
+ out_shape.append(pad_shape[i])
592
+ else:
593
+ out_shape.append(b)
594
+ self.output_desc[0][ORI_SHAPE] = out_shape
595
+
596
+ def post_process(self):
597
+ if not isinstance(self.op_desc.get(ATTR), list):
598
+ return
599
+ for item in self.op_desc[ATTR]:
600
+ if item[NAME] == SHAPE:
601
+ item["ori_value"] = item[VALUE]
602
+ item[VALUE] = self.output_desc[0][SHAPE]
603
+
604
+
605
+ class Tile(OpInfer):
606
+ """BroadcastTo op."""
607
+
608
+ def supported_format(self):
609
+ return ["ND,ND"]
610
+
611
+ def infer_shape(self):
612
+ """Tile op keeps ND format, so the output shape will not be changed"""
613
+ self.output_desc[0][SHAPE] = self.output_desc[0][ORI_SHAPE]
614
+
615
+ def infer_ori_shape(self):
616
+ shape = self.input_desc[0][ORI_SHAPE]
617
+ multiples = self.input_desc[1][VALUE]
618
+ if len(multiples) < len(shape):
619
+ raise ValueError("The length of attr 'multiples' must be >= the length of input shape, but got attr "
620
+ "'multiples': {}, input shape: {}".format(multiples, shape))
621
+ pad_shape = [1] * (len(multiples) - len(shape)) + shape
622
+ out_shape = []
623
+ for i, m in enumerate(multiples):
624
+ out_shape.append(m * pad_shape[i])
625
+ self.output_desc[0][ORI_SHAPE] = out_shape
626
+
627
+
628
+ class PagedAttention(OpInfer):
629
+ """PagedAttention"""
630
+
631
+ def supported_format(self):
632
+ return ["ND,ND,ND,ND,ND,ND"]
633
+
634
+ def infer_shape(self):
635
+ """PagedAttention op keeps ND format, so the output shape will not be changed"""
636
+ self.output_desc[0]["shape"] = self.output_desc[0]["ori_shape"]
637
+
638
+ def infer_ori_shape(self):
639
+ self.output_desc[0]["ori_shape"] = self.input_desc[0]["ori_shape"]
640
+
641
+
642
+ class ReshapeAndCache(OpInfer):
643
+ """ReshapeAndCache"""
644
+
645
+ def supported_format(self):
646
+ return ["ND,ND,ND,ND,ND,ND"]
647
+
648
+ def infer_shape(self):
649
+ """ReshapeAndCache op keeps ND format, so the output shape will not be changed"""
650
+ self.output_desc[0]["shape"] = self.output_desc[0]["ori_shape"]
651
+
652
+ def infer_ori_shape(self):
653
+ self.output_desc[0]["ori_shape"] = self.input_desc[0]["ori_shape"]
654
+
655
+
656
+ class PagedAttentionMask(PagedAttention):
657
+ """PagedAttentionMask"""
658
+
659
+ def supported_format(self):
660
+ return ["ND,ND,ND,ND,ND,ND,ND"]
661
+
662
+
663
+ # Ge will convert dtype bool to int8, and ReLU will be expand to Greater op in expander,
664
+ # and the dtype of Greater op is bool, which is incompatible with bool.
665
+ # As a result akg will rise error when parsing Greater op with dtype int8.
666
+ # Expand And Sequeeze op will be expanded into Reshape op in expander,
667
+ # but in dynamic shape scenario, the meaning of -1 in Reshape op different from -1 in Expand And Sequeeze op.
668
+ # So this will lead to infer shape error.
669
+ # To solve these problems we need to cluster these ops in to subgraph and update info file here.
670
+ prims = {
671
+ "Abs": Elemwise,
672
+ "Neg": Elemwise,
673
+ "Sqrt": Elemwise,
674
+ "Rsqrt": Elemwise,
675
+ "Reciprocal": Elemwise,
676
+ "FastGeLU": Elemwise,
677
+ "Round": Elemwise,
678
+ "Assign": ElemwiseBinaryNoBroadcast,
679
+ "Add": ElemwiseBinary,
680
+ "Sub": ElemwiseBinary,
681
+ "Mul": ElemwiseBinary,
682
+ "Div": ElemwiseBinary,
683
+ "Mod": ElemwiseBinary,
684
+ "RealDiv": ElemwiseBinary,
685
+ "Maximum": ElemwiseBinary,
686
+ "Minimum": ElemwiseBinary,
687
+ "MatMul": MatMul,
688
+ "BatchMatMul": BatchMatMul,
689
+ "ReduceSum": Reduce,
690
+ "Reshape": Reshape,
691
+ "ExpandDims": ExpandDim,
692
+ "Squeeze": Squeeze,
693
+ "BroadcastTo": BroadcastTo,
694
+ "Tile": Tile,
695
+ "Log": Elemwise,
696
+ "Exp": Elemwise,
697
+ "Pow": Elemwise,
698
+ "Sign": Elemwise,
699
+ "ReLU": Elemwise,
700
+ "Tanh": Elemwise,
701
+ "ReduceMax": Reduce,
702
+ "ReduceMin": Reduce,
703
+ "Cast": Cast,
704
+ "PagedAttention": PagedAttention,
705
+ "PagedAttentionMask": PagedAttentionMask,
706
+ "ReshapeAndCache": ReshapeAndCache,
707
+ }
708
+
709
+
710
+ def convert_to_default_format(desc):
711
+ """Convert to DefaultFormat"""
712
+ default_format = ["ND", "NCHW", "NHWC", "HWCN", DEFAULT_FORMAT]
713
+ for _, input_desc in enumerate(desc[INPUT_DESC]):
714
+ if input_desc[0][FORMAT] in default_format:
715
+ input_desc[0][FORMAT] = DEFAULT_FORMAT
716
+ if not input_desc[0][SHAPE]:
717
+ input_desc[0][SHAPE] = [1]
718
+ for _, op_desc in enumerate(desc[OP_DESC]):
719
+ for _, input_desc in enumerate(op_desc[INPUT_DESC]):
720
+ if input_desc[0][FORMAT] in default_format:
721
+ input_desc[0][FORMAT] = DEFAULT_FORMAT
722
+ if not input_desc[0][SHAPE]:
723
+ input_desc[0][SHAPE] = [1]
724
+ for _, output_desc in enumerate(op_desc[OUTPUT_DESC]):
725
+ if output_desc[FORMAT] in default_format:
726
+ output_desc[FORMAT] = DEFAULT_FORMAT
727
+ if not output_desc[SHAPE]:
728
+ output_desc[SHAPE] = [1]
729
+ for _, output_desc in enumerate(desc[OUTPUT_DESC]):
730
+ if output_desc[FORMAT] in default_format:
731
+ output_desc[FORMAT] = DEFAULT_FORMAT
732
+ if not output_desc[SHAPE]:
733
+ output_desc[SHAPE] = [1]
734
+
735
+
736
+ def update_global_input_desc(info_desc, args):
737
+ """Update the global input of the fused info file"""
738
+
739
+ def _convert_tbe_type(tbe_type, ori_type):
740
+ if tbe_type == "float":
741
+ return FLOAT32
742
+ if tbe_type == "int8" and ori_type == "bool":
743
+ # GE pass int8 here if data type is bool, but we must return bool back to GE, otherwise GE will
744
+ # raise an error "current op does not support bool"
745
+ return ori_type
746
+ return tbe_type
747
+
748
+ def _covert_tbe_shape(tbe_shape):
749
+ if not tbe_shape:
750
+ return [1]
751
+ return copy_shape(tbe_shape)
752
+
753
+ if isinstance(info_desc.get(INPUT_DESC), list):
754
+ for i, desc in enumerate(info_desc[INPUT_DESC]):
755
+ desc[0][ORI_DATA_TYPE] = desc[0][DATA_TYPE]
756
+ desc[0][DATA_TYPE] = _convert_tbe_type(args[i]["dtype"], desc[0][ORI_DATA_TYPE])
757
+ desc[0][ORI_FORMAT] = args[i].get(ORI_FORMAT, desc[0][FORMAT])
758
+ desc[0][FORMAT] = args[i][FORMAT]
759
+ desc[0][ORI_SHAPE] = _covert_tbe_shape(args[i].get(ORI_SHAPE, desc[0][SHAPE]))
760
+ desc[0][SHAPE] = list(args[i][SHAPE])
761
+
762
+
763
+ def update_global_output_desc(info_desc, tensor_desc):
764
+ """Update the global output of the fused info file"""
765
+ for i, desc in enumerate(info_desc[OUTPUT_DESC]):
766
+ tensor_name = desc[TENSOR_NAME]
767
+ if tensor_name not in tensor_desc:
768
+ raise RuntimeError("tensor '{}' not exist in op_desc".format(tensor_name))
769
+ info_desc[OUTPUT_DESC][i] = tensor_desc[tensor_name]
770
+
771
+
772
+ def update_op_input_desc(op_desc, tensor_desc):
773
+ """Update the input of operator"""
774
+ if not isinstance(op_desc.get(INPUT_DESC), list):
775
+ return
776
+ inputs_type_orig = []
777
+ inputs_type = []
778
+ const_inputs_idx = []
779
+ for i, desc in enumerate(op_desc[INPUT_DESC]):
780
+ for j, item in enumerate(desc):
781
+ if VALUE in item:
782
+ inputs_type_orig.append(None)
783
+ inputs_type.append(None)
784
+ const_inputs_idx.append(i)
785
+ item[ORI_DATA_TYPE] = item[DATA_TYPE]
786
+ item[ORI_FORMAT] = item[FORMAT]
787
+ item[ORI_SHAPE] = copy_shape(item[SHAPE])
788
+ else:
789
+ inputs_type_orig.append(item[DATA_TYPE])
790
+ tensor_name = item[TENSOR_NAME]
791
+ if tensor_name not in tensor_desc:
792
+ raise RuntimeError("tensor '{}' used without initialization".format(tensor_name))
793
+ # update op input
794
+ desc[j] = tensor_desc[tensor_name]
795
+ inputs_type.append(tensor_desc[tensor_name][DATA_TYPE])
796
+ # update op const input's data type
797
+ for _, idx in enumerate(const_inputs_idx):
798
+ const_value_type = op_desc[INPUT_DESC][idx][0][DATA_TYPE]
799
+ if const_value_type in inputs_type_orig:
800
+ op_desc[INPUT_DESC][idx][0][DATA_TYPE] = inputs_type[inputs_type_orig.index(const_value_type)]
801
+ # cache op const input
802
+ tensor_desc[op_desc[INPUT_DESC][idx][0][TENSOR_NAME]] = op_desc[INPUT_DESC][idx][0]
803
+
804
+
805
+ def cache_input_tensors(tensor_desc, input_desc):
806
+ """Cache input tensor desc"""
807
+ if isinstance(input_desc, list):
808
+ for desc in input_desc:
809
+ for item in desc:
810
+ tensor_desc[item[TENSOR_NAME]] = item
811
+
812
+
813
+ def cache_output_tensors(tensor_desc, output_desc):
814
+ """Cache output tensor desc"""
815
+ for item in output_desc:
816
+ tensor_desc[item[TENSOR_NAME]] = item
817
+
818
+
819
+ def save(filename, contents):
820
+ """Save to file"""
821
+ with os.fdopen(os.open(os.path.realpath(filename), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o660), 'w') as f:
822
+ f.write(contents)
823
+
824
+
825
+ def update_akg_info(args, info_path, kernel_name=None):
826
+ """Update akg info base on the current inputs provided by GE"""
827
+ with open(os.path.realpath(info_path), 'r') as f:
828
+ info_str = f.read()
829
+ desc = json.loads(info_str)
830
+ desc["op_ori"] = desc[OP]
831
+ desc[OP] = kernel_name if kernel_name else desc[OP]
832
+ tensor_desc = {} # {tensor_name: tensor_desc}
833
+
834
+ # Update input_desc
835
+ update_global_input_desc(desc, args)
836
+ # cache global input
837
+ cache_input_tensors(tensor_desc, desc.get(INPUT_DESC))
838
+ # Update info global config
839
+ InfoGlobalConfig.enable_cce_lib = desc.get("enable_cce_lib")
840
+ target_info = desc.get("target_info")
841
+ if target_info is not None:
842
+ InfoGlobalConfig.ascend_arch = target_info.get("arch")
843
+
844
+ # Update op_desc
845
+ for _, op_desc in enumerate(desc[OP_DESC]):
846
+ update_op_input_desc(op_desc, tensor_desc)
847
+ op_name = op_desc[NAME]
848
+ if op_name not in prims:
849
+ raise KeyError("Not supported op: {}".format(op_name))
850
+ prim = prims.get(op_name)(op_desc)
851
+ prim.update()
852
+ # cache op output
853
+ cache_output_tensors(tensor_desc, op_desc[OUTPUT_DESC])
854
+
855
+ # Update output_desc
856
+ update_global_output_desc(desc, tensor_desc)
857
+
858
+ # Update data format to DefaultFormat
859
+ convert_to_default_format(desc)
860
+
861
+ # GE backend must use old CCE
862
+ desc["backend"] = "GE"
863
+
864
+ return desc
865
+
866
+
867
+ def save_updated_akg_info(*args):
868
+ """Save the updated akg info."""
869
+ info_path = args[-2]
870
+ kernel_name = args[-1]
871
+ if not isinstance(info_path, str):
872
+ # in this case, kernel_name is not passed by GE, skip compiling
873
+ return ""
874
+ updated_desc = update_akg_info(args, info_path, kernel_name)
875
+ real_info_path = os.path.join(os.path.realpath(os.path.dirname(info_path)), kernel_name + ".info")
876
+ # Save the updated info file
877
+ save(real_info_path, json.dumps(updated_desc))
878
+ return real_info_path
879
+
880
+
881
+ def create_dirs(*dirs):
882
+ """Create directories."""
883
+ for d in dirs:
884
+ if not os.path.isdir(d):
885
+ try:
886
+ os.makedirs(d)
887
+ except OSError as err:
888
+ # File exists
889
+ if err.errno == 17:
890
+ pass
891
+ else:
892
+ raise err
893
+
894
+
895
+ def copy_file(src_path, dst_path):
896
+ """Copy file to dst."""
897
+ try:
898
+ if os.path.isfile(dst_path):
899
+ os.remove(dst_path)
900
+ except OSError:
901
+ pass
902
+
903
+ try:
904
+ shutil.copy(src_path, dst_path)
905
+ except PermissionError:
906
+ # If dst_path already exits and only has READ permission
907
+ pass
908
+
909
+
910
+ def _compile_subprocess(kernel_meta_dirs, info_path, is_lite=True, compile_backend=None, attrs=None):
911
+ """Use a new process to compile info."""
912
+ kernel_meta_parent_dir, kernel_meta_dir = kernel_meta_dirs
913
+ my_env = os.environ
914
+ my_env["MS_COMPILER_CACHE_PATH"] = kernel_meta_parent_dir
915
+ my_env["KERNEL_META_DIR"] = kernel_meta_dir
916
+ compiler = os.path.join(os.path.split(os.path.realpath(__file__))[0], "compiler.py")
917
+ if is_lite:
918
+ run_args = [sys.executable, compiler, info_path]
919
+ else:
920
+ run_args = [sys.executable, compiler, info_path, compile_backend, attrs, kernel_meta_parent_dir]
921
+ compile_result = subprocess.run(run_args, text=True, check=False, capture_output=True, env=my_env)
922
+ return compile_result
923
+
924
+
925
+ def search_supported_types_formats(info):
926
+ """Get the supported data types and formats of the fused info file"""
927
+
928
+ class DfsSearcher:
929
+ """Use DFS"""
930
+
931
+ def __init__(self, top_io_names, ops_desc):
932
+ self.supported_types = []
933
+ self.supported_formats = []
934
+ self.top_io_names = top_io_names
935
+ self.tensor_types = {}
936
+ self.tensor_formats = {}
937
+ self.ops_desc = ops_desc
938
+ self.cache = []
939
+
940
+ def set_current_format(self, cur_format, io_names):
941
+ """set tensor format"""
942
+ for i, fmt in enumerate(cur_format):
943
+ if self.tensor_formats.get(io_names[i], fmt) != fmt:
944
+ return False
945
+ self.tensor_formats[io_names[i]] = fmt
946
+ return True
947
+
948
+ def set_current_type(self, cur_type, io_names):
949
+ """set tensor data type"""
950
+ for i, data_type in enumerate(cur_type):
951
+ if self.tensor_types.get(io_names[i], data_type) != data_type:
952
+ return False
953
+ self.tensor_types[io_names[i]] = data_type
954
+ return True
955
+
956
+ def get_desc(self, opid):
957
+ """get desc"""
958
+ if opid < len(self.cache):
959
+ return self.cache[opid]
960
+ desc = self.ops_desc[opid]
961
+ io_names = [item[TENSOR_NAME] for desc in desc[INPUT_DESC] for item in desc]
962
+ io_names.append(desc[OUTPUT_DESC][0][TENSOR_NAME])
963
+ op_name = desc[NAME]
964
+ if op_name not in prims:
965
+ raise KeyError("Not supported op: {}".format(op_name))
966
+ prim = prims.get(op_name)(desc)
967
+ io_formats = [f.split(",") for f in prim.supported_format()]
968
+ io_types = [t.split(",") for t in prim.supported_type()]
969
+ self.cache.append((io_formats, io_types, tuple(io_names)))
970
+ return self.cache[-1]
971
+
972
+ def search_types(self, opid):
973
+ """search the supported types"""
974
+ if opid == len(self.ops_desc):
975
+ top_tensor_types = tuple(self.tensor_types.get(t) for t in self.top_io_names)
976
+ self.supported_types.append(top_tensor_types)
977
+ return
978
+ _, op_io_types, io_names = self.get_desc(opid)
979
+ for cur_type in op_io_types:
980
+ bak_tensor_types = copy.deepcopy(self.tensor_types)
981
+ if self.set_current_type(cur_type, io_names):
982
+ self.search_types(opid + 1)
983
+ self.tensor_types = bak_tensor_types
984
+
985
+ def search_formats(self, opid):
986
+ """search the supported formats"""
987
+ if opid == len(self.ops_desc):
988
+ top_tensor_formats = tuple(self.tensor_formats.get(t) for t in self.top_io_names)
989
+ self.supported_formats.append(top_tensor_formats)
990
+ return
991
+ op_io_formats, _, io_names = self.get_desc(opid)
992
+ for cur_format in op_io_formats:
993
+ bak_tensor_formats = copy.deepcopy(self.tensor_formats)
994
+ if self.set_current_format(cur_format, io_names):
995
+ self.search_formats(opid + 1)
996
+ self.tensor_formats = bak_tensor_formats
997
+
998
+ def remove_dup(data):
999
+ res = []
1000
+ data_str = []
1001
+ for _, t in enumerate(data):
1002
+ t_str = ",".join(t)
1003
+ if t_str not in data_str:
1004
+ data_str.append(t_str)
1005
+ res.append(t)
1006
+ return res
1007
+
1008
+ top_io_names = [t[0][TENSOR_NAME] for t in info[INPUT_DESC]] + [t[TENSOR_NAME] for t in info[OUTPUT_DESC]]
1009
+ handle = DfsSearcher(top_io_names, info[OP_DESC])
1010
+ handle.search_types(0)
1011
+ handle.search_formats(0)
1012
+ return remove_dup(handle.supported_types), remove_dup(handle.supported_formats)
1013
+
1014
+
1015
+ def op_select_format(*args, **kwags):
1016
+ """Entrance for format/data type selection, will invoked by GE"""
1017
+ info_path = args[-1]
1018
+ desc = update_akg_info(args, info_path)
1019
+ supported_io_type, supported_io_format = search_supported_types_formats(desc)
1020
+ if not supported_io_type or not supported_io_format:
1021
+ raise RuntimeError("Select format failed for info: {}".format(info_path))
1022
+ input_num = len(desc[INPUT_DESC])
1023
+ output_num = len(desc[OUTPUT_DESC])
1024
+ param_list = []
1025
+ for i in range(input_num + output_num):
1026
+ dtype_list = [item[i] for item in supported_io_type] * len(supported_io_format)
1027
+ format_list = functools.reduce(lambda x, y: x + y,
1028
+ [[item[i]] * len(supported_io_type) for item in supported_io_format])
1029
+ classify = "input" + str(i) if i < input_num else "output" + str(i - input_num)
1030
+ name = "x" + str(i) if i < input_num else "y" + str(i - input_num)
1031
+ param = gen_param(classify=classify,
1032
+ name=name,
1033
+ datatype=",".join(dtype_list),
1034
+ format=",".join(format_list))
1035
+ param_list.append(param)
1036
+ param_dynamic_in_json = get_dynamic_param_in_json(param_list)
1037
+ return param_dynamic_in_json
1038
+
1039
+
1040
+ def custom(*args, **kwags):
1041
+ """Entrance for akg info compiling, will invoked by GE"""
1042
+ kernel_name = args[-1]
1043
+ real_info_path = save_updated_akg_info(*args)
1044
+ if not real_info_path:
1045
+ return
1046
+ kernel_meta_parent_dir = get_current_build_config("kernel_meta_parent_dir")
1047
+ kernel_meta_dir = "kernel_meta"
1048
+ compile_result = _compile_subprocess([kernel_meta_parent_dir, kernel_meta_dir], real_info_path, is_lite=True)
1049
+ json_path = os.path.join(kernel_meta_parent_dir, kernel_meta_dir, kernel_name + JSON_SUFFIX)
1050
+ if compile_result.returncode or not os.path.exists(json_path):
1051
+ raise RuntimeError("Compile {} failed! Detailed compile message: {}, {}"
1052
+ .format(kernel_name, compile_result.stdout.strip(), compile_result.stderr.strip()))
1053
+
1054
+
1055
+ def custom_train(*args, **kwags):
1056
+ """Entrance for akg info compiling, will invoked by GE"""
1057
+
1058
+ def _get_optimized_info_path():
1059
+ """Get the info optimized by akg."""
1060
+ target_info = "target_info"
1061
+ file_path = os.path.join(composite_graph_dir, kernel_name + ".info")
1062
+ if not os.path.isfile(file_path):
1063
+ return real_info_path
1064
+ with open(os.path.realpath(real_info_path), 'r') as f:
1065
+ desc = json.loads(f.read())
1066
+ if target_info in desc:
1067
+ with open(os.path.realpath(file_path), 'r') as fo:
1068
+ info_desc = json.loads(fo.read())
1069
+ info_desc[target_info] = desc[target_info]
1070
+ save(file_path, json.dumps(info_desc))
1071
+ return file_path
1072
+
1073
+ info_path = args[-2]
1074
+ kernel_name = args[-1]
1075
+ real_info_path = save_updated_akg_info(*args)
1076
+ if not real_info_path:
1077
+ return
1078
+ info_dir = os.path.realpath(os.path.dirname(info_path))
1079
+ kernel_meta_parent_dir = get_current_build_config("kernel_meta_parent_dir")
1080
+ kernel_meta = "kernel_meta"
1081
+ kernel_meta_dir = os.path.join(kernel_meta_parent_dir, kernel_meta)
1082
+ akg_compile_dir = os.path.join(info_dir, "akg")
1083
+ tbe_compile_dir = os.path.join(info_dir, "tbe")
1084
+ composite_graph_dir = os.path.join(info_dir, "composite") # save akg optimized info
1085
+ akg_kernel_meta_dir = os.path.join(akg_compile_dir, kernel_meta) # save akg compile result
1086
+ tbe_kernel_meta_dir = os.path.join(tbe_compile_dir, kernel_meta) # save tbe compile result
1087
+ create_dirs(kernel_meta_dir, composite_graph_dir, akg_kernel_meta_dir, tbe_kernel_meta_dir)
1088
+ # Compile with AKG
1089
+ attr = {"dump_composite_graph": composite_graph_dir, "optimize_for_tbe": True}
1090
+ attrs = json.dumps(attr)
1091
+ akg_compile_result = _compile_subprocess([akg_compile_dir, kernel_meta], real_info_path,
1092
+ is_lite=False, compile_backend="AKG", attrs=attrs)
1093
+ json_path = os.path.join(akg_kernel_meta_dir, kernel_name + JSON_SUFFIX)
1094
+ o_path = os.path.join(akg_kernel_meta_dir, kernel_name + O_SUFFIX)
1095
+ if not os.path.exists(json_path):
1096
+ # Compile with TBE
1097
+ optimized_info_path = _get_optimized_info_path()
1098
+ tbe_compile_result = _compile_subprocess([tbe_compile_dir, kernel_meta], optimized_info_path,
1099
+ is_lite=False, compile_backend="TBE", attrs=attrs)
1100
+ json_path = os.path.join(tbe_kernel_meta_dir, kernel_name + JSON_SUFFIX)
1101
+ o_path = os.path.join(tbe_kernel_meta_dir, kernel_name + O_SUFFIX)
1102
+ if not os.path.exists(json_path):
1103
+ raise RuntimeError("Compile {} failed! Detailed akg compile message: {}, {}\n"
1104
+ "Detailed tbe compile message: {}, {}"
1105
+ .format(kernel_name,
1106
+ akg_compile_result.stdout.strip(), akg_compile_result.stderr.strip(),
1107
+ tbe_compile_result.stdout.strip(), tbe_compile_result.stderr.strip()))
1108
+ copy_file(json_path, os.path.join(kernel_meta_dir, kernel_name + JSON_SUFFIX))
1109
+ copy_file(o_path, os.path.join(kernel_meta_dir, kernel_name + O_SUFFIX))