mindspore 2.1.0__cp38-none-any.whl → 2.2.10__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (569) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +141 -88
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +84 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  274. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  275. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  276. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  277. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  278. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  279. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  280. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  281. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  283. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  284. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  285. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  286. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  287. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  288. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  289. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  290. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  291. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  292. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  293. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  294. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  295. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  298. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  299. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  300. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  301. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  302. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  303. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  304. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  305. mindspore/ops/_primitive_cache.py +1 -1
  306. mindspore/ops/_tracefunc.py +45 -13
  307. mindspore/ops/_utils/utils.py +6 -1
  308. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  309. mindspore/ops/_vmap/vmap_base.py +3 -3
  310. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  311. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  312. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  313. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  314. mindspore/ops/arg_dtype_cast.py +54 -0
  315. mindspore/ops/composite/base.py +37 -10
  316. mindspore/ops/composite/math_ops.py +5 -4
  317. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  318. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  319. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  320. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  321. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  324. mindspore/ops/deprecated.py +304 -0
  325. mindspore/ops/function/__init__.py +4 -1
  326. mindspore/ops/function/array_func.py +174 -193
  327. mindspore/ops/function/clip_func.py +81 -13
  328. mindspore/ops/function/debug_func.py +1 -1
  329. mindspore/ops/function/grad/grad_func.py +18 -9
  330. mindspore/ops/function/image_func.py +10 -4
  331. mindspore/ops/function/linalg_func.py +5 -5
  332. mindspore/ops/function/math_func.py +575 -386
  333. mindspore/ops/function/nn_func.py +568 -260
  334. mindspore/ops/function/random_func.py +88 -57
  335. mindspore/ops/function/sparse_func.py +1 -1
  336. mindspore/ops/function/sparse_unary_func.py +14 -12
  337. mindspore/ops/function/vmap_func.py +6 -5
  338. mindspore/ops/functional.py +15 -10
  339. mindspore/ops/op_info_register.py +244 -25
  340. mindspore/ops/operations/__init__.py +28 -19
  341. mindspore/ops/operations/_grad_ops.py +72 -7
  342. mindspore/ops/operations/_inner_ops.py +350 -17
  343. mindspore/ops/operations/_quant_ops.py +4 -8
  344. mindspore/ops/operations/_sequence_ops.py +42 -0
  345. mindspore/ops/operations/array_ops.py +68 -282
  346. mindspore/ops/operations/comm_ops.py +107 -59
  347. mindspore/ops/operations/custom_ops.py +94 -70
  348. mindspore/ops/operations/debug_ops.py +8 -4
  349. mindspore/ops/operations/image_ops.py +18 -12
  350. mindspore/ops/operations/inner_ops.py +26 -3
  351. mindspore/ops/operations/math_ops.py +189 -141
  352. mindspore/ops/operations/nn_ops.py +794 -489
  353. mindspore/ops/operations/other_ops.py +0 -22
  354. mindspore/ops/operations/random_ops.py +53 -111
  355. mindspore/ops/operations/sparse_ops.py +3 -1
  356. mindspore/ops/primitive.py +24 -18
  357. mindspore/parallel/_auto_parallel_context.py +68 -8
  358. mindspore/parallel/_cost_model_context.py +2 -2
  359. mindspore/parallel/_offload_context.py +17 -3
  360. mindspore/parallel/_parallel_serialization.py +12 -5
  361. mindspore/parallel/_ps_context.py +12 -0
  362. mindspore/parallel/_tensor.py +18 -13
  363. mindspore/parallel/_transformer/layers.py +5 -3
  364. mindspore/parallel/_transformer/loss.py +1 -0
  365. mindspore/parallel/_transformer/moe.py +2 -2
  366. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  367. mindspore/parallel/_transformer/transformer.py +23 -3
  368. mindspore/parallel/_utils.py +11 -7
  369. mindspore/parallel/algo_parameter_config.py +85 -5
  370. mindspore/parallel/checkpoint_transform.py +19 -12
  371. mindspore/parallel/shard.py +21 -14
  372. mindspore/profiler/common/struct_type.py +3 -3
  373. mindspore/profiler/common/util.py +4 -2
  374. mindspore/profiler/envprofiling.py +1 -1
  375. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  376. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  377. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  378. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  379. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  380. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  381. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  382. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  383. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  384. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  385. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  386. mindspore/profiler/parser/flops_parser.py +15 -11
  387. mindspore/profiler/parser/framework_parser.py +38 -22
  388. mindspore/profiler/parser/hccl_parser.py +16 -12
  389. mindspore/profiler/parser/integrator.py +22 -11
  390. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  391. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  392. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  393. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  394. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  395. mindspore/profiler/parser/optime_parser.py +1 -1
  396. mindspore/profiler/parser/profiler_info.py +21 -2
  397. mindspore/profiler/parser/step_trace_parser.py +11 -14
  398. mindspore/profiler/profiling.py +179 -89
  399. mindspore/rewrite/api/node.py +102 -19
  400. mindspore/rewrite/api/node_type.py +5 -1
  401. mindspore/rewrite/api/pattern_engine.py +1 -1
  402. mindspore/rewrite/api/scoped_value.py +9 -17
  403. mindspore/rewrite/api/symbol_tree.py +131 -47
  404. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  405. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  406. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  407. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  408. mindspore/rewrite/common/rewrite_elog.py +5 -1
  409. mindspore/rewrite/namer.py +33 -24
  410. mindspore/rewrite/namespace.py +14 -5
  411. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  412. mindspore/rewrite/node/call_function.py +79 -0
  413. mindspore/rewrite/node/cell_container.py +135 -0
  414. mindspore/rewrite/node/control_flow.py +88 -0
  415. mindspore/rewrite/{node.py → node/node.py} +273 -234
  416. mindspore/rewrite/node/node_manager.py +254 -0
  417. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  418. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  419. mindspore/rewrite/parsers/assign_parser.py +216 -221
  420. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  421. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  422. mindspore/rewrite/parsers/constant_parser.py +9 -6
  423. mindspore/rewrite/parsers/container_parser.py +9 -7
  424. mindspore/rewrite/parsers/for_parser.py +36 -15
  425. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  426. mindspore/rewrite/parsers/if_parser.py +28 -24
  427. mindspore/rewrite/parsers/module_parser.py +196 -25
  428. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  429. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  430. mindspore/rewrite/parsers/return_parser.py +6 -6
  431. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  432. mindspore/rewrite/sparsify/utils.py +1 -1
  433. mindspore/rewrite/symbol_tree.py +523 -578
  434. mindspore/rewrite/symbol_tree_builder.py +9 -193
  435. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  436. mindspore/run_check/_check_version.py +6 -4
  437. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  438. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  439. mindspore/scipy/linalg.py +1 -1
  440. mindspore/scipy/optimize/minimize.py +7 -3
  441. mindspore/train/_utils.py +7 -3
  442. mindspore/train/amp.py +323 -123
  443. mindspore/train/anf_ir_pb2.py +14 -2
  444. mindspore/train/callback/_backup_and_restore.py +2 -12
  445. mindspore/train/callback/_callback.py +29 -4
  446. mindspore/train/callback/_checkpoint.py +23 -8
  447. mindspore/train/callback/_early_stop.py +2 -2
  448. mindspore/train/callback/_landscape.py +4 -4
  449. mindspore/train/callback/_loss_monitor.py +2 -2
  450. mindspore/train/callback/_on_request_exit.py +2 -2
  451. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  452. mindspore/train/callback/_summary_collector.py +15 -8
  453. mindspore/train/callback/_time_monitor.py +58 -5
  454. mindspore/train/data_sink.py +5 -11
  455. mindspore/train/dataset_helper.py +84 -57
  456. mindspore/train/loss_scale_manager.py +2 -2
  457. mindspore/train/metrics/__init__.py +3 -3
  458. mindspore/train/metrics/cosine_similarity.py +1 -1
  459. mindspore/train/metrics/hausdorff_distance.py +3 -2
  460. mindspore/train/metrics/mean_surface_distance.py +3 -2
  461. mindspore/train/metrics/metric.py +39 -19
  462. mindspore/train/metrics/roc.py +2 -2
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  464. mindspore/train/mind_ir_pb2.py +85 -36
  465. mindspore/train/model.py +187 -47
  466. mindspore/train/serialization.py +487 -161
  467. mindspore/train/summary/_summary_adapter.py +1 -1
  468. mindspore/train/summary/_writer_pool.py +3 -2
  469. mindspore/train/summary/summary_record.py +37 -17
  470. mindspore/train/train_thor/convert_utils.py +3 -3
  471. mindspore/train/train_thor/dataset_helper.py +1 -1
  472. mindspore/version.py +1 -1
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
  475. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  476. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  478. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  479. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  480. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  481. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  482. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  483. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  484. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  486. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  487. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  488. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  489. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  490. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  491. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  492. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  493. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  494. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  495. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  496. mindspore/_extends/graph_kernel/expander.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  498. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  499. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  500. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  501. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  502. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  503. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  504. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  505. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  507. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  508. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  509. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  510. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  511. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  514. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  516. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  517. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  518. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  519. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  520. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  521. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  523. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  524. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  526. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  527. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  528. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  529. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  533. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  534. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  535. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  536. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  537. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  538. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  539. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  540. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  541. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  542. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  543. mindspore/dataset/datapreprocess/__init__.py +0 -20
  544. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  545. mindspore/include/api/net.h +0 -142
  546. mindspore/nn/lr_scheduler.py +0 -262
  547. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  548. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  549. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  550. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  551. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  552. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  554. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  555. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  556. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  567. mindspore/rewrite/node_visitor.py +0 -44
  568. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  569. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -83,23 +83,23 @@ class CommonPattern:
83
83
  def reshape(dom):
84
84
  """fuse strategy for reshape dom"""
85
85
  if dom.pattern != PrimLib.RESHAPE:
86
- return []
86
+ return [], False
87
87
  min_area, forward_fuse = None, False
88
88
  for a, _ in dom.out_relations.items():
89
- if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
90
- (min_area is None or a.pattern < min_area.pattern):
91
- min_area = a
89
+ if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a):
90
+ if min_area is None or a.pattern < min_area.pattern:
91
+ min_area = a
92
92
  for a, _ in dom.in_relations.items():
93
- if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
94
- (min_area is None or a.pattern < min_area.pattern):
95
- min_area, forward_fuse = a, True
96
- return ([min_area], forward_fuse) if min_area else []
93
+ if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom):
94
+ if min_area is None or a.pattern < min_area.pattern:
95
+ min_area, forward_fuse = a, True
96
+ return ([min_area], forward_fuse) if min_area else ([], False)
97
97
 
98
98
  @staticmethod
99
99
  def isolate_reshape(dom):
100
100
  """fuse strategy for isolate reshape dom"""
101
101
  if dom.pattern != PrimLib.RESHAPE or len(dom.ops) != 1:
102
- return []
102
+ return [], False
103
103
  for a, _ in dom.out_relations.items():
104
104
  if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and dom.check_acyclic(a):
105
105
  return [a], False
@@ -107,59 +107,61 @@ class CommonPattern:
107
107
  if a.mode == GraphSplitByPattern.Area.MODE_COMPOSITE and a.pattern <= PrimLib.BROADCAST and \
108
108
  a.check_acyclic(dom):
109
109
  return [a], True
110
- return []
110
+ return [], False
111
111
 
112
112
  @staticmethod
113
113
  def elemwise_depth(dom):
114
114
  """fuse strategy in depth for elemwise dom"""
115
115
  if dom.pattern != PrimLib.ELEMWISE or len(dom.in_relations) != 1:
116
- return []
116
+ return [], False
117
117
  a, r = list(dom.in_relations.items())[0]
118
- if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE or \
119
- tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
120
- return []
118
+ if a.pattern > PrimLib.ELEMWISE or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
119
+ return [], False
120
+ if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
121
+ return [], False
121
122
  return [a], True
122
123
 
123
124
  @staticmethod
124
125
  def elemwise_width(dom):
125
126
  """fuse strategy in width for elemwise dom"""
126
127
  if dom.pattern != PrimLib.ELEMWISE:
127
- return []
128
+ return [], False
128
129
  fused = []
129
130
  for a, r in dom.in_relations.items():
130
- if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom) and \
131
- tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
132
- fused.append(a)
131
+ if a.pattern <= PrimLib.ELEMWISE and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
132
+ if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
133
+ fused.append(a)
133
134
  return fused, True
134
135
 
135
136
  @staticmethod
136
137
  def broadcast_depth(dom):
137
138
  """fuse strategy in depth for broadcast dom"""
138
139
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
139
- return []
140
+ return [], False
140
141
  a, r = list(dom.in_relations.items())[0]
141
- if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE or \
142
- tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
143
- return []
142
+ if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r > PrimLib.ELEMWISE:
143
+ return [], False
144
+ if tensor_size(a.dom_op().output) != tensor_size(dom.dom_op().output):
145
+ return [], False
144
146
  return [a], True
145
147
 
146
148
  @staticmethod
147
149
  def broadcast_width(dom):
148
150
  """fuse strategy in width for broadcast dom"""
149
151
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
150
- return []
152
+ return [], False
151
153
  fused = []
152
154
  for a, r in dom.in_relations.items():
153
- if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom) and \
154
- tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
155
- fused.append(a)
155
+ if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.ELEMWISE and a.check_acyclic(dom):
156
+ if tensor_size(a.dom_op().output) == tensor_size(dom.dom_op().output):
157
+ fused.append(a)
156
158
  return fused, True
157
159
 
158
160
  @staticmethod
159
161
  def assign(dom):
160
162
  """fuse strategy for assign dom"""
161
163
  if len(dom.ops) != 1 or dom.dom_op().prim != "Assign":
162
- return []
164
+ return [], False
163
165
  fused = []
164
166
  for a, _ in dom.in_relations.items():
165
167
  fused.append(a)
@@ -711,8 +713,9 @@ class GraphSplitByPattern:
711
713
  for i in range(len(areas) - 1):
712
714
  dom = areas[i]
713
715
  for a in areas[i + 1:]:
714
- if dom.check_acyclic(a) and a.check_acyclic(dom) and \
715
- selector(dom, a) and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a):
716
+ can_fuse = dom.check_acyclic(a) and a.check_acyclic(dom) and selector(dom, a) \
717
+ and self.limit_area_size(dom, [a], 64) and dom.fuse_confirm(a)
718
+ if can_fuse:
716
719
  dom.fuse(a)
717
720
  self.set_area_map(a.ops, dom)
718
721
  self.areas.remove(a)
@@ -844,7 +847,7 @@ class GraphSplitByPattern:
844
847
  while stack:
845
848
  op = stack.pop()
846
849
  if len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST or len(ops) > max_weight:
847
- return []
850
+ return [], []
848
851
  ops.append(op)
849
852
  for t in op.inputs:
850
853
  if t.op in area.ops:
@@ -878,8 +881,8 @@ class GraphSplitByPattern:
878
881
  return []
879
882
  result = []
880
883
  for op in borders:
881
- if prods[op]:
882
- prod_ops, inputs = prods[op]
884
+ prod_ops, inputs = prods[op]
885
+ if prod_ops:
883
886
  if sum([t.get_size() for t in inputs]) <= op.output.get_size():
884
887
  pred = self.area_map.get(inputs[0].op) if inputs and inputs[0].op else None
885
888
  result.append([pred, prod_ops[::-1]])
@@ -938,23 +941,25 @@ class GraphSplitGpu(GraphSplitByPattern):
938
941
  return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST
939
942
 
940
943
  def _broadcast_bwd_depth(dom):
941
- if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
942
- dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
943
- return []
944
+ if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1:
945
+ return [], False
946
+ if dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
947
+ return [], False
944
948
  a, r = list(dom.out_relations.items())[0]
945
949
  if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
946
- return []
950
+ return [], False
947
951
  return [a], False
948
952
 
949
953
  def _broadcast_bwd_width(dom):
950
954
  if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
951
955
  dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
952
- return []
956
+ return [], False
953
957
  fused = []
954
958
  for a, r in dom.out_relations.items():
955
- if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \
956
- (fused and tensor_size(fused[0].dom_op().output) != tensor_size(a.dom_op().output)):
957
- return []
959
+ if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a):
960
+ return [], False
961
+ if fused and tensor_size(fused[0].dom_op().output) != tensor_size(a.dom_op().output):
962
+ return [], False
958
963
  fused.append(a)
959
964
  return fused, False
960
965
 
@@ -965,25 +970,25 @@ class GraphSplitGpu(GraphSplitByPattern):
965
970
 
966
971
  def _reduce_depth(dom):
967
972
  if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
968
- return []
973
+ return [], False
969
974
  a, r = list(dom.in_relations.items())[0]
970
- if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
971
- _is_atomic_add_available(dom):
972
- # to evade the precision problem.
973
- return []
975
+ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
976
+ if len(a.ops) >= 10 and _is_atomic_add_available(dom):
977
+ # to evade the precision problem.
978
+ return [], False
974
979
  if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
975
980
  return []
976
981
  return [a], True
977
982
 
978
983
  def _reduce_width(dom):
979
984
  if dom.pattern != PrimLib.REDUCE:
980
- return []
985
+ return [], False
981
986
  fused = []
982
987
  for a, r in dom.in_relations.items():
983
- if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \
984
- _is_atomic_add_available(dom):
985
- # to evade the precision problem.
986
- continue
988
+ if dom.ops[0].inputs[0].dtype == "float16" and a.is_output:
989
+ if len(a.ops) >= 10 and _is_atomic_add_available(dom):
990
+ # to evade the precision problem.
991
+ continue
987
992
  if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom):
988
993
  fused.append(a)
989
994
  return fused, True
@@ -1016,15 +1021,15 @@ class GraphSplitGpu(GraphSplitByPattern):
1016
1021
 
1017
1022
  def _reduce_output(dom):
1018
1023
  if dom.pattern != PrimLib.REDUCE:
1019
- return []
1024
+ return [], False
1020
1025
  if _may_multi_filter(dom.ops):
1021
- return []
1026
+ return [], False
1022
1027
  if _is_atomic_add_available(dom):
1023
- return []
1028
+ return [], False
1024
1029
  is_all_reduce = tensor_size(dom.ops[0].output) == 1
1025
1030
  # excluded large size all reduce
1026
1031
  if is_all_reduce and dom.ops[0].inputs and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
1027
- return []
1032
+ return [], False
1028
1033
 
1029
1034
  fused = []
1030
1035
  for a, r in dom.out_relations.items():
@@ -1034,11 +1039,11 @@ class GraphSplitGpu(GraphSplitByPattern):
1034
1039
 
1035
1040
  def _reduce_stitch(dom):
1036
1041
  if dom.pattern != PrimLib.REDUCE:
1037
- return []
1042
+ return [], False
1038
1043
  if tensor_size(dom.ops[0].output) == 1:
1039
- return []
1044
+ return [], False
1040
1045
  if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12:
1041
- return []
1046
+ return [], False
1042
1047
 
1043
1048
  fused = []
1044
1049
  for a, r in dom.out_relations.items():
@@ -1055,7 +1060,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1055
1060
 
1056
1061
  def _transpose(dom):
1057
1062
  if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
1058
- return []
1063
+ return [], False
1059
1064
  fused = []
1060
1065
  for a, _ in dom.in_relations.items():
1061
1066
  if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH:
@@ -1064,7 +1069,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1064
1069
 
1065
1070
  def _strided_slice(dom):
1066
1071
  if dom.dom_op().prim != "StridedSlice":
1067
- return []
1072
+ return [], False
1068
1073
  fused = []
1069
1074
  for a, _ in dom.in_relations.items():
1070
1075
  if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \
@@ -1075,7 +1080,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1075
1080
  def _gather_output(dom, reduce_fusion=False):
1076
1081
  gather_prims = ("Gather", "GatherNd", "CSRGather")
1077
1082
  if not dom.dom_op().prim in gather_prims:
1078
- return []
1083
+ return [], False
1079
1084
 
1080
1085
  def _reduce_exclude(op, axis_list):
1081
1086
  """ Whether this operator should be excluded.
@@ -1173,7 +1178,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1173
1178
  for a, _ in dom.out_relations.items():
1174
1179
  if _shape_consistent(gather_prims, appected_areas, dom, a) and dom.check_acyclic(a):
1175
1180
  return [a], False
1176
- return []
1181
+ return [], False
1177
1182
 
1178
1183
  def _broadcast_tot(dom):
1179
1184
  """Fuse rule for TensorScatterAdd and UnsortedSegmentSum."""
@@ -1182,13 +1187,13 @@ class GraphSplitGpu(GraphSplitByPattern):
1182
1187
  return bool(set(op1.inputs) & set(op2.inputs))
1183
1188
 
1184
1189
  if len(dom.ops) != 1:
1185
- return []
1190
+ return [], False
1186
1191
 
1187
1192
  # Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`.
1188
1193
  fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)}
1189
1194
  arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
1190
1195
  if arg_idx == -1:
1191
- return []
1196
+ return [], False
1192
1197
  fuse_tensor = dom.dom_op().inputs[arg_idx]
1193
1198
 
1194
1199
  for a, _ in dom.in_relations.items():
@@ -1200,27 +1205,30 @@ class GraphSplitGpu(GraphSplitByPattern):
1200
1205
  # Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs.
1201
1206
  if a.pattern <= PrimLib.BROADCAST and any((op.output in fuse_tensor for op in a.ops)):
1202
1207
  return [a], True
1203
- return []
1208
+ return [], False
1204
1209
 
1205
1210
  def _broadcast_onehot(dom, fwd=True):
1206
1211
  """Fuse rule for OneHot."""
1207
1212
  if dom.dom_op().prim != "OneHot":
1208
- return []
1213
+ return [], False
1209
1214
 
1210
1215
  fused = []
1211
1216
  neighbours = dom.in_relations.items() if fwd else dom.out_relations.items()
1212
1217
  for a, _ in neighbours:
1213
1218
  if a.pattern <= PrimLib.BROADCAST:
1214
- if (fwd and a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output) or \
1215
- (not fwd and dom.check_acyclic(a)):
1216
- fused.append(a)
1219
+ if fwd:
1220
+ if a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output:
1221
+ fused.append(a)
1222
+ else:
1223
+ if dom.check_acyclic(a):
1224
+ fused.append(a)
1217
1225
 
1218
1226
  return fused, fwd
1219
1227
 
1220
1228
  def _elemwise_elemany(dom):
1221
1229
  """Fuse rule for elemany."""
1222
1230
  if dom.dom_op().prim != "ElemAny":
1223
- return []
1231
+ return [], False
1224
1232
 
1225
1233
  fused = []
1226
1234
  for a, r in dom.in_relations.items():
@@ -1233,21 +1241,21 @@ class GraphSplitGpu(GraphSplitByPattern):
1233
1241
  """Fuse rule for injective """
1234
1242
  injective_ops = {"Transpose", "StridedSlice"}
1235
1243
  if dom.dom_op().prim not in injective_ops:
1236
- return []
1244
+ return [], False
1237
1245
  to_ops = dom.dom_op().output.to_ops
1238
1246
  if dom.is_output or len(to_ops) != 1 or len(dom.out_relations) != 1:
1239
- return []
1247
+ return [], False
1240
1248
  to_area = list(dom.out_relations.keys())[0]
1241
1249
  if (to_area.pattern >= PrimLib.REDUCE and to_area.dom_op().prim not in injective_ops) or \
1242
1250
  to_ops[0] not in to_area.ops:
1243
- return []
1251
+ return [], False
1244
1252
  if len(to_area.ops) > self.TRANSPOSE_FUSE_DEPTH:
1245
- return []
1253
+ return [], False
1246
1254
  return [to_area], False
1247
1255
 
1248
1256
  def _h_broadcast(dom, a):
1249
1257
  if dom.pattern > PrimLib.BROADCAST:
1250
- return []
1258
+ return [], False
1251
1259
  return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape
1252
1260
 
1253
1261
  def _h_reduce(dom, a):
@@ -1274,7 +1282,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1274
1282
  fuse_arg = {"CSRReduceSum": slice(1, 3), "CSRGather": slice(2, 3)}
1275
1283
  arg_idx = fuse_arg.get(dom.dom_op().prim, -1)
1276
1284
  if arg_idx == -1:
1277
- return []
1285
+ return [], False
1278
1286
  fuse_tensor = dom.dom_op().inputs[arg_idx]
1279
1287
  for a, _ in dom.in_relations.items():
1280
1288
  if (a.dom_op().prim == "CSRGather" and a.dom_op().prim == dom.dom_op().prim and
@@ -1283,7 +1291,7 @@ class GraphSplitGpu(GraphSplitByPattern):
1283
1291
  if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \
1284
1292
  any([op.output in fuse_tensor for op in a.ops]):
1285
1293
  return [a], True
1286
- return []
1294
+ return [], False
1287
1295
 
1288
1296
  def _fuse_loop():
1289
1297
  self.fuse(CommonPattern.reshape)
@@ -13,9 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ===========================================================================
15
15
  """GraphKernel model builder"""
16
-
17
- import copy
18
- from . import op_infer
19
16
  from .model import Tensor, Value, Operator, Graph, AlignShape
20
17
 
21
18
 
@@ -95,18 +92,6 @@ class GraphBuilder:
95
92
  node.all_inputs = inputs
96
93
  self.current.graph.add(node)
97
94
 
98
- def emit(self, prim, inputs, name=None, attrs=None):
99
- """Emit a new operation"""
100
- if attrs is None:
101
- attrs = {}
102
- if isinstance(inputs, (Tensor, Value)):
103
- inputs = [inputs]
104
- tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
105
- out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
106
- output = self.tensor(out_shape, out_dtype, out_format, name)
107
- self.op(prim, output, inputs, attrs)
108
- return output
109
-
110
95
  def get(self):
111
96
  """Get graphs"""
112
97
  return self.graphs
@@ -169,15 +154,18 @@ class CompositeGraph:
169
154
  for op in desc['op_desc']:
170
155
  inputs = [self.tensors.get(d['tensor_name'], None) for x in op['input_desc']
171
156
  for d in x if 'value' not in d]
157
+ if op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin'):
158
+ axis = op['input_desc'][1][0]['value']
159
+ if isinstance(axis, int):
160
+ axis = [axis]
161
+ if not op['attr']:
162
+ attr = [{'name': 'axis', 'dtype': 'listInt', 'value': axis}]
163
+ op['attr'] = attr
164
+ else:
165
+ op['attr'].append({'name': 'axis', 'dtype': 'listInt', 'value': axis})
172
166
  out_desc = op['output_desc']
173
167
  name, shape, dtype, data_format = out_desc[0]['tensor_name'], out_desc[
174
168
  0]['shape'], out_desc[0]['data_type'], out_desc[0]['format']
175
- if op['name'] == 'InplaceAssign':
176
- inputs[0].add_buddy(inputs[1])
177
- inputs[1].para_type = Tensor.PARA_OUTPUT
178
- output = inputs[2]
179
- self.tensors[name] = output
180
- continue
181
169
  output = self.tensors.get(name, None)
182
170
  if not output:
183
171
  output = builder.tensor(shape, dtype, data_format, name=name)
@@ -186,46 +174,17 @@ class CompositeGraph:
186
174
  self.graph = builder.get()[0]
187
175
  self.desc = desc
188
176
 
189
- def _pre_dump(self, outputs):
190
- """restore name to before load"""
191
- inplace_assign = {} # y_name, output_name
192
- inplace_assign_z = None
193
- for op in self.desc['op_desc']:
194
- if op['name'] == 'InplaceAssign':
195
- inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name']
196
- if inplace_assign:
197
- for t in outputs:
198
- if t.name not in inplace_assign:
199
- inplace_assign_z = t
200
- return inplace_assign, inplace_assign_z
201
177
 
202
178
  def dump(self, subgraph):
203
179
  """Dump Graph to json"""
204
180
  desc = {}
205
181
  inputs, outputs = subgraph.deduce_parameters()
206
182
  graph_ops = set(subgraph.ops)
207
- inplace_assign, inplace_assign_z = self._pre_dump(outputs)
208
183
 
209
184
  def dump_output(t):
210
- if t.name in inplace_assign:
211
- z = inplace_assign_z if inplace_assign_z is not None else self.tensors.get(t.name, None)
212
- return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)}
213
185
  return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
214
186
 
215
187
  def dump_op_desc(d):
216
- if d['name'] == 'InplaceAssign':
217
- y = d['input_desc'][1][0]['tensor_name']
218
- if self.tensors[y].op in graph_ops:
219
- z, fake = (inplace_assign_z, False) if inplace_assign_z is not None else (self.tensors.get(y), True)
220
- inplace_desc = copy.deepcopy(d)
221
- inplace_desc['attr'] = {'name': 'fake_output', 'value': fake}
222
- z_desc, out_desc = inplace_desc['input_desc'][2][0], inplace_desc['output_desc'][0]
223
- z_desc['shape'] = z.shape
224
- z_desc['data_type'] = z.dtype
225
- z_desc['tensor_name'] = z.name
226
- out_desc['shape'] = z.shape
227
- out_desc['data_type'] = z.dtype
228
- return inplace_desc
229
188
  op = self.tensors[d['output_desc'][0]['tensor_name']].op
230
189
  if op in graph_ops or op in subgraph.recompute_ops:
231
190
  return d
@@ -36,7 +36,6 @@ def split_with_json(json_str, flags_str):
36
36
  subgraphs, graph_mode = model.split(comp.graph, target, flags)
37
37
  is_multi_graph = len(subgraphs) > 1
38
38
  graph_list = list(map(comp.dump, subgraphs))
39
- _reset_graphmode_for_inplaceassign(graph_list, graph_mode)
40
39
  result = {"multi_graph": is_multi_graph,
41
40
  "graph_desc": graph_list,
42
41
  "graph_mode": graph_mode}
@@ -51,8 +50,9 @@ def split_with_json(json_str, flags_str):
51
50
  def _load_repository(graph, flags):
52
51
  """Load repository if exists"""
53
52
  def check_repo(op, best_split, op_desc):
54
- if not isinstance(best_split, dict) or "group_num" not in best_split or "graph_mode" not in best_split \
55
- or "split_result" not in best_split:
53
+ if not isinstance(best_split, dict):
54
+ return False
55
+ if "group_num" not in best_split or "graph_mode" not in best_split or "split_result" not in best_split:
56
56
  logger.warning("The graph split repository of {} should be a dict which contains 'group_num', 'graph_mode' "
57
57
  "and 'split_result' field, but got {}".format(op, best_split))
58
58
  return False
@@ -114,19 +114,12 @@ def _load_repository(graph, flags):
114
114
  return result
115
115
 
116
116
 
117
- def _reset_graphmode_for_inplaceassign(graph_list, graph_mode):
118
- """Operator with InplaceAssign should always be composite op"""
119
- for i, g in enumerate(graph_list):
120
- if any((op['name'] == 'InplaceAssign' for op in g['op_desc'])):
121
- graph_mode[i] = 'composite'
122
-
123
-
124
117
  def _dump_split_info(use_repo, graph_str, graph, subgraphs, graph_mode, graph_list):
125
118
  """Dump split info as text"""
126
119
  graph_kernel_dump_path = "graph_kernel_dump"
127
120
  utils.create_dir(graph_kernel_dump_path)
128
121
  filename = os.path.join(graph_kernel_dump_path, "graph_kernel_split_mode.%d.txt" % os.getpid())
129
- with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT), "a+") as f:
122
+ with os.fdopen(os.open(filename, os.O_WRONLY | os.O_CREAT, 0o600), "a+") as f:
130
123
  f.write("********** main graph: {} **********\n".format(graph.name))
131
124
  f.write("input json:\n{}\n".format(graph_str))
132
125
  f.write("graph desc:\n{}\n".format(str(graph)))