mindspore 2.1.0__cp37-none-any.whl → 2.2.11__cp37-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (577) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  164. mindspore/lib/libmindspore_shared_lib.so +0 -0
  165. mindspore/lib/libnnacl.so +0 -0
  166. mindspore/lib/libopencv_core.so.4.5 +0 -0
  167. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  169. mindspore/lib/libps_cache.so +0 -0
  170. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  183. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  187. mindspore/lib/plugin/ascend/libakg.so +0 -0
  188. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  189. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  190. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  191. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  193. mindspore/lib/plugin/cpu/libakg.so +0 -0
  194. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  196. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  197. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  198. mindspore/nn/__init__.py +0 -2
  199. mindspore/nn/cell.py +313 -74
  200. mindspore/nn/dynamic_lr.py +21 -21
  201. mindspore/nn/layer/activation.py +22 -30
  202. mindspore/nn/layer/basic.py +15 -13
  203. mindspore/nn/layer/channel_shuffle.py +1 -1
  204. mindspore/nn/layer/container.py +271 -9
  205. mindspore/nn/layer/conv.py +323 -204
  206. mindspore/nn/layer/dense.py +8 -5
  207. mindspore/nn/layer/embedding.py +33 -27
  208. mindspore/nn/layer/flash_attention.py +61 -95
  209. mindspore/nn/layer/image.py +8 -6
  210. mindspore/nn/layer/math.py +16 -25
  211. mindspore/nn/layer/normalization.py +107 -66
  212. mindspore/nn/layer/padding.py +1 -1
  213. mindspore/nn/layer/pooling.py +131 -109
  214. mindspore/nn/layer/rnn_cells.py +27 -22
  215. mindspore/nn/layer/rnns.py +13 -16
  216. mindspore/nn/layer/thor_layer.py +1 -1
  217. mindspore/nn/layer/transformer.py +221 -154
  218. mindspore/nn/learning_rate_schedule.py +9 -1
  219. mindspore/nn/loss/loss.py +235 -174
  220. mindspore/nn/optim/ada_grad.py +2 -1
  221. mindspore/nn/optim/adadelta.py +1 -0
  222. mindspore/nn/optim/adafactor.py +2 -1
  223. mindspore/nn/optim/adam.py +7 -4
  224. mindspore/nn/optim/adamax.py +3 -2
  225. mindspore/nn/optim/adasum.py +2 -2
  226. mindspore/nn/optim/asgd.py +2 -3
  227. mindspore/nn/optim/ftrl.py +6 -5
  228. mindspore/nn/optim/lamb.py +7 -4
  229. mindspore/nn/optim/lars.py +1 -1
  230. mindspore/nn/optim/lazyadam.py +5 -3
  231. mindspore/nn/optim/momentum.py +2 -1
  232. mindspore/nn/optim/optimizer.py +53 -4
  233. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  234. mindspore/nn/optim/rmsprop.py +4 -3
  235. mindspore/nn/optim/rprop.py +23 -12
  236. mindspore/nn/optim/sgd.py +26 -11
  237. mindspore/nn/optim/thor.py +9 -7
  238. mindspore/nn/probability/bijector/bijector.py +5 -5
  239. mindspore/nn/probability/bijector/power_transform.py +27 -27
  240. mindspore/nn/probability/bijector/softplus.py +3 -3
  241. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  242. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  243. mindspore/nn/probability/distribution/beta.py +3 -3
  244. mindspore/nn/probability/distribution/categorical.py +7 -7
  245. mindspore/nn/probability/distribution/cauchy.py +0 -1
  246. mindspore/nn/probability/distribution/distribution.py +3 -3
  247. mindspore/nn/probability/distribution/gamma.py +3 -3
  248. mindspore/nn/probability/distribution/geometric.py +4 -4
  249. mindspore/nn/probability/distribution/gumbel.py +4 -4
  250. mindspore/nn/probability/distribution/log_normal.py +2 -2
  251. mindspore/nn/probability/distribution/logistic.py +2 -2
  252. mindspore/nn/probability/distribution/poisson.py +4 -4
  253. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  254. mindspore/nn/probability/distribution/uniform.py +6 -6
  255. mindspore/nn/wrap/__init__.py +4 -2
  256. mindspore/nn/wrap/cell_wrapper.py +87 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  274. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  275. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  276. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  277. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  278. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  279. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  280. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  281. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  283. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  284. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  285. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  286. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  287. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  288. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  289. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  290. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  291. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  292. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  293. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  294. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  295. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  297. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  298. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  299. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  300. mindspore/ops/_primitive_cache.py +1 -1
  301. mindspore/ops/_tracefunc.py +45 -13
  302. mindspore/ops/_utils/utils.py +6 -1
  303. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  304. mindspore/ops/_vmap/vmap_base.py +3 -3
  305. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  306. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  307. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  308. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  309. mindspore/ops/arg_dtype_cast.py +54 -0
  310. mindspore/ops/composite/base.py +37 -10
  311. mindspore/ops/composite/math_ops.py +5 -4
  312. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  313. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  314. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  315. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  316. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  317. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  318. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  319. mindspore/ops/deprecated.py +304 -0
  320. mindspore/ops/function/__init__.py +4 -1
  321. mindspore/ops/function/array_func.py +174 -193
  322. mindspore/ops/function/clip_func.py +81 -13
  323. mindspore/ops/function/debug_func.py +1 -1
  324. mindspore/ops/function/grad/grad_func.py +18 -9
  325. mindspore/ops/function/image_func.py +10 -4
  326. mindspore/ops/function/linalg_func.py +5 -5
  327. mindspore/ops/function/math_func.py +575 -386
  328. mindspore/ops/function/nn_func.py +568 -260
  329. mindspore/ops/function/random_func.py +88 -57
  330. mindspore/ops/function/sparse_func.py +1 -1
  331. mindspore/ops/function/sparse_unary_func.py +14 -12
  332. mindspore/ops/function/vmap_func.py +6 -5
  333. mindspore/ops/functional.py +15 -10
  334. mindspore/ops/op_info_register.py +244 -25
  335. mindspore/ops/operations/__init__.py +31 -19
  336. mindspore/ops/operations/_grad_ops.py +71 -7
  337. mindspore/ops/operations/_inner_ops.py +350 -17
  338. mindspore/ops/operations/_quant_ops.py +4 -8
  339. mindspore/ops/operations/_sequence_ops.py +42 -0
  340. mindspore/ops/operations/array_ops.py +68 -282
  341. mindspore/ops/operations/comm_ops.py +107 -59
  342. mindspore/ops/operations/custom_ops.py +94 -70
  343. mindspore/ops/operations/debug_ops.py +8 -4
  344. mindspore/ops/operations/image_ops.py +18 -12
  345. mindspore/ops/operations/inner_ops.py +26 -3
  346. mindspore/ops/operations/math_ops.py +192 -144
  347. mindspore/ops/operations/nn_ops.py +857 -489
  348. mindspore/ops/operations/other_ops.py +0 -22
  349. mindspore/ops/operations/random_ops.py +53 -111
  350. mindspore/ops/operations/sparse_ops.py +3 -1
  351. mindspore/ops/primitive.py +24 -18
  352. mindspore/parallel/_auto_parallel_context.py +68 -8
  353. mindspore/parallel/_cost_model_context.py +2 -2
  354. mindspore/parallel/_offload_context.py +17 -3
  355. mindspore/parallel/_parallel_serialization.py +12 -5
  356. mindspore/parallel/_ps_context.py +12 -0
  357. mindspore/parallel/_tensor.py +18 -13
  358. mindspore/parallel/_transformer/layers.py +5 -3
  359. mindspore/parallel/_transformer/loss.py +1 -0
  360. mindspore/parallel/_transformer/moe.py +2 -2
  361. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  362. mindspore/parallel/_transformer/transformer.py +23 -3
  363. mindspore/parallel/_utils.py +11 -7
  364. mindspore/parallel/algo_parameter_config.py +85 -5
  365. mindspore/parallel/checkpoint_transform.py +19 -12
  366. mindspore/parallel/shard.py +21 -14
  367. mindspore/profiler/common/struct_type.py +3 -3
  368. mindspore/profiler/common/util.py +4 -2
  369. mindspore/profiler/envprofiling.py +1 -1
  370. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  371. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  372. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  373. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  374. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  375. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  376. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  377. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  378. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  379. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  380. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  381. mindspore/profiler/parser/flops_parser.py +15 -11
  382. mindspore/profiler/parser/framework_parser.py +38 -22
  383. mindspore/profiler/parser/hccl_parser.py +16 -12
  384. mindspore/profiler/parser/integrator.py +22 -11
  385. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  386. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  387. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  388. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  389. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  390. mindspore/profiler/parser/optime_parser.py +1 -1
  391. mindspore/profiler/parser/profiler_info.py +21 -2
  392. mindspore/profiler/parser/step_trace_parser.py +11 -14
  393. mindspore/profiler/profiling.py +179 -89
  394. mindspore/rewrite/api/node.py +102 -19
  395. mindspore/rewrite/api/node_type.py +5 -1
  396. mindspore/rewrite/api/pattern_engine.py +1 -1
  397. mindspore/rewrite/api/scoped_value.py +9 -17
  398. mindspore/rewrite/api/symbol_tree.py +131 -47
  399. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  400. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  401. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  402. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  403. mindspore/rewrite/common/rewrite_elog.py +5 -1
  404. mindspore/rewrite/namer.py +33 -24
  405. mindspore/rewrite/namespace.py +14 -5
  406. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  407. mindspore/rewrite/node/call_function.py +79 -0
  408. mindspore/rewrite/node/cell_container.py +135 -0
  409. mindspore/rewrite/node/control_flow.py +88 -0
  410. mindspore/rewrite/{node.py → node/node.py} +273 -234
  411. mindspore/rewrite/node/node_manager.py +254 -0
  412. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  413. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  414. mindspore/rewrite/parsers/assign_parser.py +216 -221
  415. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  416. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  417. mindspore/rewrite/parsers/constant_parser.py +9 -6
  418. mindspore/rewrite/parsers/container_parser.py +9 -7
  419. mindspore/rewrite/parsers/for_parser.py +42 -21
  420. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  421. mindspore/rewrite/parsers/if_parser.py +28 -24
  422. mindspore/rewrite/parsers/module_parser.py +196 -25
  423. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  424. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  425. mindspore/rewrite/parsers/return_parser.py +6 -6
  426. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  427. mindspore/rewrite/sparsify/utils.py +1 -1
  428. mindspore/rewrite/symbol_tree.py +523 -578
  429. mindspore/rewrite/symbol_tree_builder.py +9 -193
  430. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  431. mindspore/run_check/_check_version.py +6 -4
  432. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  433. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  434. mindspore/scipy/linalg.py +1 -1
  435. mindspore/scipy/ops.py +55 -5
  436. mindspore/scipy/optimize/__init__.py +3 -2
  437. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  438. mindspore/scipy/optimize/minimize.py +7 -3
  439. mindspore/train/_utils.py +7 -3
  440. mindspore/train/amp.py +323 -123
  441. mindspore/train/anf_ir_pb2.py +14 -2
  442. mindspore/train/callback/_backup_and_restore.py +2 -12
  443. mindspore/train/callback/_callback.py +29 -4
  444. mindspore/train/callback/_checkpoint.py +23 -8
  445. mindspore/train/callback/_early_stop.py +2 -2
  446. mindspore/train/callback/_landscape.py +4 -4
  447. mindspore/train/callback/_loss_monitor.py +2 -2
  448. mindspore/train/callback/_on_request_exit.py +2 -2
  449. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  450. mindspore/train/callback/_summary_collector.py +15 -8
  451. mindspore/train/callback/_time_monitor.py +58 -5
  452. mindspore/train/data_sink.py +5 -11
  453. mindspore/train/dataset_helper.py +84 -57
  454. mindspore/train/loss_scale_manager.py +2 -2
  455. mindspore/train/metrics/__init__.py +3 -3
  456. mindspore/train/metrics/cosine_similarity.py +1 -1
  457. mindspore/train/metrics/hausdorff_distance.py +3 -2
  458. mindspore/train/metrics/mean_surface_distance.py +3 -2
  459. mindspore/train/metrics/metric.py +39 -19
  460. mindspore/train/metrics/roc.py +2 -2
  461. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  462. mindspore/train/mind_ir_pb2.py +85 -36
  463. mindspore/train/model.py +187 -47
  464. mindspore/train/serialization.py +487 -161
  465. mindspore/train/summary/_summary_adapter.py +1 -1
  466. mindspore/train/summary/_writer_pool.py +3 -2
  467. mindspore/train/summary/summary_record.py +37 -17
  468. mindspore/train/train_thor/convert_utils.py +3 -3
  469. mindspore/train/train_thor/dataset_helper.py +1 -1
  470. mindspore/version.py +1 -1
  471. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  472. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +476 -527
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  474. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  475. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  476. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  477. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  478. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  479. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  480. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  481. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  482. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  483. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  484. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  485. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  486. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  487. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  488. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  489. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  490. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  491. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  492. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  493. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  494. mindspore/_extends/graph_kernel/expander.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  496. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  497. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  498. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  499. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  500. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  501. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  502. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  503. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  504. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  505. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  507. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  508. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  509. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  510. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  511. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  513. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  514. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  515. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  516. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  517. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  518. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  519. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  520. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  521. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  522. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  523. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  524. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  526. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  527. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  528. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  529. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  531. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  533. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  534. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  535. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  536. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  537. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  538. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  540. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  541. mindspore/dataset/datapreprocess/__init__.py +0 -20
  542. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  543. mindspore/include/api/net.h +0 -142
  544. mindspore/nn/lr_scheduler.py +0 -262
  545. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  546. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  547. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  548. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  549. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  550. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  551. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  552. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  553. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  554. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  555. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  556. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  557. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  558. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  559. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  563. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  567. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  568. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  569. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  570. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  571. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  574. mindspore/rewrite/node_visitor.py +0 -44
  575. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  576. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  577. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,80 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for minimum_grad"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class MinimumGrad(Expander):
22
- """MinimumGrad expander"""
23
-
24
- def _check(self):
25
- if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
26
- raise GKException("For 'MinimumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
27
- "{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
28
- return super(MinimumGrad, self)._check()
29
-
30
- def _expand(self, graph_builder):
31
- input_x, input_y, input_dout = self.inputs
32
-
33
- le_result = graph_builder.emit('LessEqual', [input_x, input_y])
34
- le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype})
35
- dx = graph_builder.emit('Mul', [le_result, input_dout])
36
- dy = graph_builder.emit('Sub', [input_dout, dx])
37
-
38
- # for minimumgrad op, output_shape should be equal to input_shape,
39
- # but some elementwise operating may broadcast input_shape
40
- # then output_shape not equal to original input_shape, so need to reduce output to let them equal
41
- reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
42
- reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
43
- if reduce_axis_x:
44
- dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
45
- if dx_reduce.shape != input_x.shape:
46
- dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
47
- else:
48
- dx_out = dx_reduce
49
- else:
50
- dx_out = dx
51
-
52
- if reduce_axis_y:
53
- dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
54
- if dy_reduce.shape != input_y.shape:
55
- dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
56
- else:
57
- dy_out = dy_reduce
58
- else:
59
- dy_out = dy
60
-
61
- # output two results, regardless of grad_x and grad_y
62
- return dx_out, dy_out
63
-
64
- @staticmethod
65
- def get_reduce_axis(original_shape, broadcast_shape):
66
- """compute reduce axis for final output_shape"""
67
- if len(original_shape) > len(broadcast_shape):
68
- raise ValueError("For 'MinimumGrad', the length of original_shape should be less than or equal to the "
69
- "length of broadcast_shape, but got {} and {}".format(original_shape, broadcast_shape))
70
-
71
- tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape
72
- reduce_axis = []
73
- for idx, _ in enumerate(tmp_shape):
74
- if tmp_shape[idx] != broadcast_shape[idx]:
75
- if tmp_shape[idx] == 1:
76
- reduce_axis.append(idx)
77
- else:
78
- raise ValueError("For 'MinimumGrad', original_shape {} and broadcast_shape {} can not broadcast."
79
- .format(original_shape, broadcast_shape))
80
- return reduce_axis
@@ -1,26 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for OnesLike"""
16
- from ._utils import Expander
17
-
18
-
19
- class OnesLike(Expander):
20
- """OnesLike expander"""
21
-
22
- def _expand(self, graph_builder):
23
- input_x = self.inputs[0]
24
- const_one = graph_builder.value(input_x.dtype, 1)
25
- result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape})
26
- return result
@@ -1,43 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for reduce_mean"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT)
21
- @VLD.check_attrs('axis', 'keep_dims')
22
- class ReduceMean(Expander):
23
- """ReduceMean expander"""
24
-
25
- def _expand(self, graph_builder):
26
- x = self.inputs[0]
27
- axis = self.attrs['axis']
28
- keep_dims = self.attrs['keep_dims']
29
-
30
- if not isinstance(axis, (tuple, list)):
31
- axis = (axis,)
32
- elif not axis:
33
- axis = list(range(len(x.shape)))
34
- reduce_size = 1.0
35
- for idx in axis:
36
- reduce_size *= x.shape[idx]
37
-
38
- reduce_size_value = graph_builder.value(x.dtype, reduce_size)
39
-
40
- sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
41
- result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value])
42
-
43
- return result
@@ -1,32 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for relu_grad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class ReluGrad(Expander):
21
- """ReLU expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x = self.inputs[0]
25
- input_y = self.inputs[1]
26
-
27
- const_zero = graph_builder.value(input_y.dtype, 0)
28
- ge_result = graph_builder.emit('Greater', [input_y, const_zero])
29
- ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
30
- result = graph_builder.emit('Mul', [ge_result, input_x])
31
-
32
- return result
@@ -1,41 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidCrossEntropyWithLogits"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidCrossEntropyWithLogits(Expander):
21
- """SigmoidCrossEntropyWithLogits expander"""
22
-
23
- def _expand(self, graph_builder):
24
- logits, labels = self.inputs
25
- # Calculate sigmoid_cross_entropy_with_logits(logits, labels)
26
- # formula of sigmoid_cross_entropy_with_logits is:
27
- # -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits)))
28
- # To ensure stability and avoid overflow, the formula equal to :
29
- # max(logits, 0) - logits * labels + log(1 + exp(-abs(logits)))
30
- const_one = graph_builder.value(logits.dtype, 1.0)
31
- const_zero = graph_builder.value(logits.dtype, 0.0)
32
- max_logits = graph_builder.emit('Maximum', [logits, const_zero])
33
- logits_mul_labels = graph_builder.emit('Mul', [logits, labels])
34
- abs_logits = graph_builder.emit('Abs', [logits])
35
- neg_abs_logits = graph_builder.emit('Neg', [abs_logits])
36
- exp_neg_abs_logits = graph_builder.emit('Exp', [neg_abs_logits])
37
- one_add_exp_neg_abs_logits = graph_builder.emit('Add', [const_one, exp_neg_abs_logits])
38
- log_one_add_exp_neg_abs_logits = graph_builder.emit('Log', [one_add_exp_neg_abs_logits])
39
- res_tmp = graph_builder.emit('Sub', [max_logits, logits_mul_labels])
40
- res = graph_builder.emit('Add', [res_tmp, log_one_add_exp_neg_abs_logits])
41
- return res
@@ -1,35 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidCrossEntropyWithLogitsGrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidCrossEntropyWithLogitsGrad(Expander):
21
- """SigmoidCrossEntropyWithLogitsGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- logits, label, dout = self.inputs
25
- # Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
26
- # formula of sigmoid_cross_entropy_with_logits_grad is :
27
- # (sigmoid(logits) - label) * dout
28
- const_one = graph_builder.value(logits.dtype, 1.0)
29
- neg_x = graph_builder.emit('Neg', [logits])
30
- exp_neg_x = graph_builder.emit('Exp', [neg_x])
31
- add_exp = graph_builder.emit('Add', [const_one, exp_neg_x])
32
- sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp])
33
- sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label])
34
- res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout])
35
- return res
@@ -1,31 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidGrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidGrad(Expander):
21
- """SigmoidGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_y, dy = self.inputs
25
- # Calculate sigmoid_grad(y, dy)
26
- # formula of sigmoid_grad is : (1 - y) * y * dy
27
- const_one = graph_builder.value(input_y.dtype, 1.0)
28
- one_mins_y = graph_builder.emit('Sub', [const_one, input_y])
29
- y_mul_dy = graph_builder.emit('Mul', [input_y, dy])
30
- res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy])
31
- return res
@@ -1,35 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for slice"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_attrs('begin', 'size')
20
- class Slice(Expander):
21
- """Slice expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x = self.inputs[0]
25
- begin = self.attrs['begin']
26
- size = self.attrs['size']
27
- end = []
28
- strides = []
29
- for i, begin_idx in enumerate(begin):
30
- strides.append(1)
31
- end.append(begin_idx + size[i])
32
- output = graph_builder.tensor(size, input_x.dtype, input_x.data_format)
33
- graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides})
34
-
35
- return output
@@ -1,42 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SoftmaxCrossEntropyWithLogits"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class SoftmaxCrossEntropyWithLogits(Expander):
22
- """SoftmaxCrossEntropyWithLogits expander"""
23
-
24
- def _expand(self, graph_builder):
25
- logits, label = self.inputs
26
- # Calculate softmax_cross_entropy_with_logits(logits, label)
27
- # formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits)))
28
- axis = (-1,)
29
- max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True})
30
- data_sub = graph_builder.emit('Sub', [logits, max_x])
31
- data_exp = graph_builder.emit('Exp', [data_sub])
32
- data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
33
- data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum])
34
- const_eps = graph_builder.value(logits.dtype, 0.000001)
35
- data_softmax_safety = graph_builder.emit("Maximum", [data_softmax, const_eps])
36
- softmax_log = graph_builder.emit('Log', [data_softmax_safety])
37
- label_mul_log = graph_builder.emit('Mul', [label, softmax_log])
38
- tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={
39
- 'reduce_axis': axis, 'keep_dims': False})
40
- loss = graph_builder.emit('Neg', [tmp_res])
41
- dlogits = graph_builder.emit('Sub', [data_softmax, label])
42
- return loss, dlogits
@@ -1,41 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SoftmaxGradExt"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from ._utils import get_reduce_axis_shape
19
-
20
-
21
- @VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
22
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.check_attrs('axis')
24
- class SoftmaxGradExt(Expander):
25
- """SoftmaxGradExt expander"""
26
-
27
- def _expand(self, graph_builder):
28
- x, y, z = self.inputs
29
- axis = self.attrs['axis']
30
-
31
- reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
32
-
33
- data_mul = graph_builder.emit('Mul', [x, y])
34
- data_sum = graph_builder.emit('ReduceSum', [data_mul],
35
- attrs={'reduce_axis': reduce_axis, 'keep_dims': True, 'reduce_output_fuse': True})
36
- if x.data_format == DF.FRAC_NZ:
37
- data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape})
38
- data_sub = graph_builder.emit('Sub', [x, data_sum])
39
- data_mul2 = graph_builder.emit('Mul', [data_sub, y])
40
- result = graph_builder.emit('Mul', [data_mul2, z])
41
- return result
@@ -1,28 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for Softsign"""
16
- from ._utils import Expander
17
-
18
-
19
- class Softsign(Expander):
20
- """Softsign expander"""
21
-
22
- def _expand(self, graph_builder):
23
- x = self.inputs[0]
24
- x_abs = graph_builder.emit('Abs', [x])
25
- one = graph_builder.value(x.dtype, 1.0)
26
- sum_x = graph_builder.emit('Add', [x_abs, one])
27
- result = graph_builder.emit('RealDiv', [x, sum_x])
28
- return result
@@ -1,29 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for sqrtgrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SqrtGrad(Expander):
21
- """SqrtGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- # formula of sqrt_grad is dout / (2 * x)
25
- x, dout = self.inputs
26
- const_two = graph_builder.value(x.dtype, 2)
27
- dividend = graph_builder.emit('Mul', [x, const_two])
28
- result = graph_builder.emit('RealDiv', [dout, dividend])
29
- return result
@@ -1,44 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SquareSumAll"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander
18
-
19
-
20
- class SquareSumAll(Expander):
21
- """SquareSumAll expander"""
22
-
23
- def _check(self):
24
- """check inputs"""
25
- input_num = len(self.inputs)
26
- if input_num != 2:
27
- raise GKException("For 'SquareSumAll', the inputs number should be 2, but got {}.".format(input_num))
28
-
29
- def _expand(self, graph_builder):
30
- """do expand"""
31
- x0 = self.inputs[0]
32
- x1 = self.inputs[1]
33
-
34
- ori_shape = x0.shape
35
- axis = []
36
- for i, _ in enumerate(ori_shape):
37
- axis.append(i)
38
-
39
- square_res0 = graph_builder.emit('Mul', [x0, x0])
40
- square_res1 = graph_builder.emit('Mul', [x1, x1])
41
- result0 = graph_builder.emit('ReduceSum', [square_res0], attrs={'reduce_axis': axis, 'keep_dims': False})
42
- result1 = graph_builder.emit('ReduceSum', [square_res1], attrs={'reduce_axis': axis, 'keep_dims': False})
43
-
44
- return result0, result1
@@ -1,37 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SquareSumV1"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from ._utils import get_reduce_axis_shape
19
-
20
-
21
- @VLD.add_format(DF.FRAC_NZ)
22
- @VLD.add_format(DF.DEFAULT)
23
- @VLD.check_attrs('axis')
24
- class SquareSumV1(Expander):
25
- """Square expander"""
26
-
27
- def _expand(self, graph_builder):
28
- x = self.inputs[0]
29
- axis = self.attrs['axis']
30
-
31
- reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
32
-
33
- square_res = graph_builder.emit('Mul', [x, x])
34
- result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
35
- if x.data_format == DF.FRAC_NZ:
36
- result = graph_builder.emit('Reshape', [result], attrs={'shape': ori_reduced_shape})
37
- return result
@@ -1,43 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for squared_difference"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class SquaredDifference(Expander):
22
- """SquaredDifference expander"""
23
-
24
- def __init__(self, expand_info):
25
- super().__init__(expand_info)
26
- self.dtype_x = self.inputs[0]['data_type']
27
- self.dtype_y = self.inputs[1]['data_type']
28
-
29
- def _check(self):
30
- if self.dtype_x == "float64" or self.dtype_y == "float64":
31
- raise GKException("For 'SquaredDifference', the inputs data type must not be float64")
32
- if self.dtype_x != self.dtype_y:
33
- raise GKException("For 'SquaredDifference', the inputs data type should be same, but got {} and {}"
34
- .format(self.dtype_x, self.dtype_y))
35
-
36
- def _expand(self, graph_builder):
37
- input_x = self.inputs[0]
38
- input_y = self.inputs[1]
39
-
40
- sub_val = graph_builder.emit('Sub', [input_x, input_y])
41
- result = graph_builder.emit('Mul', [sub_val, sub_val])
42
-
43
- return result
@@ -1,31 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for tanh_grad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class TanhGrad(Expander):
21
- """TanhGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_y, input_dy = self.inputs
25
-
26
- const_one = graph_builder.value(input_y.dtype, 1)
27
- double_y = graph_builder.emit('Mul', [input_y, input_y])
28
- one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y])
29
- result = graph_builder.emit('Mul', [input_dy, one_sub_double_y])
30
-
31
- return result