mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,269 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """GraphKernel expander utils"""
16
- from abc import ABCMeta, abstractmethod
17
- from mindspore._extends.graph_kernel.model import model_builder as builder
18
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
19
-
20
-
21
- class Expander(metaclass=ABCMeta):
22
- """
23
- Expander is the base class of expanders.
24
-
25
- The method `_expand` should be overridden to implement the operator detail.
26
- """
27
- def __init__(self, expand_info):
28
- self.name = expand_info["name"]
29
- self.inputs = expand_info["input_desc"]
30
- self.outputs = expand_info["output_desc"]
31
- self.attrs = expand_info["attr"]
32
- self.processor = expand_info["process"]
33
-
34
- def run(self):
35
- """
36
- Expand the operator to a graph.
37
-
38
- `GraphKernelUnsupportedException` would be raised if check failed.
39
- """
40
- self._check()
41
- graph_builder = builder.GraphBuilder()
42
- with graph_builder.graph_scope(self.name) as graph_scope:
43
- # transform input_desc to Tensor
44
- self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs]
45
- graph_scope.set_input(*self.inputs)
46
- outputs = self._expand(graph_builder)
47
- if isinstance(outputs, (list, tuple)):
48
- self._check_output_same(outputs)
49
- graph_scope.set_output(*outputs)
50
- else:
51
- self._check_output_same([outputs])
52
- graph_scope.set_output(outputs)
53
-
54
- graph = graph_builder.get()[0]
55
- graph.set_processor(self.processor)
56
- return graph
57
-
58
- def _check(self):
59
- """Check inputs"""
60
-
61
- def _check_output_same(self, outputs):
62
- for index, value in enumerate(self.outputs):
63
- if list(outputs[index].shape) != list(value['shape']):
64
- raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
65
- self.__class__.__name__, list(outputs[index].shape), list(value['shape'])))
66
- if outputs[index].dtype != value['data_type']:
67
- raise GKException("{} 's output data_type {} is wrong. Expected: {}".format(
68
- self.__class__.__name__, outputs[index].dtype, value['data_type']))
69
- if outputs[index].data_format != value['format']:
70
- raise GKException("{} 's output format {} is wrong. Expected: {}".format(
71
- self.__class__.__name__, outputs[index].data_format, value['format']))
72
-
73
- @abstractmethod
74
- def _expand(self, graph_builder):
75
- """Expand operator, this function should be overridden in subclass"""
76
- raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
77
-
78
-
79
- class ExpanderInfoValidator:
80
- """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
81
-
82
- def __init__(self):
83
- """Init"""
84
-
85
- @staticmethod
86
- def _add_check_function(kls, func):
87
- """
88
- Rewrite the function `_check` in class Expander
89
- to append the new `func` after the original checks.
90
- """
91
- old_check = getattr(kls, "_check")
92
-
93
- def new_check(obj):
94
- old_check(obj)
95
- func(obj)
96
-
97
- setattr(kls, "_check", new_check)
98
-
99
- @staticmethod
100
- def add_format(*input_format):
101
- """
102
- Add new supported format for the operator
103
-
104
- this function will add a list `__supported_formats` into the expander,
105
- saving the whitelist of formats that this op supports.
106
- it also rewrites the `_check` function to check the formats.
107
- """
108
- format_list_name = "__supported_formats"
109
-
110
- def _check_format(obj):
111
- inp_formats = [inp['format'] for inp in obj.inputs]
112
- for formats in getattr(obj, format_list_name):
113
- if len(formats) != len(inp_formats):
114
- raise GKException("For '{}', length of registered format is different from the length of inputs "
115
- "format: {} vs {}".format(obj.name, len(formats), len(inp_formats)))
116
- if all((fmt == inp for fmt, inp in zip(formats, inp_formats))):
117
- return
118
- raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
119
-
120
- def wrapper(cls):
121
- if not issubclass(cls, Expander):
122
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
123
- if not hasattr(cls, format_list_name):
124
- setattr(cls, format_list_name, list())
125
- ExpanderInfoValidator._add_check_function(cls, _check_format)
126
- getattr(cls, format_list_name).append(input_format)
127
- return cls
128
-
129
- return wrapper
130
-
131
- @staticmethod
132
- def check_all_formats_same(kls):
133
- """Check that all formats are the same"""
134
-
135
- # Ensure no args case can return a class
136
- def _check(*args):
137
- def _check_format(obj):
138
- inp_formats = [inp['format'] for inp in obj.inputs]
139
- if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
140
- return
141
- raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
142
- ','.join(inp_formats), obj.name))
143
-
144
- def wrapper(cls):
145
- if not issubclass(cls, Expander):
146
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
147
- ExpanderInfoValidator._add_check_function(cls, _check_format)
148
- return cls
149
-
150
- return wrapper
151
-
152
- return _check()(kls)
153
-
154
- @staticmethod
155
- def check_attrs(*args):
156
- """Check the attrs exist"""
157
-
158
- def _check_attr(obj):
159
- for a in args:
160
- if a not in obj.attrs:
161
- raise GKException("attr '{}' does not exist. {}".format(a, obj.name))
162
-
163
- def wrapper(cls):
164
- if not issubclass(cls, Expander):
165
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
166
- ExpanderInfoValidator._add_check_function(cls, _check_attr)
167
- return cls
168
-
169
- return wrapper
170
-
171
-
172
- def to_frac_z_axis(ori_shape, ori_axis):
173
- """
174
- judge the format is fractal NZ
175
- Parameters
176
- ----------
177
- ori_shape: list or tuple
178
- original shape of input
179
- ori_axis: list or tuple
180
- original axis of original shape to operate
181
- Returns
182
- -------
183
- output: list
184
- axis of the fractal Nz shape
185
- """
186
- frac_z_axis = list(ori_axis)
187
- shape_len = len(ori_shape)
188
- axis_count = len(frac_z_axis)
189
- axis_negative_1 = shape_len - 1
190
- axis_negative_2 = shape_len - 2
191
- for i in range(axis_count):
192
- axis_index = (frac_z_axis[i] + shape_len) % shape_len
193
- if axis_index == axis_negative_1:
194
- if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
195
- frac_z_axis[i] = axis_index - 1
196
- frac_z_axis.append(axis_index + 2)
197
- else: # no case cover this branch now
198
- frac_z_axis[i] = axis_index - 1
199
- frac_z_axis.append(axis_index + 2)
200
- elif axis_index == axis_negative_2:
201
- frac_z_axis[i] = axis_index + 1
202
- frac_z_axis.append(axis_index + 2)
203
- else:
204
- frac_z_axis[i] = axis_index
205
- return frac_z_axis
206
-
207
-
208
- def infer_shape_from_fractalnz(fractal):
209
- "get original shape from fractalnz shape"
210
- shape = []
211
- dims = len(fractal)
212
- batch = dims - 4
213
- for i in range(batch):
214
- shape.append(fractal[i])
215
- m = fractal[dims - 3] * fractal[dims - 2]
216
- n = fractal[dims - 4] * fractal[dims - 1]
217
- shape.append(m)
218
- shape.append(n)
219
- return shape
220
-
221
-
222
- def get_reduced_ori_shape(shape, axis):
223
- "get shape after reduced which is based on original shape"
224
- reduced_ori_shape = []
225
- for i, value in enumerate(shape):
226
- if i in axis:
227
- reduced_ori_shape.append(1)
228
- else:
229
- reduced_ori_shape.append(value)
230
- return reduced_ori_shape
231
-
232
-
233
- def get_reduce_axis_shape(shape, data_format, axis):
234
- """
235
- Get the reduce axis under format `data_format` and original reduced shape.
236
- Parameters
237
- ----------
238
- shape: list or tuple
239
- shape of input
240
- data_format: str
241
- data format of input
242
- axis: None, int, list or tuple
243
- reduce axis of the original shape
244
- Returns
245
- -------
246
- reduce_axis: list
247
- reduce axis of the `data_format` shape
248
- ori_reduced_shape: list
249
- original reduced shape
250
- """
251
- ori_shape = shape
252
- if data_format == "FRACTAL_NZ":
253
- ori_shape = infer_shape_from_fractalnz(shape)
254
- if not axis:
255
- axis = []
256
- for i, _ in enumerate(ori_shape):
257
- axis.append(i)
258
- else:
259
- if isinstance(axis, int):
260
- axis = [axis]
261
- for i, _ in enumerate(list(axis)):
262
- if axis[i] < 0:
263
- axis[i] += len(ori_shape)
264
-
265
- ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
266
- reduce_axis = axis
267
- if data_format == "FRACTAL_NZ":
268
- reduce_axis = to_frac_z_axis(ori_shape, axis)
269
- return reduce_axis, ori_reduced_shape
@@ -1,33 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for addn"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class AddN(Expander):
22
- """Expand AddN to multiple Adds"""
23
-
24
- def _check(self):
25
- if len(self.inputs) < 2:
26
- raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
27
- .format(len(self.inputs)))
28
-
29
- def _expand(self, graph_builder):
30
- result = self.inputs[0]
31
- for inp in self.inputs[1:]:
32
- result = graph_builder.emit('Add', [result, inp])
33
- return result
@@ -1,152 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for BatchNorm"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from .expand_dims import ExpandDims
19
-
20
-
21
- @VLD.add_format(DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
22
- @VLD.add_format(DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
24
- @VLD.check_attrs('is_training', 'momentum', 'epsilon')
25
- class BatchNorm(Expander):
26
- """BatchNorm expander"""
27
-
28
- def _expand(self, graph_builder):
29
- # get op info
30
- input_x = self.inputs[0]
31
- input_scale = self.inputs[1]
32
- input_offset = self.inputs[2]
33
- input_mean = self.inputs[3]
34
- input_variance = self.inputs[4]
35
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
36
-
37
- input_x_ori_type = input_x.dtype
38
- input_x_new_type = input_x.dtype
39
- if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
40
- input_mean.dtype == "float32" and input_variance.dtype == "float32":
41
- input_x_new_type = "float32"
42
- if input_x_new_type != input_x_ori_type:
43
- input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
44
-
45
- if self.attrs['is_training']:
46
- self.inputs[0] = input_x
47
- res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
48
- if input_x_new_type != input_x_ori_type:
49
- res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
50
- return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
51
- # infer mode
52
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
53
- input_mean = graph_builder.emit(
54
- 'Reshape', [input_mean], attrs={'shape': ExpandDims.infer_shape(input_mean.shape, [-1, -1])})
55
- input_scale = graph_builder.emit(
56
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
57
- input_offset = graph_builder.emit(
58
- 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
59
- x_sub = graph_builder.emit('Sub', [input_x, input_mean])
60
- x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub])
61
- var_add = graph_builder.emit('Add', [epsilon_v, input_variance])
62
- var_add_sqrt = graph_builder.emit('Sqrt', [var_add])
63
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
64
- var_add_sqrt = graph_builder.emit(
65
- 'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
66
- x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
67
- res_y = graph_builder.emit('Add', [input_offset, x_div])
68
- if input_x_new_type != input_x_ori_type:
69
- res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
70
- return res_y, var_add, var_add, var_add, var_add
71
-
72
- def _bn_train(self, graph_builder):
73
- """expand BatchNorm for training mode"""
74
- input_x = self.inputs[0]
75
- input_scale = self.inputs[1]
76
- input_offset = self.inputs[2]
77
- input_mean = self.inputs[3]
78
- input_variance = self.inputs[4]
79
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
80
- reduce_axis = ()
81
- shape_x = input_x.shape
82
- if input_x.data_format == DF.NHWC:
83
- reduce_axis = (0, 1, 2)
84
- num = shape_x[0] * shape_x[1] * shape_x[2]
85
- else:
86
- reduce_axis = (0, 2, 3)
87
- num = shape_x[0] * shape_x[2] * shape_x[3]
88
- num_rec = 1.0 / num
89
- num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
90
-
91
- # compute mean value of input_x
92
- mean_sum = graph_builder.emit(
93
- 'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
94
- mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
95
-
96
- # compute variance of input_x
97
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
98
- mean_muls_expand = graph_builder.emit(
99
- 'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
100
- else:
101
- mean_muls_expand = mean_muls
102
- var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
103
- var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
104
- var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
105
- var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
106
-
107
- # y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
108
- scalar_one = 1.0
109
- scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
110
- y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
111
- y_sqrt = graph_builder.emit('Sqrt', [y_add])
112
- y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
113
-
114
- # compute res_y
115
- tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
116
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
117
- y_sqrt_rec_expand = graph_builder.emit(
118
- 'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
119
- else:
120
- y_sqrt_rec_expand = y_sqrt_rec
121
- y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
122
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
123
- input_scale_expand = graph_builder.emit(
124
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
125
- else:
126
- input_scale_expand = input_scale
127
- res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
128
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
129
- input_offset_expand = graph_builder.emit(
130
- 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
131
- else:
132
- input_offset_expand = input_offset
133
- res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
134
-
135
- # compute mean_res
136
- momentum_sub = scalar_one - self.attrs['momentum']
137
- momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
138
- new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
139
- momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
140
- current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
141
- updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
142
- mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
143
-
144
- # variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
145
- var_num = float(num) / (num - 1)
146
- var_num_v = graph_builder.value(input_scale.dtype, var_num)
147
- var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
148
- new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
149
- current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
150
- updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
151
- variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
152
- return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
@@ -1,105 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for BatchNormGrad"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from .expand_dims import ExpandDims
19
-
20
-
21
- @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
22
- @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
24
- @VLD.check_attrs('is_training', 'epsilon')
25
- class BatchNormGrad(Expander):
26
- """BatchNormGrad expander"""
27
-
28
- def _expand(self, graph_builder):
29
- # get op info
30
- input_dy = self.inputs[0]
31
- input_x = self.inputs[1]
32
- input_scale = self.inputs[2]
33
- input_save_mean = self.inputs[3]
34
- input_save_inv_variance = self.inputs[4]
35
-
36
- reduce_axis = ()
37
- shape_x = input_x.shape
38
- if input_x.data_format == DF.NHWC:
39
- reduce_axis = (0, 1, 2)
40
- num = shape_x[0] * shape_x[1] * shape_x[2]
41
- else:
42
- reduce_axis = (0, 2, 3)
43
- num = shape_x[0] * shape_x[2] * shape_x[3]
44
- ori_type = input_x.dtype
45
- if ori_type == 'float16':
46
- input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
47
- if input_dy.dtype == 'float16':
48
- input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
49
- num_rec = -1.0 / num
50
- num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
51
- dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
52
-
53
- # in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass
54
- if self.attrs['is_training']:
55
- inv_variance = input_save_inv_variance
56
- else:
57
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
58
- var_add = graph_builder.emit('Add', [input_save_inv_variance, epsilon_v])
59
- sqrt_var_eps = graph_builder.emit('Sqrt', [var_add])
60
- scalar_one = 1.0
61
- scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
62
- inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
63
-
64
- # compute dgamma
65
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
66
- input_save_mean = graph_builder.emit(
67
- 'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
68
- inv_variance = graph_builder.emit(
69
- 'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
70
- input_scale = graph_builder.emit(
71
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
72
- x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean])
73
- x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance])
74
- dgamma_param = graph_builder.emit('Mul', [input_dy, x_div])
75
- dgamma = graph_builder.emit(
76
- 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
77
-
78
- # compute dx
79
- if self.attrs['is_training']:
80
- tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
81
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
82
- dgamma_expand = graph_builder.emit(
83
- 'Reshape', [dgamma], attrs={'shape': ExpandDims.infer_shape(dgamma.shape, [-1, -1])})
84
- tmp_b = graph_builder.emit(
85
- 'Reshape', [tmp_b], attrs={'shape': ExpandDims.infer_shape(tmp_b.shape, [-1, -1])})
86
- else:
87
- dgamma_expand = dgamma
88
- x_sub_mean_dgamma_mul = graph_builder.emit('Mul', [x_div, dgamma_expand])
89
- tmp_c = graph_builder.emit('Mul', [num_rec_v, x_sub_mean_dgamma_mul])
90
- tmp_ab_add = graph_builder.emit('Add', [input_dy, tmp_b])
91
- tmp_abc_add = graph_builder.emit('Add', [tmp_ab_add, tmp_c])
92
- gamma_mul = graph_builder.emit('Mul', [input_scale, tmp_abc_add])
93
- dx = graph_builder.emit('Mul', [inv_variance, gamma_mul])
94
- else:
95
- y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
96
- dx = graph_builder.emit('Mul', [inv_variance, y_scale])
97
- if ori_type == 'float16':
98
- dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
99
-
100
- # set output tensors' data_format
101
- dx.data_format = self.outputs[0]['format']
102
- dgamma.data_format = self.outputs[1]['format']
103
- dbeta.data_format = self.outputs[2]['format']
104
-
105
- return dx, dgamma, dbeta
@@ -1,33 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for ClipByNormNoDivSum"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class ClipByNormNoDivSum(Expander):
21
- """ClipByNormNoDivSum expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x0, input_x1, input_x2, input_x3 = self.inputs
25
-
26
- # cal result
27
- greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
28
- select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
29
- sqrt_res = graph_builder.emit('Sqrt', [select_res0])
30
- select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
31
- result = graph_builder.emit('Maximum', [select_res1, input_x3])
32
-
33
- return result
@@ -1,30 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cabs"""
16
- from mindspore._extends.graph_kernel.expanders._utils import Expander
17
-
18
-
19
- class CAbs(Expander):
20
- """CAbs expander"""
21
-
22
- def _expand(self, graph_builder):
23
- input_x = self.inputs[0]
24
- x_real = graph_builder.emit('CReal', [input_x])
25
- x_imag = graph_builder.emit('CImag', [input_x])
26
- squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
27
- squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
28
- squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
29
- result = graph_builder.emit('Sqrt', [squre_sum])
30
- return result
@@ -1,44 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cadd"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CAdd(Expander):
22
- """CAdd expander"""
23
-
24
- def _expand(self, graph_builder):
25
- input_x, input_y = self.inputs
26
- if input_x.dtype == input_y.dtype:
27
- x_real = graph_builder.emit('CReal', [input_x])
28
- y_real = graph_builder.emit('CReal', [input_y])
29
- x_imag = graph_builder.emit('CImag', [input_x])
30
- y_imag = graph_builder.emit('CImag', [input_y])
31
- result_real = graph_builder.emit('Add', [x_real, y_real])
32
- result_imag = graph_builder.emit('Add', [x_imag, y_imag])
33
- result = graph_builder.emit('Complex', [result_real, result_imag])
34
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
35
- x_real = graph_builder.emit('CReal', [input_x])
36
- x_imag = graph_builder.emit('CImag', [input_x])
37
- x_real_add_y = graph_builder.emit('Add', [x_real, input_y])
38
- result = graph_builder.emit('Complex', [x_real_add_y, x_imag])
39
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
40
- y_real = graph_builder.emit('CReal', [input_y])
41
- y_imag = graph_builder.emit('CImag', [input_y])
42
- y_real_add_x = graph_builder.emit('Add', [y_real, input_x])
43
- result = graph_builder.emit('Complex', [y_real_add_x, y_imag])
44
- return result