mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (488) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -1
  7. mindspore/_checkparam.py +23 -29
  8. mindspore/_extends/graph_kernel/__init__.py +0 -1
  9. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  10. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  11. mindspore/_extends/graph_kernel/splitter.py +4 -11
  12. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  13. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  14. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  15. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  16. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  17. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  20. mindspore/_extends/parse/__init__.py +13 -15
  21. mindspore/_extends/parse/namespace.py +7 -33
  22. mindspore/_extends/parse/parser.py +67 -72
  23. mindspore/_extends/parse/resources.py +1 -1
  24. mindspore/_extends/parse/standard_method.py +86 -106
  25. mindspore/_extends/parse/trope.py +1 -1
  26. mindspore/_extends/remote/kernel_build_server.py +25 -7
  27. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  28. mindspore/_install_custom.py +43 -0
  29. mindspore/amp.py +47 -11
  30. mindspore/boost/boost.py +1 -8
  31. mindspore/boost/boost_cell_wrapper.py +3 -2
  32. mindspore/boost/grad_accumulation.py +1 -1
  33. mindspore/boost/group_loss_scale_manager.py +8 -7
  34. mindspore/common/__init__.py +5 -3
  35. mindspore/common/_jit_fallback_utils.py +6 -0
  36. mindspore/common/_register_for_adapter.py +2 -0
  37. mindspore/common/_register_for_tensor.py +2 -2
  38. mindspore/common/_stub_tensor.py +13 -0
  39. mindspore/common/_utils.py +29 -0
  40. mindspore/common/api.py +174 -259
  41. mindspore/common/auto_dynamic_shape.py +494 -0
  42. mindspore/common/dtype.py +18 -11
  43. mindspore/common/dump.py +6 -4
  44. mindspore/common/initializer.py +14 -14
  45. mindspore/common/jit_config.py +33 -15
  46. mindspore/common/lazy_inline.py +126 -7
  47. mindspore/common/mindir_util.py +101 -0
  48. mindspore/common/parameter.py +51 -41
  49. mindspore/common/seed.py +4 -4
  50. mindspore/common/sparse_tensor.py +13 -14
  51. mindspore/common/tensor.py +243 -165
  52. mindspore/communication/__init__.py +7 -4
  53. mindspore/communication/_comm_helper.py +83 -4
  54. mindspore/communication/management.py +152 -84
  55. mindspore/config/op_info.config +14 -3
  56. mindspore/context.py +152 -61
  57. mindspore/dataset/__init__.py +5 -5
  58. mindspore/dataset/audio/__init__.py +2 -2
  59. mindspore/dataset/audio/transforms.py +52 -52
  60. mindspore/dataset/callback/ds_callback.py +16 -2
  61. mindspore/dataset/core/config.py +68 -51
  62. mindspore/dataset/engine/cache_client.py +33 -7
  63. mindspore/dataset/engine/datasets.py +250 -112
  64. mindspore/dataset/engine/datasets_audio.py +43 -211
  65. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  66. mindspore/dataset/engine/datasets_text.py +43 -67
  67. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  68. mindspore/dataset/engine/datasets_vision.py +219 -1029
  69. mindspore/dataset/engine/iterators.py +11 -4
  70. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  71. mindspore/dataset/engine/obs/util.py +3 -0
  72. mindspore/dataset/engine/samplers.py +1 -1
  73. mindspore/dataset/engine/validators.py +19 -5
  74. mindspore/dataset/text/__init__.py +3 -3
  75. mindspore/dataset/text/transforms.py +101 -127
  76. mindspore/dataset/text/utils.py +205 -138
  77. mindspore/dataset/transforms/__init__.py +1 -1
  78. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  79. mindspore/dataset/transforms/transforms.py +95 -40
  80. mindspore/dataset/utils/browse_dataset.py +8 -2
  81. mindspore/dataset/utils/line_reader.py +17 -19
  82. mindspore/dataset/vision/__init__.py +3 -3
  83. mindspore/dataset/vision/c_transforms.py +6 -3
  84. mindspore/dataset/vision/transforms.py +409 -287
  85. mindspore/dataset/vision/utils.py +13 -14
  86. mindspore/dataset/vision/validators.py +11 -1
  87. mindspore/dnnl.dll +0 -0
  88. mindspore/experimental/map_parameter.py +14 -0
  89. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  90. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  91. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  92. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  93. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  94. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  95. mindspore/gen_ops.py +273 -0
  96. mindspore/include/OWNERS +0 -1
  97. mindspore/include/api/data_type.h +2 -1
  98. mindspore/include/api/graph.h +0 -15
  99. mindspore/include/api/kernel.h +2 -0
  100. mindspore/include/api/kernel_api.h +37 -12
  101. mindspore/include/api/model.h +17 -14
  102. mindspore/include/api/status.h +8 -3
  103. mindspore/include/api/types.h +37 -4
  104. mindspore/include/c_api/ms/abstract.h +67 -0
  105. mindspore/include/c_api/ms/attribute.h +197 -0
  106. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  107. mindspore/include/c_api/ms/base/macros.h +32 -0
  108. mindspore/include/c_api/ms/base/status.h +33 -0
  109. mindspore/include/c_api/ms/base/types.h +282 -0
  110. mindspore/include/c_api/ms/context.h +102 -0
  111. mindspore/include/c_api/ms/graph.h +160 -0
  112. mindspore/include/c_api/ms/node.h +606 -0
  113. mindspore/include/c_api/ms/tensor.h +161 -0
  114. mindspore/include/c_api/ms/value.h +84 -0
  115. mindspore/include/dataset/constants.h +6 -5
  116. mindspore/include/dataset/execute.h +23 -13
  117. mindspore/include/dataset/text.h +26 -26
  118. mindspore/include/dataset/transforms.h +13 -13
  119. mindspore/include/dataset/vision.h +60 -60
  120. mindspore/include/dataset/vision_ascend.h +5 -6
  121. mindspore/include/dataset/vision_lite.h +17 -17
  122. mindspore/jpeg62.dll +0 -0
  123. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  124. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  125. mindspore/mindspore_backend.dll +0 -0
  126. mindspore/mindspore_common.dll +0 -0
  127. mindspore/mindspore_core.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_shared_lib.dll +0 -0
  130. mindspore/nn/__init__.py +0 -2
  131. mindspore/nn/cell.py +313 -74
  132. mindspore/nn/dynamic_lr.py +21 -21
  133. mindspore/nn/layer/activation.py +22 -30
  134. mindspore/nn/layer/basic.py +15 -13
  135. mindspore/nn/layer/channel_shuffle.py +1 -1
  136. mindspore/nn/layer/container.py +271 -9
  137. mindspore/nn/layer/conv.py +323 -204
  138. mindspore/nn/layer/dense.py +8 -5
  139. mindspore/nn/layer/embedding.py +33 -27
  140. mindspore/nn/layer/flash_attention.py +61 -95
  141. mindspore/nn/layer/image.py +8 -6
  142. mindspore/nn/layer/math.py +16 -25
  143. mindspore/nn/layer/normalization.py +107 -66
  144. mindspore/nn/layer/padding.py +1 -1
  145. mindspore/nn/layer/pooling.py +131 -109
  146. mindspore/nn/layer/rnn_cells.py +27 -22
  147. mindspore/nn/layer/rnns.py +13 -16
  148. mindspore/nn/layer/thor_layer.py +1 -1
  149. mindspore/nn/layer/transformer.py +221 -154
  150. mindspore/nn/learning_rate_schedule.py +9 -1
  151. mindspore/nn/loss/loss.py +235 -174
  152. mindspore/nn/optim/ada_grad.py +2 -1
  153. mindspore/nn/optim/adadelta.py +1 -0
  154. mindspore/nn/optim/adafactor.py +2 -1
  155. mindspore/nn/optim/adam.py +7 -4
  156. mindspore/nn/optim/adamax.py +3 -2
  157. mindspore/nn/optim/adasum.py +2 -2
  158. mindspore/nn/optim/asgd.py +2 -3
  159. mindspore/nn/optim/ftrl.py +6 -5
  160. mindspore/nn/optim/lamb.py +7 -4
  161. mindspore/nn/optim/lars.py +1 -1
  162. mindspore/nn/optim/lazyadam.py +5 -3
  163. mindspore/nn/optim/momentum.py +2 -1
  164. mindspore/nn/optim/optimizer.py +53 -4
  165. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  166. mindspore/nn/optim/rmsprop.py +4 -3
  167. mindspore/nn/optim/rprop.py +23 -12
  168. mindspore/nn/optim/sgd.py +26 -11
  169. mindspore/nn/optim/thor.py +9 -7
  170. mindspore/nn/probability/bijector/bijector.py +5 -5
  171. mindspore/nn/probability/bijector/power_transform.py +27 -27
  172. mindspore/nn/probability/bijector/softplus.py +3 -3
  173. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  174. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  175. mindspore/nn/probability/distribution/beta.py +3 -3
  176. mindspore/nn/probability/distribution/categorical.py +7 -7
  177. mindspore/nn/probability/distribution/cauchy.py +0 -1
  178. mindspore/nn/probability/distribution/distribution.py +3 -3
  179. mindspore/nn/probability/distribution/gamma.py +3 -3
  180. mindspore/nn/probability/distribution/geometric.py +4 -4
  181. mindspore/nn/probability/distribution/gumbel.py +4 -4
  182. mindspore/nn/probability/distribution/log_normal.py +2 -2
  183. mindspore/nn/probability/distribution/logistic.py +2 -2
  184. mindspore/nn/probability/distribution/poisson.py +4 -4
  185. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  186. mindspore/nn/probability/distribution/uniform.py +6 -6
  187. mindspore/nn/wrap/__init__.py +4 -2
  188. mindspore/nn/wrap/cell_wrapper.py +87 -34
  189. mindspore/nn/wrap/grad_reducer.py +8 -5
  190. mindspore/nn/wrap/loss_scale.py +105 -42
  191. mindspore/numpy/array_creations.py +1 -2
  192. mindspore/numpy/array_ops.py +3 -2
  193. mindspore/numpy/utils_const.py +5 -5
  194. mindspore/opencv_core452.dll +0 -0
  195. mindspore/opencv_imgcodecs452.dll +0 -0
  196. mindspore/opencv_imgproc452.dll +0 -0
  197. mindspore/ops/_grad_experimental/__init__.py +0 -5
  198. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  199. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  200. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  201. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  202. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  203. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  204. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  205. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  206. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  207. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  208. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  209. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  210. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  211. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  212. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  213. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  214. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  215. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  216. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  217. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  218. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  219. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  220. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  221. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  222. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  223. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  224. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  225. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  226. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  227. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  228. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  229. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  230. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  231. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  232. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  233. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  234. mindspore/ops/_primitive_cache.py +1 -1
  235. mindspore/ops/_tracefunc.py +45 -13
  236. mindspore/ops/_utils/utils.py +6 -1
  237. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  238. mindspore/ops/_vmap/vmap_base.py +3 -3
  239. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  240. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  241. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  242. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  243. mindspore/ops/arg_dtype_cast.py +54 -0
  244. mindspore/ops/composite/base.py +37 -10
  245. mindspore/ops/composite/math_ops.py +5 -4
  246. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  247. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  248. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  249. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  250. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  251. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  252. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  253. mindspore/ops/deprecated.py +304 -0
  254. mindspore/ops/function/__init__.py +4 -1
  255. mindspore/ops/function/array_func.py +174 -193
  256. mindspore/ops/function/clip_func.py +81 -13
  257. mindspore/ops/function/debug_func.py +1 -1
  258. mindspore/ops/function/grad/grad_func.py +18 -9
  259. mindspore/ops/function/image_func.py +10 -4
  260. mindspore/ops/function/linalg_func.py +5 -5
  261. mindspore/ops/function/math_func.py +575 -386
  262. mindspore/ops/function/nn_func.py +568 -260
  263. mindspore/ops/function/random_func.py +88 -57
  264. mindspore/ops/function/sparse_func.py +1 -1
  265. mindspore/ops/function/sparse_unary_func.py +14 -12
  266. mindspore/ops/function/vmap_func.py +6 -5
  267. mindspore/ops/functional.py +15 -10
  268. mindspore/ops/op_info_register.py +244 -25
  269. mindspore/ops/operations/__init__.py +31 -19
  270. mindspore/ops/operations/_grad_ops.py +71 -7
  271. mindspore/ops/operations/_inner_ops.py +350 -17
  272. mindspore/ops/operations/_quant_ops.py +4 -8
  273. mindspore/ops/operations/_sequence_ops.py +42 -0
  274. mindspore/ops/operations/array_ops.py +68 -282
  275. mindspore/ops/operations/comm_ops.py +107 -59
  276. mindspore/ops/operations/custom_ops.py +94 -70
  277. mindspore/ops/operations/debug_ops.py +8 -4
  278. mindspore/ops/operations/image_ops.py +18 -12
  279. mindspore/ops/operations/inner_ops.py +26 -3
  280. mindspore/ops/operations/math_ops.py +192 -144
  281. mindspore/ops/operations/nn_ops.py +857 -489
  282. mindspore/ops/operations/other_ops.py +0 -22
  283. mindspore/ops/operations/random_ops.py +53 -111
  284. mindspore/ops/operations/sparse_ops.py +3 -1
  285. mindspore/ops/primitive.py +24 -18
  286. mindspore/parallel/_auto_parallel_context.py +68 -8
  287. mindspore/parallel/_cost_model_context.py +2 -2
  288. mindspore/parallel/_offload_context.py +17 -3
  289. mindspore/parallel/_parallel_serialization.py +12 -5
  290. mindspore/parallel/_ps_context.py +12 -0
  291. mindspore/parallel/_tensor.py +18 -13
  292. mindspore/parallel/_transformer/layers.py +5 -3
  293. mindspore/parallel/_transformer/loss.py +1 -0
  294. mindspore/parallel/_transformer/moe.py +2 -2
  295. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  296. mindspore/parallel/_transformer/transformer.py +23 -3
  297. mindspore/parallel/_utils.py +11 -7
  298. mindspore/parallel/algo_parameter_config.py +85 -5
  299. mindspore/parallel/checkpoint_transform.py +19 -12
  300. mindspore/parallel/shard.py +21 -14
  301. mindspore/profiler/common/struct_type.py +3 -3
  302. mindspore/profiler/common/util.py +4 -2
  303. mindspore/profiler/envprofiling.py +1 -1
  304. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  305. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  306. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  307. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  308. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  309. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  310. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  311. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  312. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  313. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  314. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  315. mindspore/profiler/parser/flops_parser.py +15 -11
  316. mindspore/profiler/parser/framework_parser.py +38 -22
  317. mindspore/profiler/parser/hccl_parser.py +16 -12
  318. mindspore/profiler/parser/integrator.py +22 -11
  319. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  320. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  321. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  322. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  323. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  324. mindspore/profiler/parser/optime_parser.py +1 -1
  325. mindspore/profiler/parser/profiler_info.py +21 -2
  326. mindspore/profiler/parser/step_trace_parser.py +11 -14
  327. mindspore/profiler/profiling.py +179 -89
  328. mindspore/rewrite/api/node.py +102 -19
  329. mindspore/rewrite/api/node_type.py +5 -1
  330. mindspore/rewrite/api/pattern_engine.py +1 -1
  331. mindspore/rewrite/api/scoped_value.py +9 -17
  332. mindspore/rewrite/api/symbol_tree.py +131 -47
  333. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  334. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  335. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  336. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  337. mindspore/rewrite/common/rewrite_elog.py +5 -1
  338. mindspore/rewrite/namer.py +33 -24
  339. mindspore/rewrite/namespace.py +14 -5
  340. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  341. mindspore/rewrite/node/call_function.py +79 -0
  342. mindspore/rewrite/node/cell_container.py +135 -0
  343. mindspore/rewrite/node/control_flow.py +88 -0
  344. mindspore/rewrite/{node.py → node/node.py} +273 -234
  345. mindspore/rewrite/node/node_manager.py +254 -0
  346. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  347. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  348. mindspore/rewrite/parsers/assign_parser.py +216 -221
  349. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  350. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  351. mindspore/rewrite/parsers/constant_parser.py +9 -6
  352. mindspore/rewrite/parsers/container_parser.py +9 -7
  353. mindspore/rewrite/parsers/for_parser.py +42 -21
  354. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  355. mindspore/rewrite/parsers/if_parser.py +28 -24
  356. mindspore/rewrite/parsers/module_parser.py +196 -25
  357. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  358. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  359. mindspore/rewrite/parsers/return_parser.py +6 -6
  360. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  361. mindspore/rewrite/sparsify/utils.py +1 -1
  362. mindspore/rewrite/symbol_tree.py +523 -578
  363. mindspore/rewrite/symbol_tree_builder.py +9 -193
  364. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  365. mindspore/run_check/_check_version.py +6 -4
  366. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  367. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  368. mindspore/tinyxml2.dll +0 -0
  369. mindspore/train/_utils.py +7 -3
  370. mindspore/train/amp.py +323 -123
  371. mindspore/train/anf_ir_pb2.py +14 -2
  372. mindspore/train/callback/_backup_and_restore.py +2 -12
  373. mindspore/train/callback/_callback.py +29 -4
  374. mindspore/train/callback/_checkpoint.py +23 -8
  375. mindspore/train/callback/_early_stop.py +2 -2
  376. mindspore/train/callback/_landscape.py +4 -4
  377. mindspore/train/callback/_loss_monitor.py +2 -2
  378. mindspore/train/callback/_on_request_exit.py +2 -2
  379. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  380. mindspore/train/callback/_summary_collector.py +15 -8
  381. mindspore/train/callback/_time_monitor.py +58 -5
  382. mindspore/train/data_sink.py +5 -11
  383. mindspore/train/dataset_helper.py +84 -57
  384. mindspore/train/loss_scale_manager.py +2 -2
  385. mindspore/train/metrics/__init__.py +3 -3
  386. mindspore/train/metrics/cosine_similarity.py +1 -1
  387. mindspore/train/metrics/hausdorff_distance.py +3 -2
  388. mindspore/train/metrics/mean_surface_distance.py +3 -2
  389. mindspore/train/metrics/metric.py +39 -19
  390. mindspore/train/metrics/roc.py +2 -2
  391. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  392. mindspore/train/mind_ir_pb2.py +85 -36
  393. mindspore/train/model.py +187 -47
  394. mindspore/train/serialization.py +487 -161
  395. mindspore/train/summary/_summary_adapter.py +1 -1
  396. mindspore/train/summary/_writer_pool.py +3 -2
  397. mindspore/train/summary/summary_record.py +37 -17
  398. mindspore/train/train_thor/convert_utils.py +3 -3
  399. mindspore/train/train_thor/dataset_helper.py +1 -1
  400. mindspore/turbojpeg.dll +0 -0
  401. mindspore/version.py +1 -1
  402. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
  403. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +406 -463
  404. mindspore/_extends/graph_kernel/expander.py +0 -80
  405. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  406. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  407. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  408. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  409. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  410. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  411. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  412. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  413. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  414. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  415. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  416. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  417. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  418. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  419. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  420. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  421. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  422. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  423. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  424. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  425. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  426. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  427. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  428. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  429. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  430. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  431. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  432. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  433. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  434. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  435. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  436. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  437. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  438. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  439. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  440. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  441. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  442. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  443. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  444. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  445. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  446. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  447. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  448. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  449. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  450. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  451. mindspore/dataset/datapreprocess/__init__.py +0 -20
  452. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  453. mindspore/include/api/net.h +0 -142
  454. mindspore/nn/lr_scheduler.py +0 -262
  455. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  456. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  457. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  458. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  459. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  460. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  461. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  462. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  463. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  464. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  465. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  466. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  467. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  468. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  469. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  470. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  471. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  472. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  473. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  474. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  475. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  476. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  477. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  478. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  479. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  480. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  481. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  482. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  483. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  484. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  485. mindspore/rewrite/node_visitor.py +0 -44
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  487. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  488. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -13,9 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ===========================================================================
15
15
  """GraphKernel model builder"""
16
-
17
- import copy
18
- from . import op_infer
19
16
  from .model import Tensor, Value, Operator, Graph, AlignShape
20
17
 
21
18
 
@@ -95,18 +92,6 @@ class GraphBuilder:
95
92
  node.all_inputs = inputs
96
93
  self.current.graph.add(node)
97
94
 
98
- def emit(self, prim, inputs, name=None, attrs=None):
99
- """Emit a new operation"""
100
- if attrs is None:
101
- attrs = {}
102
- if isinstance(inputs, (Tensor, Value)):
103
- inputs = [inputs]
104
- tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
105
- out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
106
- output = self.tensor(out_shape, out_dtype, out_format, name)
107
- self.op(prim, output, inputs, attrs)
108
- return output
109
-
110
95
  def get(self):
111
96
  """Get graphs"""
112
97
  return self.graphs
@@ -169,15 +154,18 @@ class CompositeGraph:
169
154
  for op in desc['op_desc']:
170
155
  inputs = [self.tensors.get(d['tensor_name'], None) for x in op['input_desc']
171
156
  for d in x if 'value' not in d]
157
+ if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
158
+ axis = op['input_desc'][1][0]['value']
159
+ if isinstance(axis, int):
160
+ axis = [axis]
161
+ if not op['attr']:
162
+ attr = [{'name': 'axis', 'dtype': 'listInt', 'value': axis}]
163
+ op['attr'] = attr
164
+ else:
165
+ op['attr'].append({'name': 'axis', 'dtype': 'listInt', 'value': axis})
172
166
  out_desc = op['output_desc']
173
167
  name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[
174
168
  0]['shape'], out_desc[0]['data_type'], out_desc[0]['format']
175
- if op['name'] == 'InplaceAssign':
176
- inputs[0].add_buddy(inputs[1])
177
- inputs[1].para_type = Tensor.PARA_OUTPUT
178
- output = inputs[2]
179
- self.tensors[name] = output
180
- continue
181
169
  output = self.tensors.get(name, None)
182
170
  if not output:
183
171
  output = builder.tensor(shape, dtype, data_format, name=name)
@@ -186,46 +174,17 @@ class CompositeGraph:
186
174
  self.graph = builder.get()[0]
187
175
  self.desc = desc
188
176
 
189
- def _pre_dump(self, outputs):
190
- """restore name to before load"""
191
- inplace_assign = {} # y_name, output_name
192
- inplace_assign_z = None
193
- for op in self.desc['op_desc']:
194
- if op['name'] == 'InplaceAssign':
195
- inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name']
196
- if inplace_assign:
197
- for t in outputs:
198
- if t.name not in inplace_assign:
199
- inplace_assign_z = t
200
- return inplace_assign, inplace_assign_z
201
177
 
202
178
  def dump(self, subgraph):
203
179
  """Dump Graph to json"""
204
180
  desc = {}
205
181
  inputs, outputs = subgraph.deduce_parameters()
206
182
  graph_ops = set(subgraph.ops)
207
- inplace_assign, inplace_assign_z = self._pre_dump(outputs)
208
183
 
209
184
  def dump_output(t):
210
- if t.name in inplace_assign:
211
- z = inplace_assign_z if inplace_assign_z is not None else self.tensors.get(t.name, None)
212
- return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)}
213
185
  return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
214
186
 
215
187
  def dump_op_desc(d):
216
- if d['name'] == 'InplaceAssign':
217
- y = d['input_desc'][1][0]['tensor_name']
218
- if self.tensors[y].op in graph_ops:
219
- z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (self.tensors.get(y), True)
220
- inplace_desc = copy.deepcopy(d)
221
- inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
222
- z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
223
- z_desc['shape'] = z.shape
224
- z_desc['data_type'] = z.dtype
225
- z_desc['tensor_name'] = z.name
226
- out_desc['shape'] = z.shape
227
- out_desc['data_type'] = z.dtype
228
- return inplace_desc
229
188
  op = self.tensors[d['output_desc'][0]['tensor_name']].op
230
189
  if op in graph_ops or op in subgraph.recompute_ops:
231
190
  return d
@@ -36,7 +36,6 @@ def split_with_json(json_str, flags_str):
36
36
  subgraphs, graph_mode = model.split(comp.graph, target, flags)
37
37
  is_multi_graph = len(subgraphs) > 1
38
38
  graph_list = list(map(comp.dump, subgraphs))
39
- _reset_graphmode_for_inplaceassign(graph_list, graph_mode)
40
39
  result = {"multi_graph": is_multi_graph,
41
40
  "graph_desc": graph_list,
42
41
  "graph_mode": graph_mode}
@@ -51,8 +50,9 @@ def split_with_json(json_str, flags_str):
51
50
  def _load_repository(graph, flags):
52
51
  """Load repository if exists"""
53
52
  def check_repo(op, best_split, op_desc):
54
- if not isinstance(best_split, dict) or "group_num" not in best_split or "graph_mode" not in best_split \
55
- or "split_result" not in best_split:
53
+ if not isinstance(best_split, dict):
54
+ return False
55
+ if "group_num" not in best_split or "graph_mode" not in best_split or "split_result" not in best_split:
56
56
  logger.warning("The graph split repository of {} should be a dict which contains 'group_num', 'graph_mode' "
57
57
  "and 'split_result' field, but got {}".format(op, best_split))
58
58
  return False
@@ -114,19 +114,12 @@ def _load_repository(graph, flags):
114
114
  return result
115
115
 
116
116
 
117
- def _reset_graphmode_for_inplaceassign(graph_list, graph_mode):
118
- """Operator with InplaceAssign should always be composite op"""
119
- for i, g in enumerate(graph_list):
120
- if any((op['name'] == 'InplaceAssign' for op in g['op_desc'])):
121
- graph_mode[i] = 'composite'
122
-
123
-
124
117
  def _dump_split_info(use_repo, graph_str, graph, subgraphs, graph_mode, graph_list):
125
118
  """Dump split info as text"""
126
119
  graph_kernel_dump_path = "graph_kernel_dump"
127
120
  utils.create_dir(graph_kernel_dump_path)
128
121
  filename = os.path.join(graph_kernel_dump_path, "graph_kernel_split_mode.%d.txt" % os.getpid())
129
- with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT), "a+") as f:
122
+ with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT, 0o600), "a+") as f:
130
123
  f.write("********** main graph: {} **********\n".format(graph.name))
131
124
  f.write("input json:\n{}\n".format(graph_str))
132
125
  f.write("graph desc:\n{}\n".format(str(graph)))
@@ -113,30 +113,115 @@ def create_akg_parallel_process(process_num, wait_time, platform):
113
113
  return AkgProcess(process_num, wait_time, platform)
114
114
 
115
115
 
116
- class AkgProcess:
117
- """akg kernel parallel process"""
116
+ def _is_input_shape_dynamic(desc_d):
117
+ input_lists = desc_d.get("input_desc", [])
118
+ if input_lists is None:
119
+ return True
120
+ for input_desc in input_lists:
121
+ shape = input_desc[0].get("shape", ())
122
+ if -1 in shape or -2 in shape:
123
+ return True
124
+ return False
118
125
 
119
- def __init__(self, process_num, wait_time, platform):
126
+
127
+ def _compile_akg_v2_task_default(json_strs, attrs, driver):
128
+ """
129
+ compile func called in single process
130
+
131
+ Parameters:
132
+ json_strs: list. List contains multiple kernel infos, suitable for json compile api.
133
+ """
134
+ log_level = get_log_level(attrs)
135
+ kernel_meta_dir = os.path.join(get_kernel_meta_parent_dir(attrs), "akg_kernel_meta")
136
+ for json_str in json_strs:
137
+ json_desc = json.loads(json_str)
138
+ op_name = json_desc["op"]
139
+ info_path = os.path.join(kernel_meta_dir, op_name + ".info")
140
+ if not os.path.isfile(info_path):
141
+ raise FileNotFoundError(f"Can not compile non-existing file \"{info_path}\"")
142
+ # Compile json str with AKG
143
+ bisheng_cpp_path = os.getenv("BISHENG_CPP_PATH", default="")
144
+ compiler = driver(input_file=info_path, output_dir=kernel_meta_dir, bisheng_tools_dir=bisheng_cpp_path,
145
+ dynamic_shape=_is_input_shape_dynamic(json_desc))
146
+ try:
147
+ compiler.compile()
148
+ except RuntimeError as exc:
149
+ if log_level == "ERROR":
150
+ raise ValueError(f"Compile error, json str: {json_str}! build attrs: {attrs}") from exc
151
+ logger.info(f"Will try to split, json str: {json_str}! build attrs: {attrs}")
152
+
153
+
154
+ def create_akg_v2_parallel_process(process_num, wait_time, platform):
155
+ """
156
+ create Akg V2 Parallel Compiler object
157
+
158
+ Returns:
159
+ AKG V2 ParallelCompiler
160
+ """
161
+ return AkgV2Process(process_num, wait_time, platform)
162
+
163
+
164
+ class AkgProcessBase:
165
+ """base class for akg kernel parallel process"""
166
+
167
+ def __init__(self, name, process_num, wait_time, platform):
120
168
  """
121
169
  Args:
122
170
  process_num: int. processes number
123
171
  wait_time: int. max time the function blocked
124
172
  """
125
173
  if not isinstance(process_num, int):
126
- raise ValueError("AKG kernel compiling process number must be of type int, but got {} with type {}"
127
- .format(process_num, type(wait_time)))
174
+ raise ValueError(
175
+ f"{name} kernel compiling process number must be of type int"
176
+ ", but got {process_num} with type {type(wait_time)}")
128
177
  if not isinstance(wait_time, int):
129
- raise ValueError("AKG kernel compiling wait time must be of type int, but got {} with type {}"
130
- .format(wait_time, type(wait_time)))
178
+ raise ValueError(
179
+ f"{name} kernel compiling wait time must be of type int,"
180
+ " but got {wait_time} with type {type(wait_time)}")
131
181
  if process_num == 0:
132
182
  process_num = 1
133
183
  max_proc_num = 16
184
+ self.name = name
134
185
  self.process_num = min([cpu_count(), max_proc_num, process_num])
135
186
  self.args = list([] for _ in range(self.process_num))
136
187
  self.wait_time = wait_time
137
188
  self.platform = platform
138
189
  self.argc = 0
139
190
 
191
+ def compile(self, attrs=None):
192
+ """
193
+ compile kernel by multi processes
194
+ Return:
195
+ True for all compile success, False for some failed.
196
+ """
197
+ del attrs
198
+ raise NotImplementedError
199
+
200
+ def accept_json(self, json_str):
201
+ """
202
+ accept json data before compile
203
+ Args:
204
+ json_str: str. kernel info.
205
+ """
206
+ if not isinstance(json_str, str):
207
+ raise ValueError(
208
+ f"In {self.name} kernel compiling, the kernel json must be of type str"
209
+ ", but got {json_str} with type { type(json_str)}")
210
+ self.args[self.argc % self.process_num].append(json_str)
211
+ self.argc += 1
212
+
213
+
214
+ class AkgProcess(AkgProcessBase):
215
+ """akg kernel parallel process"""
216
+
217
+ def __init__(self, process_num, wait_time, platform):
218
+ """
219
+ Args:
220
+ process_num: int. processes number
221
+ wait_time: int. max time the function blocked
222
+ """
223
+ super(AkgProcess, self).__init__("AKG", process_num, wait_time, platform)
224
+
140
225
  def compile(self, attrs=None):
141
226
  """
142
227
  compile kernel by multi processes
@@ -162,14 +247,36 @@ class AkgProcess:
162
247
  res.get(timeout=self.wait_time)
163
248
  return True
164
249
 
165
- def accept_json(self, json_str):
250
+
251
+ class AkgV2Process(AkgProcessBase):
252
+ """akg v2 kernel parallel process"""
253
+
254
+ def __init__(self, process_num, wait_time, platform):
166
255
  """
167
- accept json data before compile
168
256
  Args:
169
- json_str: str. kernel info.
257
+ process_num: int. processes number
258
+ wait_time: int. max time the function blocked
170
259
  """
171
- if not isinstance(json_str, str):
172
- raise ValueError("In AKG kernel compiling, the kernel json must be of type str, but got {} with type {}"
173
- .format(json, type(json)))
174
- self.args[self.argc % self.process_num].append(json_str)
175
- self.argc += 1
260
+ super(AkgV2Process, self).__init__("AKG V2", process_num, wait_time, platform)
261
+
262
+ def compile(self, attrs=None):
263
+ """
264
+ compile kernel by multi processes
265
+ Return:
266
+ True for all compile success, False for some failed.
267
+ """
268
+ if self.argc == 0:
269
+ raise ValueError("In AKG V2 kernel compiling, the number of kernel json that need to be compiled can "
270
+ "not be zero.")
271
+ akg_v2_path = os.getenv("AKG_V2_PATH", default="")
272
+ if akg_v2_path == "":
273
+ raise ValueError(
274
+ "The path to akg v2 compiler is not specified. Set the path to the compiler in AKG_V2_PATH")
275
+ sys.path.append(akg_v2_path)
276
+ p = __import__("akg_v2", globals(), locals())
277
+ driver = getattr(p, "AkgV2Driver")
278
+ args = list((arg, attrs, driver) for arg in self.args)
279
+ with Pool(processes=self.process_num) as pool:
280
+ res = pool.starmap_async(_compile_akg_v2_task_default, args)
281
+ res.get(timeout=self.wait_time)
282
+ return True
@@ -1,4 +1,4 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -25,6 +25,20 @@ from tbe.common.buildcfg import build_config
25
25
  from tbe.dsl import auto_schedule
26
26
  from tbe.dsl import build as tbe_build
27
27
  import tbe.common.context.op_context as op_context
28
+ from impl.dynamic.add import _add_check_format, _infer_shape
29
+
30
+ SHAPE = "shape"
31
+ FORMAT = "format"
32
+ DATA_TYPE = "data_type"
33
+ NEW_SHAPE = "new_shape"
34
+ ORI_SHAPE = "ori_shape"
35
+ ORI_FORMAT = "ori_format"
36
+ DST_TYPE = "dst_type"
37
+ DST_ORI_SHAPE = "dst_ori_shape"
38
+ INPUT_DESC = "input_desc"
39
+ OUTPUT_DESC = "output_desc"
40
+ ENABLE_VECTOR_2X = "enable_vector_2x"
41
+ ENABLE_GROUP_INPLACE = "enable_group_inplace"
28
42
 
29
43
 
30
44
  def initialize(kernel_meta_parent_dir):
@@ -49,14 +63,14 @@ def update_config(config, op_names):
49
63
  change_type_dict = {"MatMul": (True, False),
50
64
  "BatchMatMul": (True, False)}
51
65
  config["bool_storage_as_1bit"] = True
52
- config["enable_group_inplace"] = False
53
- config["enable_vector_2x"] = True
66
+ config[ENABLE_GROUP_INPLACE] = False
67
+ config[ENABLE_VECTOR_2X] = True
54
68
  for op in op_names:
55
69
  if op in bool_storage_as_1bit_oplist:
56
70
  config["bool_storage_as_1bit"] = False
57
71
  enable_group_inplace, enable_vector_2x = change_type_dict.get(op, (False, True))
58
- config["enable_group_inplace"] = config["enable_group_inplace"] or enable_group_inplace
59
- config["enable_vector_2x"] = config["enable_vector_2x"] and enable_vector_2x
72
+ config[ENABLE_GROUP_INPLACE] = config[ENABLE_GROUP_INPLACE] or enable_group_inplace
73
+ config[ENABLE_VECTOR_2X] = config[ENABLE_VECTOR_2X] and enable_vector_2x
60
74
 
61
75
 
62
76
  def add_new_shape(names, shapes, new_shapes, inputs):
@@ -70,11 +84,11 @@ def add_new_shape(names, shapes, new_shapes, inputs):
70
84
  continue
71
85
  if name not in inputs:
72
86
  raise RuntimeError("Can not support reshape on output tensor {}".format(name))
73
- if "new_shape" not in inputs[name]:
74
- inputs[name]["new_shape"] = new_shapes[i]
75
- elif new_shapes[i] != inputs[name]["new_shape"]:
87
+ if NEW_SHAPE not in inputs[name]:
88
+ inputs[name][NEW_SHAPE] = new_shapes[i]
89
+ elif new_shapes[i] != inputs[name][NEW_SHAPE]:
76
90
  raise RuntimeError("Find different new_shape {} and {} for {}"
77
- .format(inputs[name]["new_shape"], new_shapes[i], name))
91
+ .format(inputs[name][NEW_SHAPE], new_shapes[i], name))
78
92
 
79
93
 
80
94
  class TransShape:
@@ -93,22 +107,21 @@ class TransShape:
93
107
  if v.get("value") is not None:
94
108
  continue
95
109
  names.append(k)
96
- shapes.append(v["shape"])
97
- ori_shapes.append(v["ori_shape"] if v.get("ori_shape") else None)
98
- formats.append(v["format"])
99
- ori_formats.append(v["ori_format"])
110
+ shapes.append(v[SHAPE])
111
+ ori_shapes.append(v[ORI_SHAPE] if v.get(ORI_SHAPE) else None)
112
+ formats.append(v[FORMAT])
113
+ ori_formats.append(v[ORI_FORMAT])
100
114
  if len(shapes) == 2 and len(shapes[0]) != len(shapes[1]):
101
- from impl.add import _add_check_format, _infer_shape
102
- format_pattern = _add_check_format({"shape": shapes[0], "format": formats[0]},
103
- {"shape": shapes[1], "format": formats[1]})
115
+ format_pattern = _add_check_format({SHAPE: shapes[0], FORMAT: formats[0]},
116
+ {SHAPE: shapes[1], FORMAT: formats[1]})
104
117
  ori_shape0 = ori_shapes[0] if ori_shapes[0] is not None else infer_ori_shape(
105
118
  shapes[0], formats[0], ori_formats[0])
106
119
  ori_shape1 = ori_shapes[1] if ori_shapes[1] is not None else infer_ori_shape(
107
120
  shapes[1], formats[1], ori_formats[1])
108
121
  new_shapes = [None, None]
109
122
  new_shapes[0], new_shapes[1] = _infer_shape(format_pattern,
110
- {"shape": shapes[0], "ori_shape": ori_shape0},
111
- {"shape": shapes[1], "ori_shape": ori_shape1})
123
+ {SHAPE: shapes[0], ORI_SHAPE: ori_shape0},
124
+ {SHAPE: shapes[1], ORI_SHAPE: ori_shape1})
112
125
  new_shapes[0], new_shapes[1], _ = shape_util.broadcast_shapes(new_shapes[0], new_shapes[1],
113
126
  param_name_input1="input0",
114
127
  param_name_input2="input1")
@@ -119,7 +132,7 @@ class TransShape:
119
132
  """deal with batch_matmul."""
120
133
  for k, v in op_inputs.items():
121
134
  # batch dimension of BatchMatMul must be fused to 1D
122
- shape = v["shape"]
135
+ shape = v[SHAPE]
123
136
  if len(shape) > 5:
124
137
  new_shape = [functools.reduce(lambda x, y: x * y, shape[:-4])] + shape[-4:]
125
138
  add_new_shape(k, shape, new_shape, inputs)
@@ -135,6 +148,10 @@ class TransShape:
135
148
 
136
149
  def infer_ori_shape(shape, cur_format, ori_format):
137
150
  """Given current format and shape, infer the shape with ori_format."""
151
+
152
+ def _shape_error(current_shape, current_format):
153
+ raise ValueError("Invalid shape {} for format {}".format(current_shape, current_format))
154
+
138
155
  if cur_format == ori_format:
139
156
  return shape
140
157
  default_formats = ["DefaultFormat", "ND", "NCHW"]
@@ -145,7 +162,7 @@ def infer_ori_shape(shape, cur_format, ori_format):
145
162
  if cur_format == "FRACTAL_NZ" and ori_format in default_formats:
146
163
  dims = len(shape)
147
164
  if dims < 4:
148
- raise ValueError("Invalid shape {} for format {}".format(shape, cur_format))
165
+ _shape_error(shape, cur_format)
149
166
  ori_shape = shape[:dims - 4]
150
167
  m = shape[-3] * shape[-2]
151
168
  n = shape[-4] * shape[-1]
@@ -155,13 +172,13 @@ def infer_ori_shape(shape, cur_format, ori_format):
155
172
 
156
173
  if cur_format == "NC1HWC0" and ori_format in default_formats:
157
174
  if len(shape) != 5:
158
- raise ValueError("Invalid shape {} for format {}".format(shape, cur_format))
175
+ _shape_error(shape, cur_format)
159
176
  ori_shape = [shape[0], shape[1] * shape[4], shape[2], shape[3]]
160
177
  return ori_shape
161
178
 
162
179
  if cur_format == "NHWC" and ori_format in default_formats:
163
180
  if len(shape) != 4:
164
- raise ValueError("Invalid shape {} for format {}".format(shape, cur_format))
181
+ _shape_error(shape, cur_format)
165
182
  ori_shape = [shape[0], shape[3], shape[1], shape[2]]
166
183
  return ori_shape
167
184
 
@@ -202,7 +219,7 @@ def get_input_desc(input_desc):
202
219
  res = {}
203
220
  for desc in input_desc:
204
221
  for item in desc:
205
- item["shape"] = [1] if not item["shape"] else item["shape"]
222
+ item[SHAPE] = [1] if not item[SHAPE] else item[SHAPE]
206
223
  res[item["tensor_name"]] = item
207
224
  return res
208
225
 
@@ -215,7 +232,7 @@ def get_inputs_tensor(input_desc, all_tensors):
215
232
  name = item["tensor_name"]
216
233
  if item.get("value") is not None:
217
234
  # const value
218
- all_tensors[name] = tvm.const(item["value"], item["data_type"])
235
+ all_tensors[name] = tvm.const(item["value"], item[DATA_TYPE])
219
236
  if all_tensors.get(name) is None:
220
237
  raise ValueError("Tensor [{}] not found.".format(name))
221
238
  inputs.append(all_tensors[name])
@@ -237,17 +254,17 @@ def get_op_attrs(op, fusion_op_name):
237
254
  op_name = op["name"]
238
255
  op_attrs = get_attr_dict(op.get("attr"))
239
256
  if op_name == "BatchMatMul":
240
- op_attrs["dst_type"] = op["output_desc"][0]["data_type"]
241
- op_attrs["dst_ori_shape"] = op["output_desc"][0].get("ori_shape")
242
- if op_attrs.get("dst_ori_shape") is None:
243
- op_attrs["dst_ori_shape"] = infer_ori_shape(op["output_desc"][0]["shape"],
244
- op["output_desc"][0]["format"],
245
- op["output_desc"][0]["ori_format"])
257
+ op_attrs[DST_TYPE] = op[OUTPUT_DESC][0][DATA_TYPE]
258
+ op_attrs[DST_ORI_SHAPE] = op[OUTPUT_DESC][0].get(ORI_SHAPE)
259
+ if op_attrs.get(DST_ORI_SHAPE) is None:
260
+ op_attrs[DST_ORI_SHAPE] = infer_ori_shape(op[OUTPUT_DESC][0][SHAPE],
261
+ op[OUTPUT_DESC][0][FORMAT],
262
+ op[OUTPUT_DESC][0][ORI_FORMAT])
246
263
  elif op_name == "MatMul":
247
- op_attrs["dst_type"] = op["output_desc"][0]["data_type"]
248
- op_attrs["dst_format"] = op["output_desc"][0]["format"]
264
+ op_attrs[DST_TYPE] = op[OUTPUT_DESC][0][DATA_TYPE]
265
+ op_attrs["dst_format"] = op[OUTPUT_DESC][0][FORMAT]
249
266
  elif op_name == "Cast":
250
- op_attrs["dst_type"] = op["output_desc"][0]["data_type"]
267
+ op_attrs[DST_TYPE] = op[OUTPUT_DESC][0][DATA_TYPE]
251
268
  op_attrs["fusion_op_name"] = fusion_op_name
252
269
  return op_attrs
253
270
 
@@ -256,17 +273,17 @@ def create_placeholders(inputs):
256
273
  """Create placeholders."""
257
274
  tensors = {}
258
275
  for k, v in inputs.items():
259
- dtype = v["data_type"]
276
+ dtype = v[DATA_TYPE]
260
277
  if dtype == "bool":
261
278
  dtype = "int8"
262
- shape = v["shape"]
263
- if "new_shape" in v:
264
- shape = v["new_shape"]
279
+ shape = v[SHAPE]
280
+ if NEW_SHAPE in v:
281
+ shape = v[NEW_SHAPE]
265
282
  attr = {
266
- "format": v.get("format"),
283
+ FORMAT: v.get(FORMAT),
267
284
  "sub_format": v.get("sub_format", ""),
268
- "ori_shape": v.get("ori_shape"),
269
- "ori_format": v.get("ori_format"),
285
+ ORI_SHAPE: v.get(ORI_SHAPE),
286
+ ORI_FORMAT: v.get(ORI_FORMAT),
270
287
  "addr_type": v.get("addr_type", 0),
271
288
  "valid_shape": v.get("valid_shape", []),
272
289
  "slice_offset": v.get("slice_offset", []),
@@ -276,8 +293,8 @@ def create_placeholders(inputs):
276
293
  "L1_valid_size": v.get("L1_valid_size", -1),
277
294
  "range": v.get("range", [])
278
295
  }
279
- if attr.get("ori_shape") is None:
280
- attr["ori_shape"] = infer_ori_shape(v.get("shape"), v.get("format"), attr.get("ori_format"))
296
+ if attr.get(ORI_SHAPE) is None:
297
+ attr[ORI_SHAPE] = infer_ori_shape(v.get(SHAPE), v.get(FORMAT), attr.get(ORI_FORMAT))
281
298
  tensors[k] = tvm.placeholder(shape=shape, name=k, dtype=dtype, attrs=attr)
282
299
  return tensors
283
300
 
@@ -289,8 +306,8 @@ def same_shape(inputs):
289
306
  base_shape = -1
290
307
  for _, v in inputs.items():
291
308
  if base_shape == -1:
292
- base_shape = v["shape"]
293
- if v["shape"] != base_shape:
309
+ base_shape = v[SHAPE]
310
+ if v[SHAPE] != base_shape:
294
311
  return False
295
312
  return True
296
313
 
@@ -298,17 +315,17 @@ def same_shape(inputs):
298
315
  def create_input_tensors(json_dict):
299
316
  """Create input placeholders."""
300
317
  fold_dim = True
301
- inputs = get_input_desc(json_dict.get("input_desc", []))
318
+ inputs = get_input_desc(json_dict.get(INPUT_DESC, []))
302
319
  for op in json_dict["op_desc"]:
303
320
  op_name = op["name"]
304
321
  pattern = get_op_reg_info(op_name, "pattern")
305
- op_inputs = get_input_desc(op.get("input_desc", []))
322
+ op_inputs = get_input_desc(op.get(INPUT_DESC, []))
306
323
  TransShape.run(op_name, pattern, op_inputs, inputs)
307
324
  if pattern != OpPattern.ELEMWISE or not same_shape(op_inputs):
308
325
  fold_dim = False
309
326
  if fold_dim:
310
327
  for k, v in inputs.items():
311
- shape = v["shape"]
328
+ shape = v[SHAPE]
312
329
  new_shape = [functools.reduce(lambda x, y: x * y, shape[:])]
313
330
  add_new_shape(k, shape, new_shape, inputs)
314
331
  return create_placeholders(inputs)
@@ -324,28 +341,28 @@ def create_fusion_op_name(op_names):
324
341
 
325
342
 
326
343
  def update_format(json_dict):
327
- """Some format like DefaultFormat is not recognized in TBE, need to covert these formats."""
344
+ """Some format like DefaultFormat is not recognized in TBE, need to convert these formats."""
328
345
 
329
346
  def _update_input_format(input_desc):
330
347
  for desc in input_desc:
331
348
  for item in desc:
332
- if item["format"] == "DefaultFormat":
333
- item["format"] = "ND"
334
- if item.get("ori_format") is None or item["ori_format"] == "DefaultFormat":
335
- item["ori_format"] = "NCHW"
349
+ if item[FORMAT] == "DefaultFormat":
350
+ item[FORMAT] = "ND"
351
+ if item.get(ORI_FORMAT) is None or item[ORI_FORMAT] == "DefaultFormat":
352
+ item[ORI_FORMAT] = "NCHW"
336
353
 
337
354
  def _update_output_format(output_desc):
338
355
  for item in output_desc:
339
- if item["format"] == "DefaultFormat":
340
- item["format"] = "ND"
341
- if item.get("ori_format") is None or item["ori_format"] == "DefaultFormat":
342
- item["ori_format"] = "NCHW"
356
+ if item[FORMAT] == "DefaultFormat":
357
+ item[FORMAT] = "ND"
358
+ if item.get(ORI_FORMAT) is None or item[ORI_FORMAT] == "DefaultFormat":
359
+ item[ORI_FORMAT] = "NCHW"
343
360
 
344
- _update_input_format(json_dict.get("input_desc", []))
345
- _update_output_format(json_dict["output_desc"])
361
+ _update_input_format(json_dict.get(INPUT_DESC, []))
362
+ _update_output_format(json_dict[OUTPUT_DESC])
346
363
  for op in json_dict["op_desc"]:
347
- _update_input_format(op.get("input_desc", []))
348
- _update_output_format(op["output_desc"])
364
+ _update_input_format(op.get(INPUT_DESC, []))
365
+ _update_output_format(op[OUTPUT_DESC])
349
366
 
350
367
 
351
368
  def gen_args_remap(orig_inputs_name, orig_outputs_name, inputs_name, outputs_name, inplace_names):
@@ -410,7 +427,7 @@ def update_json(json_dict, inputs_name, outputs_name, inplace_names, kernel_meta
410
427
  pass
411
428
  # generate new .json
412
429
  try:
413
- with os.fdopen(os.open(json_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o660), 'w') as fi:
430
+ with os.fdopen(os.open(json_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), 'w') as fi:
414
431
  json.dump(json_dict, fi, sort_keys=True, indent=4, separators=(',', ':'))
415
432
  except OSError:
416
433
  pass
@@ -441,8 +458,8 @@ def build(json_str, kernel_meta_parent_dir):
441
458
  """Build kernel."""
442
459
  json_dict = json.loads(json_str)
443
460
  update_format(json_dict)
444
- inputs_name = get_inputs_name(json_dict.get("input_desc", []))
445
- outputs_name, inplace_names = get_outputs_info(json_dict["output_desc"])
461
+ inputs_name = get_inputs_name(json_dict.get(INPUT_DESC, []))
462
+ outputs_name, inplace_names = get_outputs_info(json_dict[OUTPUT_DESC])
446
463
  op_names = get_all_op_name(json_dict["op_desc"])
447
464
  fusion_op_name = create_fusion_op_name(op_names)
448
465
 
@@ -458,7 +475,7 @@ def build(json_str, kernel_meta_parent_dir):
458
475
  for op in json_dict["op_desc"]:
459
476
  op_name = op["name"]
460
477
  # get op input tensor
461
- op_inputs = get_inputs_tensor(op.get("input_desc", []), all_tensors)
478
+ op_inputs = get_inputs_tensor(op.get(INPUT_DESC, []), all_tensors)
462
479
  # get op attrs
463
480
  op_attrs = get_op_attrs(op, fusion_op_name)
464
481
  # op compute
@@ -466,10 +483,10 @@ def build(json_str, kernel_meta_parent_dir):
466
483
  # update op output tensor
467
484
  if not isinstance(op_outputs, (list, tuple)):
468
485
  op_outputs = [op_outputs]
469
- if len(op["output_desc"]) != len(op_outputs):
486
+ if len(op[OUTPUT_DESC]) != len(op_outputs):
470
487
  raise ValueError("len(op[\"output_desc\"] is not equal to the number of real output tensors in op[{}]: "
471
- "{} vs {}".format(op_name, len(op["output_desc"]), len(op_outputs)))
472
- for i, desc in enumerate(op["output_desc"]):
488
+ "{} vs {}".format(op_name, len(op[OUTPUT_DESC]), len(op_outputs)))
489
+ for i, desc in enumerate(op[OUTPUT_DESC]):
473
490
  all_tensors[desc["tensor_name"]] = op_outputs[i]
474
491
 
475
492
  # Collect input, output tensors
@@ -310,7 +310,8 @@ def _log(x, attrs=None):
310
310
  if base <= 0 and not math.isclose(base, -1.0, rel_tol=1e-8, abs_tol=0.0):
311
311
  raise ValueError("base must be strictly positive or -1, but got {}".format(base))
312
312
  from impl.log import log_compute
313
- return log_compute(x, None, base, scale, shift, kernel_name=attrs["fusion_op_name"])
313
+ output_desc = {"dtype": x.dtype, "shape": x.shape}
314
+ return log_compute(x, output_desc, base, scale, shift, kernel_name=attrs["fusion_op_name"])
314
315
 
315
316
 
316
317
  @reg_op("Maximum", pattern=OpPattern.ELEMWISE)
@@ -349,7 +350,8 @@ def _mul(x0, x1, attrs=None):
349
350
  return tbe.dsl.vmuls(x1, x0)
350
351
  x0, x1 = _broadcast(x0, x1)
351
352
  from impl.mul import mul_compute
352
- return mul_compute(x0, x1, None, kernel_name=attrs["fusion_op_name"])
353
+ output_desc = {"dtype": x0.dtype, "shape": x0.shape}
354
+ return mul_compute(x0, x1, output_desc, kernel_name=attrs["fusion_op_name"])
353
355
 
354
356
 
355
357
  @reg_op("Neg", pattern=OpPattern.ELEMWISE)