mindspore 2.1.0__cp39-none-any.whl → 2.2.11__cp39-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (578) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-39-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-39-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-39-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-39-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +61 -95
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/__init__.py +4 -2
  257. mindspore/nn/wrap/cell_wrapper.py +87 -34
  258. mindspore/nn/wrap/grad_reducer.py +8 -5
  259. mindspore/nn/wrap/loss_scale.py +105 -42
  260. mindspore/numpy/array_creations.py +1 -2
  261. mindspore/numpy/array_ops.py +3 -2
  262. mindspore/numpy/utils_const.py +5 -5
  263. mindspore/offline_debug/convert_async.py +2 -2
  264. mindspore/ops/_grad_experimental/__init__.py +0 -5
  265. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  266. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  267. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  268. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  269. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  270. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  271. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  272. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  273. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  274. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  275. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  276. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  277. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  278. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  279. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  280. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  281. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  282. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  283. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  284. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  285. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  286. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  287. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  288. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  289. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  290. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  291. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  293. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  294. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  295. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  296. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  298. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  299. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  300. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  301. mindspore/ops/_primitive_cache.py +1 -1
  302. mindspore/ops/_tracefunc.py +45 -13
  303. mindspore/ops/_utils/utils.py +6 -1
  304. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  305. mindspore/ops/_vmap/vmap_base.py +3 -3
  306. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  307. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  308. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  309. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  310. mindspore/ops/arg_dtype_cast.py +54 -0
  311. mindspore/ops/composite/base.py +37 -10
  312. mindspore/ops/composite/math_ops.py +5 -4
  313. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  314. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  315. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  316. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  317. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  320. mindspore/ops/deprecated.py +304 -0
  321. mindspore/ops/function/__init__.py +4 -1
  322. mindspore/ops/function/array_func.py +174 -193
  323. mindspore/ops/function/clip_func.py +81 -13
  324. mindspore/ops/function/debug_func.py +1 -1
  325. mindspore/ops/function/grad/grad_func.py +18 -9
  326. mindspore/ops/function/image_func.py +10 -4
  327. mindspore/ops/function/linalg_func.py +5 -5
  328. mindspore/ops/function/math_func.py +575 -386
  329. mindspore/ops/function/nn_func.py +568 -260
  330. mindspore/ops/function/random_func.py +88 -57
  331. mindspore/ops/function/sparse_func.py +1 -1
  332. mindspore/ops/function/sparse_unary_func.py +14 -12
  333. mindspore/ops/function/vmap_func.py +6 -5
  334. mindspore/ops/functional.py +15 -10
  335. mindspore/ops/op_info_register.py +244 -25
  336. mindspore/ops/operations/__init__.py +31 -19
  337. mindspore/ops/operations/_grad_ops.py +71 -7
  338. mindspore/ops/operations/_inner_ops.py +350 -17
  339. mindspore/ops/operations/_quant_ops.py +4 -8
  340. mindspore/ops/operations/_sequence_ops.py +42 -0
  341. mindspore/ops/operations/array_ops.py +68 -282
  342. mindspore/ops/operations/comm_ops.py +107 -59
  343. mindspore/ops/operations/custom_ops.py +94 -70
  344. mindspore/ops/operations/debug_ops.py +8 -4
  345. mindspore/ops/operations/image_ops.py +18 -12
  346. mindspore/ops/operations/inner_ops.py +26 -3
  347. mindspore/ops/operations/math_ops.py +192 -144
  348. mindspore/ops/operations/nn_ops.py +857 -489
  349. mindspore/ops/operations/other_ops.py +0 -22
  350. mindspore/ops/operations/random_ops.py +53 -111
  351. mindspore/ops/operations/sparse_ops.py +3 -1
  352. mindspore/ops/primitive.py +24 -18
  353. mindspore/parallel/_auto_parallel_context.py +68 -8
  354. mindspore/parallel/_cost_model_context.py +2 -2
  355. mindspore/parallel/_offload_context.py +17 -3
  356. mindspore/parallel/_parallel_serialization.py +12 -5
  357. mindspore/parallel/_ps_context.py +12 -0
  358. mindspore/parallel/_tensor.py +18 -13
  359. mindspore/parallel/_transformer/layers.py +5 -3
  360. mindspore/parallel/_transformer/loss.py +1 -0
  361. mindspore/parallel/_transformer/moe.py +2 -2
  362. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  363. mindspore/parallel/_transformer/transformer.py +23 -3
  364. mindspore/parallel/_utils.py +11 -7
  365. mindspore/parallel/algo_parameter_config.py +85 -5
  366. mindspore/parallel/checkpoint_transform.py +19 -12
  367. mindspore/parallel/shard.py +21 -14
  368. mindspore/profiler/common/struct_type.py +3 -3
  369. mindspore/profiler/common/util.py +4 -2
  370. mindspore/profiler/envprofiling.py +1 -1
  371. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  372. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  373. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  374. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  375. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  376. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  377. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  378. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  379. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  380. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  381. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  382. mindspore/profiler/parser/flops_parser.py +15 -11
  383. mindspore/profiler/parser/framework_parser.py +38 -22
  384. mindspore/profiler/parser/hccl_parser.py +16 -12
  385. mindspore/profiler/parser/integrator.py +22 -11
  386. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  387. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  388. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  389. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  390. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  391. mindspore/profiler/parser/optime_parser.py +1 -1
  392. mindspore/profiler/parser/profiler_info.py +21 -2
  393. mindspore/profiler/parser/step_trace_parser.py +11 -14
  394. mindspore/profiler/profiling.py +179 -89
  395. mindspore/rewrite/api/node.py +102 -19
  396. mindspore/rewrite/api/node_type.py +5 -1
  397. mindspore/rewrite/api/pattern_engine.py +1 -1
  398. mindspore/rewrite/api/scoped_value.py +9 -17
  399. mindspore/rewrite/api/symbol_tree.py +131 -47
  400. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  401. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  402. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  403. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  404. mindspore/rewrite/common/rewrite_elog.py +5 -1
  405. mindspore/rewrite/namer.py +33 -24
  406. mindspore/rewrite/namespace.py +14 -5
  407. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  408. mindspore/rewrite/node/call_function.py +79 -0
  409. mindspore/rewrite/node/cell_container.py +135 -0
  410. mindspore/rewrite/node/control_flow.py +88 -0
  411. mindspore/rewrite/{node.py → node/node.py} +273 -234
  412. mindspore/rewrite/node/node_manager.py +254 -0
  413. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  414. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  415. mindspore/rewrite/parsers/assign_parser.py +216 -221
  416. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  417. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  418. mindspore/rewrite/parsers/constant_parser.py +9 -6
  419. mindspore/rewrite/parsers/container_parser.py +9 -7
  420. mindspore/rewrite/parsers/for_parser.py +42 -21
  421. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  422. mindspore/rewrite/parsers/if_parser.py +28 -24
  423. mindspore/rewrite/parsers/module_parser.py +196 -25
  424. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  425. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  426. mindspore/rewrite/parsers/return_parser.py +6 -6
  427. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  428. mindspore/rewrite/sparsify/utils.py +1 -1
  429. mindspore/rewrite/symbol_tree.py +523 -578
  430. mindspore/rewrite/symbol_tree_builder.py +9 -193
  431. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  432. mindspore/run_check/_check_version.py +6 -4
  433. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  434. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  435. mindspore/scipy/linalg.py +1 -1
  436. mindspore/scipy/ops.py +55 -5
  437. mindspore/scipy/optimize/__init__.py +3 -2
  438. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  439. mindspore/scipy/optimize/minimize.py +7 -3
  440. mindspore/train/_utils.py +7 -3
  441. mindspore/train/amp.py +323 -123
  442. mindspore/train/anf_ir_pb2.py +14 -2
  443. mindspore/train/callback/_backup_and_restore.py +2 -12
  444. mindspore/train/callback/_callback.py +29 -4
  445. mindspore/train/callback/_checkpoint.py +23 -8
  446. mindspore/train/callback/_early_stop.py +2 -2
  447. mindspore/train/callback/_landscape.py +4 -4
  448. mindspore/train/callback/_loss_monitor.py +2 -2
  449. mindspore/train/callback/_on_request_exit.py +2 -2
  450. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  451. mindspore/train/callback/_summary_collector.py +15 -8
  452. mindspore/train/callback/_time_monitor.py +58 -5
  453. mindspore/train/data_sink.py +5 -11
  454. mindspore/train/dataset_helper.py +84 -57
  455. mindspore/train/loss_scale_manager.py +2 -2
  456. mindspore/train/metrics/__init__.py +3 -3
  457. mindspore/train/metrics/cosine_similarity.py +1 -1
  458. mindspore/train/metrics/hausdorff_distance.py +3 -2
  459. mindspore/train/metrics/mean_surface_distance.py +3 -2
  460. mindspore/train/metrics/metric.py +39 -19
  461. mindspore/train/metrics/roc.py +2 -2
  462. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  463. mindspore/train/mind_ir_pb2.py +85 -36
  464. mindspore/train/model.py +187 -47
  465. mindspore/train/serialization.py +487 -161
  466. mindspore/train/summary/_summary_adapter.py +1 -1
  467. mindspore/train/summary/_writer_pool.py +3 -2
  468. mindspore/train/summary/summary_record.py +37 -17
  469. mindspore/train/train_thor/convert_utils.py +3 -3
  470. mindspore/train/train_thor/dataset_helper.py +1 -1
  471. mindspore/version.py +1 -1
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +477 -528
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  475. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  478. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  479. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  480. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  481. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  482. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  483. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  486. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  487. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  488. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  489. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  490. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  491. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  492. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  493. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  494. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  495. mindspore/_extends/graph_kernel/expander.py +0 -80
  496. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  497. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  498. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  499. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  500. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  501. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  502. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  503. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  504. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  505. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  506. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  507. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  508. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  509. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  510. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  511. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  512. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  513. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  514. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  515. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  516. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  517. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  518. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  519. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  520. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  523. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  524. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  525. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  526. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  527. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  528. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  532. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  533. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  534. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  535. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  536. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  537. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  538. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  539. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  541. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  542. mindspore/dataset/datapreprocess/__init__.py +0 -20
  543. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  544. mindspore/include/api/net.h +0 -142
  545. mindspore/nn/lr_scheduler.py +0 -262
  546. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  547. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  548. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  549. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  550. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  559. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  560. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  564. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  575. mindspore/rewrite/node_visitor.py +0 -44
  576. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  578. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,62 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cdiv"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CDiv(Expander):
22
- """CDiv expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CDiv Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
33
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
34
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
35
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
36
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
37
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
38
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
39
- final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
40
- final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
41
- result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
42
- result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
43
- result = graph_builder.emit('Complex', [result_real, result_imag])
44
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
45
- x_real = graph_builder.emit('CReal', [input_x])
46
- x_imag = graph_builder.emit('CImag', [input_x])
47
- x_real_div_y = graph_builder.emit('RealDiv', [x_real, input_y])
48
- x_imag_div_y = graph_builder.emit('RealDiv', [x_imag, input_y])
49
- result = graph_builder.emit('Complex', [x_real_div_y, x_imag_div_y])
50
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
51
- y_real = graph_builder.emit('CReal', [input_y])
52
- y_imag = graph_builder.emit('CImag', [input_y])
53
- neg_y_imag = graph_builder.emit('Neg', [y_imag])
54
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
55
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
56
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
57
- x_mul_y_real = graph_builder.emit('Mul', [input_x, y_real])
58
- x_mul_neg_y_imag = graph_builder.emit('Mul', [input_x, neg_y_imag])
59
- y_real_div_x = graph_builder.emit('RealDiv', [x_mul_y_real, final_denominator])
60
- y_imag_div_x = graph_builder.emit('RealDiv', [x_mul_neg_y_imag, final_denominator])
61
- result = graph_builder.emit('Complex', [y_real_div_x, y_imag_div_x])
62
- return result
@@ -1,52 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cmul"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CMul(Expander):
22
- """CMul expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CMul Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
33
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
34
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
35
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
36
- result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag])
37
- result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real])
38
- result = graph_builder.emit('Complex', [result_real, result_imag])
39
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
40
- x_real = graph_builder.emit('CReal', [input_x])
41
- x_imag = graph_builder.emit('CImag', [input_x])
42
- x_real_mul_y = graph_builder.emit('Mul', [x_real, input_y])
43
- x_imag_mul_y = graph_builder.emit('Mul', [x_imag, input_y])
44
- result = graph_builder.emit('Complex', [x_real_mul_y, x_imag_mul_y])
45
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
46
- y_real = graph_builder.emit('CReal', [input_y])
47
- y_imag = graph_builder.emit('CImag', [input_y])
48
- y_real_mul_x = graph_builder.emit('Mul', [y_real, input_x])
49
- y_imag_mul_x = graph_builder.emit('Mul', [y_imag, input_x])
50
- result = graph_builder.emit('Complex', [y_real_mul_x, y_imag_mul_x])
51
-
52
- return result
@@ -1,62 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for crealdiv"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CRealDiv(Expander):
22
- """CRealDiv expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CRealDiv Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
33
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
34
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
35
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
36
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
37
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
38
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
39
- final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
40
- final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
41
- result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
42
- result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
43
- result = graph_builder.emit('Complex', [result_real, result_imag])
44
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
45
- x_real = graph_builder.emit('CReal', [input_x])
46
- x_imag = graph_builder.emit('CImag', [input_x])
47
- x_real_div_y = graph_builder.emit('RealDiv', [x_real, input_y])
48
- x_imag_div_y = graph_builder.emit('RealDiv', [x_imag, input_y])
49
- result = graph_builder.emit('Complex', [x_real_div_y, x_imag_div_y])
50
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
51
- y_real = graph_builder.emit('CReal', [input_y])
52
- y_imag = graph_builder.emit('CImag', [input_y])
53
- neg_y_imag = graph_builder.emit('Neg', [y_imag])
54
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
55
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
56
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
57
- x_mul_y_real = graph_builder.emit('Mul', [input_x, y_real])
58
- x_mul_neg_y_imag = graph_builder.emit('Mul', [input_x, neg_y_imag])
59
- y_real_div_x = graph_builder.emit('RealDiv', [x_mul_y_real, final_denominator])
60
- y_imag_div_x = graph_builder.emit('RealDiv', [x_mul_neg_y_imag, final_denominator])
61
- result = graph_builder.emit('Complex', [y_real_div_x, y_imag_div_x])
62
- return result
@@ -1,45 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for csub"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CSub(Expander):
22
- """CSub expander"""
23
-
24
- def _expand(self, graph_builder):
25
- input_x, input_y = self.inputs
26
- if input_x.dtype == input_y.dtype:
27
- x_real = graph_builder.emit('CReal', [input_x])
28
- y_real = graph_builder.emit('CReal', [input_y])
29
- x_imag = graph_builder.emit('CImag', [input_x])
30
- y_imag = graph_builder.emit('CImag', [input_y])
31
- result_real = graph_builder.emit('Sub', [x_real, y_real])
32
- result_imag = graph_builder.emit('Sub', [x_imag, y_imag])
33
- result = graph_builder.emit('Complex', [result_real, result_imag])
34
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
35
- x_real = graph_builder.emit('CReal', [input_x])
36
- x_imag = graph_builder.emit('CImag', [input_x])
37
- x_real_sub_y = graph_builder.emit('Sub', [x_real, input_y])
38
- result = graph_builder.emit('Complex', [x_real_sub_y, x_imag])
39
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
40
- y_real = graph_builder.emit('CReal', [input_y])
41
- y_imag = graph_builder.emit('CImag', [input_y])
42
- x_sub_y_real = graph_builder.emit('Sub', [input_x, y_real])
43
- y_imag = graph_builder.emit('Neg', [y_imag])
44
- result = graph_builder.emit('Complex', [x_sub_y_real, y_imag])
45
- return result
@@ -1,200 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for Conv2D"""
16
- from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad
17
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
18
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
19
- from ._utils import Expander, ExpanderInfoValidator as VLD
20
-
21
-
22
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.NHWC, DF.NHWC)
24
- @VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
25
- class Conv2D(Expander):
26
- """
27
- Conv2D expander
28
-
29
- Currently, only Conv2D that meets several conditions can be expanded, other cases will be skipped.
30
- Conditions to expand:
31
- inputs are NHWC format and float16.
32
- attr groups and group are 1.
33
- attr dilation are all 1.
34
- N channel of inputs > 16.
35
- C channel of inputs > 8.
36
- output N*H*W are multiplies of 128.
37
- """
38
- M_ALIGN = 32
39
- N_ALIGN = 32
40
- K_ALIGN = 16
41
- K_LIMIT = 800
42
- MNK_LIMIT = 3 * (10 ** 10)
43
- N0_CHANNEL_ALIGN = 32
44
- N1_CHANNEL_ALIGN = 32
45
- C_CHANNEL_ALIGN = 16
46
- OUT_NHW_ALIGN = 128
47
-
48
- def __init__(self, expand_info):
49
- super().__init__(expand_info)
50
- self.dst_type = self.outputs[0]['data_type']
51
- self.dst_format = self.outputs[0]['format']
52
- self.has_pad = False
53
- self.can_optimize_to_matmul = False
54
- self.shape_0_pad = self.inputs[0]['shape']
55
- self.shape_1_pad = self.inputs[1]['shape']
56
- self.m = 0
57
- self.n = 0
58
- self.k = 0
59
-
60
- def _optimize_to_matmul(self):
61
- stride = self.attrs['stride']
62
- dilation = self.attrs['dilation']
63
- _, h, w, _ = self.inputs[1]['shape']
64
- if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
65
- self.m % self.M_ALIGN == 0 and self.n % self.N_ALIGN == 0 and self.k % self.K_ALIGN == 0:
66
- return True
67
- return False
68
-
69
- def _common_check(self):
70
- """common check for inputs and attrs"""
71
- type_0 = self.inputs[0]['data_type']
72
- type_1 = self.inputs[1]['data_type']
73
- if type_0 != "float16" or type_1 != "float16":
74
- raise GKException("For 'Conv2D', inputs data type should be both float16, but got {} and {}"
75
- .format(type_0, type_1))
76
-
77
- formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
78
- check_format_any(formats, DF.NHWC)
79
-
80
- groups = self.attrs['groups']
81
- group = self.attrs['group']
82
- if groups != 1 or group != 1:
83
- raise GKException("For 'Conv2D', value of attr 'groups' and 'group' should be both 1, but got {} and {}."
84
- .format(groups, group))
85
-
86
- dilation = self.attrs['dilation']
87
- check_nd(dilation, 4)
88
- if dilation != [1, 1, 1, 1]:
89
- raise GKException("For 'Conv2D', value of attr 'dilation' should be [1, 1, 1, 1], but got {}"
90
- .format(dilation))
91
-
92
- def _check(self):
93
- self._common_check()
94
-
95
- pad_list = self.attrs['pad_list']
96
- check_nd(pad_list, 4)
97
- self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode'])
98
-
99
- shape_0 = self.inputs[0]['shape']
100
- shape_1 = self.inputs[1]['shape']
101
- stride = self.attrs['stride']
102
- check_nd(shape_0, 4)
103
- check_nd(shape_1, 4)
104
- check_nd(stride, 4)
105
- n0, h0, w0, c0 = shape_0
106
- n1, h1, w1, c1 = shape_1
107
- if (n0 % self.N0_CHANNEL_ALIGN) != 0:
108
- raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}"
109
- .format(self.N0_CHANNEL_ALIGN, n0))
110
- if (n1 % self.N1_CHANNEL_ALIGN) != 0:
111
- raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}"
112
- .format(self.N1_CHANNEL_ALIGN, n1))
113
- if c0 != c1 or (c0 % self.C_CHANNEL_ALIGN) != 0:
114
- raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got "
115
- "{} and {}".format(self.C_CHANNEL_ALIGN, c0, c1))
116
- # n0 pad
117
- n0 = ((n0 + self.N0_CHANNEL_ALIGN - 1) //
118
- self.N0_CHANNEL_ALIGN) * self.N0_CHANNEL_ALIGN
119
- # h0, w0 pad
120
- if self.has_pad:
121
- h0 = h0 + pad_list[0] + pad_list[1]
122
- w0 = w0 + pad_list[2] + pad_list[3]
123
- # c0, c1 pad
124
- c0 = ((c0 + self.C_CHANNEL_ALIGN - 1) // self.C_CHANNEL_ALIGN) * self.C_CHANNEL_ALIGN
125
- c1 = c0
126
- # n1 pad
127
- n1 = ((n1 + self.N1_CHANNEL_ALIGN - 1) //
128
- self.N1_CHANNEL_ALIGN) * self.N1_CHANNEL_ALIGN
129
-
130
- # check if can optimize to matmul
131
- self.m, self.n, self.k = n0 * h0 * w0, n1, c1
132
- self.can_optimize_to_matmul = self._optimize_to_matmul()
133
-
134
- # requirements
135
- if self.can_optimize_to_matmul:
136
- if self.k > self.K_LIMIT:
137
- raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got "
138
- "{}".format(self.K_LIMIT, self.k))
139
- if self.m * self.n * self.k >= self.MNK_LIMIT:
140
- raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than "
141
- "{}, but got {}".format(self.MNK_LIMIT, self.m * self.n * self.k))
142
- else:
143
- out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
144
- if ((n0 * out_h * out_w) % self.OUT_NHW_ALIGN) != 0:
145
- raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}"
146
- .format(n0, out_h, out_w, self.OUT_NHW_ALIGN))
147
- if stride != [1, 1, 2, 2]:
148
- raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}"
149
- .format(stride))
150
-
151
- self.shape_0_pad = [n0, h0, w0, c0]
152
- self.shape_1_pad = [n1, h1, w1, c1]
153
-
154
- def _expand(self, graph_builder):
155
- input_0 = self.inputs[0]
156
- input_1 = self.inputs[1]
157
- n0, _, _, c0 = input_0.shape
158
- n1, _, _, c1 = input_1.shape
159
- n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
160
- n1_p, _, _, c1_p = self.shape_1_pad
161
-
162
- pad_value = 0
163
- # input0 pad
164
- input_0_pad_before = [0, 0, 0, 0]
165
- input_0_pad_after = [0, 0, 0, 0]
166
- if self.has_pad:
167
- pad_list = self.attrs['pad_list']
168
- input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
169
- input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
170
- input_0_pad_after[0] = n0_p - n0
171
- input_0_pad_after[3] = c0_p - c0
172
- if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
173
- input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
174
- 'tail': input_0_pad_after,
175
- 'pad_val': pad_value})
176
- # input1 pad
177
- input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
178
- if input_1_pad_after != [0, 0, 0, 0]:
179
- input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
180
- 'tail': input_1_pad_after,
181
- 'pad_val': pad_value})
182
- if self.can_optimize_to_matmul:
183
- a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
184
- b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
185
- c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
186
- 'transpose_b': True,
187
- 'dst_type': self.dst_type})
188
- result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
189
- 'format': self.dst_format})
190
- else:
191
- attrs = self.attrs
192
- attrs['pad_list'] = [0, 0, 0, 0]
193
- attrs['dst_type'] = self.dst_type
194
- result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
195
- # unpad
196
- unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]
197
- if unpad_after != [0, 0, 0, 0]:
198
- result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after})
199
-
200
- return result
@@ -1,30 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for DropoutGrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- @VLD.check_attrs('keep_prob')
21
- class DropoutGrad(Expander):
22
- """DropoutGrad expander"""
23
-
24
- def _expand(self, graph_builder):
25
- input_dy, input_mask = self.inputs
26
- keep_prob = self.attrs['keep_prob']
27
- r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
28
- result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
29
- result = graph_builder.emit('Mul', [result, input_mask])
30
- return result
@@ -1,50 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for equal_count"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class EqualCount(Expander):
22
- """EqualCount expander"""
23
-
24
- def __init__(self, expand_info):
25
- super().__init__(expand_info)
26
- self.shape_x = self.inputs[0]['shape']
27
- self.shape_y = self.inputs[1]['shape']
28
- self.dtype_x = self.inputs[0]['data_type']
29
- self.dtype_y = self.inputs[1]['data_type']
30
-
31
- def _check(self):
32
- if self.shape_x != self.shape_y:
33
- raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}"
34
- .format(self.shape_x, self.shape_y))
35
- if self.dtype_x != self.dtype_y:
36
- raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}"
37
- .format(self.dtype_x, self.dtype_y))
38
-
39
- def _expand(self, graph_builder):
40
- input_x = self.inputs[0]
41
- input_y = self.inputs[1]
42
-
43
- eql_val = graph_builder.emit('Equal', [input_x, input_y])
44
- cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
45
- axis = list(range(len(input_x.shape)))
46
- result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': False})
47
-
48
- if result.dtype != input_x.dtype:
49
- result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
50
- return result
@@ -1,35 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for erfc"""
16
- from ._utils import Expander
17
-
18
-
19
- class Erfc(Expander):
20
- """Erfc expander"""
21
-
22
- def _expand(self, graph_builder):
23
- input_x = self.inputs[0]
24
- result = None
25
- if input_x.dtype == "float16":
26
- const_one = graph_builder.value("float32", 1)
27
- input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
28
- erf_result = graph_builder.emit('Erf', [input_x])
29
- result = graph_builder.emit('Sub', [const_one, erf_result])
30
- result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
31
- return result
32
- const_one = graph_builder.value(input_x.dtype, 1)
33
- erf_result = graph_builder.emit('Erf', [input_x])
34
- result = graph_builder.emit('Sub', [const_one, erf_result])
35
- return result
@@ -1,50 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for expand_dims"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_attrs('axis')
20
- class ExpandDims(Expander):
21
- """ExpandDims expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x = self.inputs[0]
25
- shape = self.infer_shape(input_x.shape, self.attrs['axis'])
26
- result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape})
27
-
28
- return result
29
-
30
- @staticmethod
31
- def infer_shape(shape, axis):
32
- """infer shape for expand_dims"""
33
- def insert_axis(shape, axis):
34
- if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1:
35
- raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, "
36
- "{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis)))
37
- if axis >= 0:
38
- shape.insert(axis, 1)
39
- else:
40
- shape.insert(axis + len(shape) + 1, 1)
41
- return shape
42
- out_shape = shape[:]
43
- if isinstance(axis, int):
44
- return insert_axis(out_shape, axis)
45
- if isinstance(axis, (list, tuple)):
46
- for i in axis:
47
- out_shape = insert_axis(out_shape, i)
48
- return out_shape
49
- raise ValueError("For 'ExpandDims', type of attr 'axis' should be one of ['int', 'list', 'tuple'], but got {} "
50
- "with type {}".format(axis, type(axis)))
@@ -1,44 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for fused_adam"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class FusedAdam(Expander):
21
- """FusedAdam expander"""
22
-
23
- def _expand(self, graph_builder):
24
- beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
25
-
26
- beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
27
- one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
28
- next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
29
- beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
30
- grad_square = graph_builder.emit('Mul', [gradient, gradient])
31
- one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
32
- next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
33
- sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
34
- sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
35
- update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
36
- update_with_lr = graph_builder.emit('Mul', [lr, update])
37
- next_para = graph_builder.emit('Sub', [param, update_with_lr])
38
-
39
- param_result = graph_builder.emit(
40
- 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
41
- param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
42
- param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
43
-
44
- return param_result