mindspore 2.1.0__cp38-none-any.whl → 2.2.11__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (578) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +61 -95
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/__init__.py +4 -2
  257. mindspore/nn/wrap/cell_wrapper.py +87 -34
  258. mindspore/nn/wrap/grad_reducer.py +8 -5
  259. mindspore/nn/wrap/loss_scale.py +105 -42
  260. mindspore/numpy/array_creations.py +1 -2
  261. mindspore/numpy/array_ops.py +3 -2
  262. mindspore/numpy/utils_const.py +5 -5
  263. mindspore/offline_debug/convert_async.py +2 -2
  264. mindspore/ops/_grad_experimental/__init__.py +0 -5
  265. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  266. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  267. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  268. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  269. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  270. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  271. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  272. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  273. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  274. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  275. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  276. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  277. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  278. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  279. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  280. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  281. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  282. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  283. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  284. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  285. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  286. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  287. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  288. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  289. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  290. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  291. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  293. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  294. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  295. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  296. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  298. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  299. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  300. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  301. mindspore/ops/_primitive_cache.py +1 -1
  302. mindspore/ops/_tracefunc.py +45 -13
  303. mindspore/ops/_utils/utils.py +6 -1
  304. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  305. mindspore/ops/_vmap/vmap_base.py +3 -3
  306. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  307. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  308. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  309. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  310. mindspore/ops/arg_dtype_cast.py +54 -0
  311. mindspore/ops/composite/base.py +37 -10
  312. mindspore/ops/composite/math_ops.py +5 -4
  313. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  314. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  315. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  316. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  317. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  318. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  319. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  320. mindspore/ops/deprecated.py +304 -0
  321. mindspore/ops/function/__init__.py +4 -1
  322. mindspore/ops/function/array_func.py +174 -193
  323. mindspore/ops/function/clip_func.py +81 -13
  324. mindspore/ops/function/debug_func.py +1 -1
  325. mindspore/ops/function/grad/grad_func.py +18 -9
  326. mindspore/ops/function/image_func.py +10 -4
  327. mindspore/ops/function/linalg_func.py +5 -5
  328. mindspore/ops/function/math_func.py +575 -386
  329. mindspore/ops/function/nn_func.py +568 -260
  330. mindspore/ops/function/random_func.py +88 -57
  331. mindspore/ops/function/sparse_func.py +1 -1
  332. mindspore/ops/function/sparse_unary_func.py +14 -12
  333. mindspore/ops/function/vmap_func.py +6 -5
  334. mindspore/ops/functional.py +15 -10
  335. mindspore/ops/op_info_register.py +244 -25
  336. mindspore/ops/operations/__init__.py +31 -19
  337. mindspore/ops/operations/_grad_ops.py +71 -7
  338. mindspore/ops/operations/_inner_ops.py +350 -17
  339. mindspore/ops/operations/_quant_ops.py +4 -8
  340. mindspore/ops/operations/_sequence_ops.py +42 -0
  341. mindspore/ops/operations/array_ops.py +68 -282
  342. mindspore/ops/operations/comm_ops.py +107 -59
  343. mindspore/ops/operations/custom_ops.py +94 -70
  344. mindspore/ops/operations/debug_ops.py +8 -4
  345. mindspore/ops/operations/image_ops.py +18 -12
  346. mindspore/ops/operations/inner_ops.py +26 -3
  347. mindspore/ops/operations/math_ops.py +192 -144
  348. mindspore/ops/operations/nn_ops.py +857 -489
  349. mindspore/ops/operations/other_ops.py +0 -22
  350. mindspore/ops/operations/random_ops.py +53 -111
  351. mindspore/ops/operations/sparse_ops.py +3 -1
  352. mindspore/ops/primitive.py +24 -18
  353. mindspore/parallel/_auto_parallel_context.py +68 -8
  354. mindspore/parallel/_cost_model_context.py +2 -2
  355. mindspore/parallel/_offload_context.py +17 -3
  356. mindspore/parallel/_parallel_serialization.py +12 -5
  357. mindspore/parallel/_ps_context.py +12 -0
  358. mindspore/parallel/_tensor.py +18 -13
  359. mindspore/parallel/_transformer/layers.py +5 -3
  360. mindspore/parallel/_transformer/loss.py +1 -0
  361. mindspore/parallel/_transformer/moe.py +2 -2
  362. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  363. mindspore/parallel/_transformer/transformer.py +23 -3
  364. mindspore/parallel/_utils.py +11 -7
  365. mindspore/parallel/algo_parameter_config.py +85 -5
  366. mindspore/parallel/checkpoint_transform.py +19 -12
  367. mindspore/parallel/shard.py +21 -14
  368. mindspore/profiler/common/struct_type.py +3 -3
  369. mindspore/profiler/common/util.py +4 -2
  370. mindspore/profiler/envprofiling.py +1 -1
  371. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  372. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  373. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  374. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  375. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  376. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  377. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  378. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  379. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  380. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  381. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  382. mindspore/profiler/parser/flops_parser.py +15 -11
  383. mindspore/profiler/parser/framework_parser.py +38 -22
  384. mindspore/profiler/parser/hccl_parser.py +16 -12
  385. mindspore/profiler/parser/integrator.py +22 -11
  386. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  387. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  388. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  389. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  390. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  391. mindspore/profiler/parser/optime_parser.py +1 -1
  392. mindspore/profiler/parser/profiler_info.py +21 -2
  393. mindspore/profiler/parser/step_trace_parser.py +11 -14
  394. mindspore/profiler/profiling.py +179 -89
  395. mindspore/rewrite/api/node.py +102 -19
  396. mindspore/rewrite/api/node_type.py +5 -1
  397. mindspore/rewrite/api/pattern_engine.py +1 -1
  398. mindspore/rewrite/api/scoped_value.py +9 -17
  399. mindspore/rewrite/api/symbol_tree.py +131 -47
  400. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  401. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  402. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  403. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  404. mindspore/rewrite/common/rewrite_elog.py +5 -1
  405. mindspore/rewrite/namer.py +33 -24
  406. mindspore/rewrite/namespace.py +14 -5
  407. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  408. mindspore/rewrite/node/call_function.py +79 -0
  409. mindspore/rewrite/node/cell_container.py +135 -0
  410. mindspore/rewrite/node/control_flow.py +88 -0
  411. mindspore/rewrite/{node.py → node/node.py} +273 -234
  412. mindspore/rewrite/node/node_manager.py +254 -0
  413. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  414. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  415. mindspore/rewrite/parsers/assign_parser.py +216 -221
  416. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  417. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  418. mindspore/rewrite/parsers/constant_parser.py +9 -6
  419. mindspore/rewrite/parsers/container_parser.py +9 -7
  420. mindspore/rewrite/parsers/for_parser.py +42 -21
  421. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  422. mindspore/rewrite/parsers/if_parser.py +28 -24
  423. mindspore/rewrite/parsers/module_parser.py +196 -25
  424. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  425. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  426. mindspore/rewrite/parsers/return_parser.py +6 -6
  427. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  428. mindspore/rewrite/sparsify/utils.py +1 -1
  429. mindspore/rewrite/symbol_tree.py +523 -578
  430. mindspore/rewrite/symbol_tree_builder.py +9 -193
  431. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  432. mindspore/run_check/_check_version.py +6 -4
  433. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  434. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  435. mindspore/scipy/linalg.py +1 -1
  436. mindspore/scipy/ops.py +55 -5
  437. mindspore/scipy/optimize/__init__.py +3 -2
  438. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  439. mindspore/scipy/optimize/minimize.py +7 -3
  440. mindspore/train/_utils.py +7 -3
  441. mindspore/train/amp.py +323 -123
  442. mindspore/train/anf_ir_pb2.py +14 -2
  443. mindspore/train/callback/_backup_and_restore.py +2 -12
  444. mindspore/train/callback/_callback.py +29 -4
  445. mindspore/train/callback/_checkpoint.py +23 -8
  446. mindspore/train/callback/_early_stop.py +2 -2
  447. mindspore/train/callback/_landscape.py +4 -4
  448. mindspore/train/callback/_loss_monitor.py +2 -2
  449. mindspore/train/callback/_on_request_exit.py +2 -2
  450. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  451. mindspore/train/callback/_summary_collector.py +15 -8
  452. mindspore/train/callback/_time_monitor.py +58 -5
  453. mindspore/train/data_sink.py +5 -11
  454. mindspore/train/dataset_helper.py +84 -57
  455. mindspore/train/loss_scale_manager.py +2 -2
  456. mindspore/train/metrics/__init__.py +3 -3
  457. mindspore/train/metrics/cosine_similarity.py +1 -1
  458. mindspore/train/metrics/hausdorff_distance.py +3 -2
  459. mindspore/train/metrics/mean_surface_distance.py +3 -2
  460. mindspore/train/metrics/metric.py +39 -19
  461. mindspore/train/metrics/roc.py +2 -2
  462. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  463. mindspore/train/mind_ir_pb2.py +85 -36
  464. mindspore/train/model.py +187 -47
  465. mindspore/train/serialization.py +487 -161
  466. mindspore/train/summary/_summary_adapter.py +1 -1
  467. mindspore/train/summary/_writer_pool.py +3 -2
  468. mindspore/train/summary/summary_record.py +37 -17
  469. mindspore/train/train_thor/convert_utils.py +3 -3
  470. mindspore/train/train_thor/dataset_helper.py +1 -1
  471. mindspore/version.py +1 -1
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +477 -528
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  475. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  478. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  479. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  480. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  481. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  482. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  483. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  486. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  487. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  488. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  489. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  490. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  491. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  492. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  493. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  494. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  495. mindspore/_extends/graph_kernel/expander.py +0 -80
  496. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  497. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  498. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  499. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  500. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  501. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  502. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  503. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  504. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  505. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  506. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  507. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  508. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  509. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  510. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  511. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  512. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  513. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  514. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  515. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  516. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  517. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  518. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  519. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  520. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  523. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  524. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  525. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  526. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  527. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  528. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  532. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  533. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  534. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  535. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  536. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  537. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  538. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  539. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  540. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  541. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  542. mindspore/dataset/datapreprocess/__init__.py +0 -20
  543. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  544. mindspore/include/api/net.h +0 -142
  545. mindspore/nn/lr_scheduler.py +0 -262
  546. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  547. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  548. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  549. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  550. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  559. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  560. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  564. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  575. mindspore/rewrite/node_visitor.py +0 -44
  576. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  578. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,80 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for minimum_grad"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class MinimumGrad(Expander):
22
- """MinimumGrad expander"""
23
-
24
- def _check(self):
25
- if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
26
- raise GKException("For 'MinimumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
27
- "{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
28
- return super(MinimumGrad, self)._check()
29
-
30
- def _expand(self, graph_builder):
31
- input_x, input_y, input_dout = self.inputs
32
-
33
- le_result = graph_builder.emit('LessEqual', [input_x, input_y])
34
- le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype})
35
- dx = graph_builder.emit('Mul', [le_result, input_dout])
36
- dy = graph_builder.emit('Sub', [input_dout, dx])
37
-
38
- # for minimumgrad op, output_shape should be equal to input_shape,
39
- # but some elementwise operating may broadcast input_shape
40
- # then output_shape not equal to original input_shape, so need to reduce output to let them equal
41
- reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
42
- reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
43
- if reduce_axis_x:
44
- dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
45
- if dx_reduce.shape != input_x.shape:
46
- dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
47
- else:
48
- dx_out = dx_reduce
49
- else:
50
- dx_out = dx
51
-
52
- if reduce_axis_y:
53
- dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
54
- if dy_reduce.shape != input_y.shape:
55
- dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
56
- else:
57
- dy_out = dy_reduce
58
- else:
59
- dy_out = dy
60
-
61
- # output two results, regardless of grad_x and grad_y
62
- return dx_out, dy_out
63
-
64
- @staticmethod
65
- def get_reduce_axis(original_shape, broadcast_shape):
66
- """compute reduce axis for final output_shape"""
67
- if len(original_shape) > len(broadcast_shape):
68
- raise ValueError("For 'MinimumGrad', the length of original_shape should be less than or equal to the "
69
- "length of broadcast_shape, but got {} and {}".format(original_shape, broadcast_shape))
70
-
71
- tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape
72
- reduce_axis = []
73
- for idx, _ in enumerate(tmp_shape):
74
- if tmp_shape[idx] != broadcast_shape[idx]:
75
- if tmp_shape[idx] == 1:
76
- reduce_axis.append(idx)
77
- else:
78
- raise ValueError("For 'MinimumGrad', original_shape {} and broadcast_shape {} can not broadcast."
79
- .format(original_shape, broadcast_shape))
80
- return reduce_axis
@@ -1,26 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for OnesLike"""
16
- from ._utils import Expander
17
-
18
-
19
- class OnesLike(Expander):
20
- """OnesLike expander"""
21
-
22
- def _expand(self, graph_builder):
23
- input_x = self.inputs[0]
24
- const_one = graph_builder.value(input_x.dtype, 1)
25
- result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape})
26
- return result
@@ -1,43 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for reduce_mean"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT)
21
- @VLD.check_attrs('axis', 'keep_dims')
22
- class ReduceMean(Expander):
23
- """ReduceMean expander"""
24
-
25
- def _expand(self, graph_builder):
26
- x = self.inputs[0]
27
- axis = self.attrs['axis']
28
- keep_dims = self.attrs['keep_dims']
29
-
30
- if not isinstance(axis, (tuple, list)):
31
- axis = (axis,)
32
- elif not axis:
33
- axis = list(range(len(x.shape)))
34
- reduce_size = 1.0
35
- for idx in axis:
36
- reduce_size *= x.shape[idx]
37
-
38
- reduce_size_value = graph_builder.value(x.dtype, reduce_size)
39
-
40
- sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
41
- result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value])
42
-
43
- return result
@@ -1,32 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for relu_grad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class ReluGrad(Expander):
21
- """ReLU expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x = self.inputs[0]
25
- input_y = self.inputs[1]
26
-
27
- const_zero = graph_builder.value(input_y.dtype, 0)
28
- ge_result = graph_builder.emit('Greater', [input_y, const_zero])
29
- ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
30
- result = graph_builder.emit('Mul', [ge_result, input_x])
31
-
32
- return result
@@ -1,41 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidCrossEntropyWithLogits"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidCrossEntropyWithLogits(Expander):
21
- """SigmoidCrossEntropyWithLogits expander"""
22
-
23
- def _expand(self, graph_builder):
24
- logits, labels = self.inputs
25
- # Calculate sigmoid_cross_entropy_with_logits(logits, labels)
26
- # formula of sigmoid_cross_entropy_with_logits is:
27
- # -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits)))
28
- # To ensure stability and avoid overflow, the formula equal to :
29
- # max(logits, 0) - logits * labels + log(1 + exp(-abs(logits)))
30
- const_one = graph_builder.value(logits.dtype, 1.0)
31
- const_zero = graph_builder.value(logits.dtype, 0.0)
32
- max_logits = graph_builder.emit('Maximum', [logits, const_zero])
33
- logits_mul_labels = graph_builder.emit('Mul', [logits, labels])
34
- abs_logits = graph_builder.emit('Abs', [logits])
35
- neg_abs_logits = graph_builder.emit('Neg', [abs_logits])
36
- exp_neg_abs_logits = graph_builder.emit('Exp', [neg_abs_logits])
37
- one_add_exp_neg_abs_logits = graph_builder.emit('Add', [const_one, exp_neg_abs_logits])
38
- log_one_add_exp_neg_abs_logits = graph_builder.emit('Log', [one_add_exp_neg_abs_logits])
39
- res_tmp = graph_builder.emit('Sub', [max_logits, logits_mul_labels])
40
- res = graph_builder.emit('Add', [res_tmp, log_one_add_exp_neg_abs_logits])
41
- return res
@@ -1,35 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidCrossEntropyWithLogitsGrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidCrossEntropyWithLogitsGrad(Expander):
21
- """SigmoidCrossEntropyWithLogitsGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- logits, label, dout = self.inputs
25
- # Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
26
- # formula of sigmoid_cross_entropy_with_logits_grad is :
27
- # (sigmoid(logits) - label) * dout
28
- const_one = graph_builder.value(logits.dtype, 1.0)
29
- neg_x = graph_builder.emit('Neg', [logits])
30
- exp_neg_x = graph_builder.emit('Exp', [neg_x])
31
- add_exp = graph_builder.emit('Add', [const_one, exp_neg_x])
32
- sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp])
33
- sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label])
34
- res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout])
35
- return res
@@ -1,31 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidGrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidGrad(Expander):
21
- """SigmoidGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_y, dy = self.inputs
25
- # Calculate sigmoid_grad(y, dy)
26
- # formula of sigmoid_grad is : (1 - y) * y * dy
27
- const_one = graph_builder.value(input_y.dtype, 1.0)
28
- one_mins_y = graph_builder.emit('Sub', [const_one, input_y])
29
- y_mul_dy = graph_builder.emit('Mul', [input_y, dy])
30
- res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy])
31
- return res
@@ -1,35 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for slice"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_attrs('begin', 'size')
20
- class Slice(Expander):
21
- """Slice expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x = self.inputs[0]
25
- begin = self.attrs['begin']
26
- size = self.attrs['size']
27
- end = []
28
- strides = []
29
- for i, begin_idx in enumerate(begin):
30
- strides.append(1)
31
- end.append(begin_idx + size[i])
32
- output = graph_builder.tensor(size, input_x.dtype, input_x.data_format)
33
- graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides})
34
-
35
- return output
@@ -1,42 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SoftmaxCrossEntropyWithLogits"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class SoftmaxCrossEntropyWithLogits(Expander):
22
- """SoftmaxCrossEntropyWithLogits expander"""
23
-
24
- def _expand(self, graph_builder):
25
- logits, label = self.inputs
26
- # Calculate softmax_cross_entropy_with_logits(logits, label)
27
- # formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits)))
28
- axis = (-1,)
29
- max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True})
30
- data_sub = graph_builder.emit('Sub', [logits, max_x])
31
- data_exp = graph_builder.emit('Exp', [data_sub])
32
- data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
33
- data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum])
34
- const_eps = graph_builder.value(logits.dtype, 0.000001)
35
- data_softmax_safety = graph_builder.emit("Maximum", [data_softmax, const_eps])
36
- softmax_log = graph_builder.emit('Log', [data_softmax_safety])
37
- label_mul_log = graph_builder.emit('Mul', [label, softmax_log])
38
- tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={
39
- 'reduce_axis': axis, 'keep_dims': False})
40
- loss = graph_builder.emit('Neg', [tmp_res])
41
- dlogits = graph_builder.emit('Sub', [data_softmax, label])
42
- return loss, dlogits
@@ -1,41 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SoftmaxGradExt"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from ._utils import get_reduce_axis_shape
19
-
20
-
21
- @VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
22
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.check_attrs('axis')
24
- class SoftmaxGradExt(Expander):
25
- """SoftmaxGradExt expander"""
26
-
27
- def _expand(self, graph_builder):
28
- x, y, z = self.inputs
29
- axis = self.attrs['axis']
30
-
31
- reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
32
-
33
- data_mul = graph_builder.emit('Mul', [x, y])
34
- data_sum = graph_builder.emit('ReduceSum', [data_mul],
35
- attrs={'reduce_axis': reduce_axis, 'keep_dims': True, 'reduce_output_fuse': True})
36
- if x.data_format == DF.FRAC_NZ:
37
- data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape})
38
- data_sub = graph_builder.emit('Sub', [x, data_sum])
39
- data_mul2 = graph_builder.emit('Mul', [data_sub, y])
40
- result = graph_builder.emit('Mul', [data_mul2, z])
41
- return result
@@ -1,28 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for Softsign"""
16
- from ._utils import Expander
17
-
18
-
19
- class Softsign(Expander):
20
- """Softsign expander"""
21
-
22
- def _expand(self, graph_builder):
23
- x = self.inputs[0]
24
- x_abs = graph_builder.emit('Abs', [x])
25
- one = graph_builder.value(x.dtype, 1.0)
26
- sum_x = graph_builder.emit('Add', [x_abs, one])
27
- result = graph_builder.emit('RealDiv', [x, sum_x])
28
- return result
@@ -1,29 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for sqrtgrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SqrtGrad(Expander):
21
- """SqrtGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- # formula of sqrt_grad is dout / (2 * x)
25
- x, dout = self.inputs
26
- const_two = graph_builder.value(x.dtype, 2)
27
- dividend = graph_builder.emit('Mul', [x, const_two])
28
- result = graph_builder.emit('RealDiv', [dout, dividend])
29
- return result
@@ -1,44 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SquareSumAll"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander
18
-
19
-
20
- class SquareSumAll(Expander):
21
- """SquareSumAll expander"""
22
-
23
- def _check(self):
24
- """check inputs"""
25
- input_num = len(self.inputs)
26
- if input_num != 2:
27
- raise GKException("For 'SquareSumAll', the inputs number should be 2, but got {}.".format(input_num))
28
-
29
- def _expand(self, graph_builder):
30
- """do expand"""
31
- x0 = self.inputs[0]
32
- x1 = self.inputs[1]
33
-
34
- ori_shape = x0.shape
35
- axis = []
36
- for i, _ in enumerate(ori_shape):
37
- axis.append(i)
38
-
39
- square_res0 = graph_builder.emit('Mul', [x0, x0])
40
- square_res1 = graph_builder.emit('Mul', [x1, x1])
41
- result0 = graph_builder.emit('ReduceSum', [square_res0], attrs={'reduce_axis': axis, 'keep_dims': False})
42
- result1 = graph_builder.emit('ReduceSum', [square_res1], attrs={'reduce_axis': axis, 'keep_dims': False})
43
-
44
- return result0, result1
@@ -1,37 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SquareSumV1"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from ._utils import get_reduce_axis_shape
19
-
20
-
21
- @VLD.add_format(DF.FRAC_NZ)
22
- @VLD.add_format(DF.DEFAULT)
23
- @VLD.check_attrs('axis')
24
- class SquareSumV1(Expander):
25
- """Square expander"""
26
-
27
- def _expand(self, graph_builder):
28
- x = self.inputs[0]
29
- axis = self.attrs['axis']
30
-
31
- reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
32
-
33
- square_res = graph_builder.emit('Mul', [x, x])
34
- result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
35
- if x.data_format == DF.FRAC_NZ:
36
- result = graph_builder.emit('Reshape', [result], attrs={'shape': ori_reduced_shape})
37
- return result
@@ -1,43 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for squared_difference"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class SquaredDifference(Expander):
22
- """SquaredDifference expander"""
23
-
24
- def __init__(self, expand_info):
25
- super().__init__(expand_info)
26
- self.dtype_x = self.inputs[0]['data_type']
27
- self.dtype_y = self.inputs[1]['data_type']
28
-
29
- def _check(self):
30
- if self.dtype_x == "float64" or self.dtype_y == "float64":
31
- raise GKException("For 'SquaredDifference', the inputs data type must not be float64")
32
- if self.dtype_x != self.dtype_y:
33
- raise GKException("For 'SquaredDifference', the inputs data type should be same, but got {} and {}"
34
- .format(self.dtype_x, self.dtype_y))
35
-
36
- def _expand(self, graph_builder):
37
- input_x = self.inputs[0]
38
- input_y = self.inputs[1]
39
-
40
- sub_val = graph_builder.emit('Sub', [input_x, input_y])
41
- result = graph_builder.emit('Mul', [sub_val, sub_val])
42
-
43
- return result
@@ -1,31 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for tanh_grad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class TanhGrad(Expander):
21
- """TanhGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_y, input_dy = self.inputs
25
-
26
- const_one = graph_builder.value(input_y.dtype, 1)
27
- double_y = graph_builder.emit('Mul', [input_y, input_y])
28
- one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y])
29
- result = graph_builder.emit('Mul', [input_dy, one_sub_double_y])
30
-
31
- return result