mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,62 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cdiv"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CDiv(Expander):
22
- """CDiv expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CDiv Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
33
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
34
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
35
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
36
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
37
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
38
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
39
- final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
40
- final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
41
- result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
42
- result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
43
- result = graph_builder.emit('Complex', [result_real, result_imag])
44
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
45
- x_real = graph_builder.emit('CReal', [input_x])
46
- x_imag = graph_builder.emit('CImag', [input_x])
47
- x_real_div_y = graph_builder.emit('RealDiv', [x_real, input_y])
48
- x_imag_div_y = graph_builder.emit('RealDiv', [x_imag, input_y])
49
- result = graph_builder.emit('Complex', [x_real_div_y, x_imag_div_y])
50
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
51
- y_real = graph_builder.emit('CReal', [input_y])
52
- y_imag = graph_builder.emit('CImag', [input_y])
53
- neg_y_imag = graph_builder.emit('Neg', [y_imag])
54
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
55
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
56
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
57
- x_mul_y_real = graph_builder.emit('Mul', [input_x, y_real])
58
- x_mul_neg_y_imag = graph_builder.emit('Mul', [input_x, neg_y_imag])
59
- y_real_div_x = graph_builder.emit('RealDiv', [x_mul_y_real, final_denominator])
60
- y_imag_div_x = graph_builder.emit('RealDiv', [x_mul_neg_y_imag, final_denominator])
61
- result = graph_builder.emit('Complex', [y_real_div_x, y_imag_div_x])
62
- return result
@@ -1,52 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cmul"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CMul(Expander):
22
- """CMul expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CMul Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
33
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
34
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
35
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
36
- result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag])
37
- result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real])
38
- result = graph_builder.emit('Complex', [result_real, result_imag])
39
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
40
- x_real = graph_builder.emit('CReal', [input_x])
41
- x_imag = graph_builder.emit('CImag', [input_x])
42
- x_real_mul_y = graph_builder.emit('Mul', [x_real, input_y])
43
- x_imag_mul_y = graph_builder.emit('Mul', [x_imag, input_y])
44
- result = graph_builder.emit('Complex', [x_real_mul_y, x_imag_mul_y])
45
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
46
- y_real = graph_builder.emit('CReal', [input_y])
47
- y_imag = graph_builder.emit('CImag', [input_y])
48
- y_real_mul_x = graph_builder.emit('Mul', [y_real, input_x])
49
- y_imag_mul_x = graph_builder.emit('Mul', [y_imag, input_x])
50
- result = graph_builder.emit('Complex', [y_real_mul_x, y_imag_mul_x])
51
-
52
- return result
@@ -1,62 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for crealdiv"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CRealDiv(Expander):
22
- """CRealDiv expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CRealDiv Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
33
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
34
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
35
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
36
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
37
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
38
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
39
- final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
40
- final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
41
- result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
42
- result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
43
- result = graph_builder.emit('Complex', [result_real, result_imag])
44
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
45
- x_real = graph_builder.emit('CReal', [input_x])
46
- x_imag = graph_builder.emit('CImag', [input_x])
47
- x_real_div_y = graph_builder.emit('RealDiv', [x_real, input_y])
48
- x_imag_div_y = graph_builder.emit('RealDiv', [x_imag, input_y])
49
- result = graph_builder.emit('Complex', [x_real_div_y, x_imag_div_y])
50
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
51
- y_real = graph_builder.emit('CReal', [input_y])
52
- y_imag = graph_builder.emit('CImag', [input_y])
53
- neg_y_imag = graph_builder.emit('Neg', [y_imag])
54
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
55
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
56
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
57
- x_mul_y_real = graph_builder.emit('Mul', [input_x, y_real])
58
- x_mul_neg_y_imag = graph_builder.emit('Mul', [input_x, neg_y_imag])
59
- y_real_div_x = graph_builder.emit('RealDiv', [x_mul_y_real, final_denominator])
60
- y_imag_div_x = graph_builder.emit('RealDiv', [x_mul_neg_y_imag, final_denominator])
61
- result = graph_builder.emit('Complex', [y_real_div_x, y_imag_div_x])
62
- return result
@@ -1,45 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for csub"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CSub(Expander):
22
- """CSub expander"""
23
-
24
- def _expand(self, graph_builder):
25
- input_x, input_y = self.inputs
26
- if input_x.dtype == input_y.dtype:
27
- x_real = graph_builder.emit('CReal', [input_x])
28
- y_real = graph_builder.emit('CReal', [input_y])
29
- x_imag = graph_builder.emit('CImag', [input_x])
30
- y_imag = graph_builder.emit('CImag', [input_y])
31
- result_real = graph_builder.emit('Sub', [x_real, y_real])
32
- result_imag = graph_builder.emit('Sub', [x_imag, y_imag])
33
- result = graph_builder.emit('Complex', [result_real, result_imag])
34
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
35
- x_real = graph_builder.emit('CReal', [input_x])
36
- x_imag = graph_builder.emit('CImag', [input_x])
37
- x_real_sub_y = graph_builder.emit('Sub', [x_real, input_y])
38
- result = graph_builder.emit('Complex', [x_real_sub_y, x_imag])
39
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
40
- y_real = graph_builder.emit('CReal', [input_y])
41
- y_imag = graph_builder.emit('CImag', [input_y])
42
- x_sub_y_real = graph_builder.emit('Sub', [input_x, y_real])
43
- y_imag = graph_builder.emit('Neg', [y_imag])
44
- result = graph_builder.emit('Complex', [x_sub_y_real, y_imag])
45
- return result
@@ -1,200 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for Conv2D"""
16
- from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad
17
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
18
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
19
- from ._utils import Expander, ExpanderInfoValidator as VLD
20
-
21
-
22
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.NHWC, DF.NHWC)
24
- @VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
25
- class Conv2D(Expander):
26
- """
27
- Conv2D expander
28
-
29
- Currently, only Conv2D that meets several conditions can be expanded, other cases will be skipped.
30
- Conditions to expand:
31
- inputs are NHWC format and float16.
32
- attr groups and group are 1.
33
- attr dilation are all 1.
34
- N channel of inputs > 16.
35
- C channel of inputs > 8.
36
- output N*H*W are multiplies of 128.
37
- """
38
- M_ALIGN = 32
39
- N_ALIGN = 32
40
- K_ALIGN = 16
41
- K_LIMIT = 800
42
- MNK_LIMIT = 3 * (10 ** 10)
43
- N0_CHANNEL_ALIGN = 32
44
- N1_CHANNEL_ALIGN = 32
45
- C_CHANNEL_ALIGN = 16
46
- OUT_NHW_ALIGN = 128
47
-
48
- def __init__(self, expand_info):
49
- super().__init__(expand_info)
50
- self.dst_type = self.outputs[0]['data_type']
51
- self.dst_format = self.outputs[0]['format']
52
- self.has_pad = False
53
- self.can_optimize_to_matmul = False
54
- self.shape_0_pad = self.inputs[0]['shape']
55
- self.shape_1_pad = self.inputs[1]['shape']
56
- self.m = 0
57
- self.n = 0
58
- self.k = 0
59
-
60
- def _optimize_to_matmul(self):
61
- stride = self.attrs['stride']
62
- dilation = self.attrs['dilation']
63
- _, h, w, _ = self.inputs[1]['shape']
64
- if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
65
- self.m % self.M_ALIGN == 0 and self.n % self.N_ALIGN == 0 and self.k % self.K_ALIGN == 0:
66
- return True
67
- return False
68
-
69
- def _common_check(self):
70
- """common check for inputs and attrs"""
71
- type_0 = self.inputs[0]['data_type']
72
- type_1 = self.inputs[1]['data_type']
73
- if type_0 != "float16" or type_1 != "float16":
74
- raise GKException("For 'Conv2D', inputs data type should be both float16, but got {} and {}"
75
- .format(type_0, type_1))
76
-
77
- formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
78
- check_format_any(formats, DF.NHWC)
79
-
80
- groups = self.attrs['groups']
81
- group = self.attrs['group']
82
- if groups != 1 or group != 1:
83
- raise GKException("For 'Conv2D', value of attr 'groups' and 'group' should be both 1, but got {} and {}."
84
- .format(groups, group))
85
-
86
- dilation = self.attrs['dilation']
87
- check_nd(dilation, 4)
88
- if dilation != [1, 1, 1, 1]:
89
- raise GKException("For 'Conv2D', value of attr 'dilation' should be [1, 1, 1, 1], but got {}"
90
- .format(dilation))
91
-
92
- def _check(self):
93
- self._common_check()
94
-
95
- pad_list = self.attrs['pad_list']
96
- check_nd(pad_list, 4)
97
- self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode'])
98
-
99
- shape_0 = self.inputs[0]['shape']
100
- shape_1 = self.inputs[1]['shape']
101
- stride = self.attrs['stride']
102
- check_nd(shape_0, 4)
103
- check_nd(shape_1, 4)
104
- check_nd(stride, 4)
105
- n0, h0, w0, c0 = shape_0
106
- n1, h1, w1, c1 = shape_1
107
- if (n0 % self.N0_CHANNEL_ALIGN) != 0:
108
- raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}"
109
- .format(self.N0_CHANNEL_ALIGN, n0))
110
- if (n1 % self.N1_CHANNEL_ALIGN) != 0:
111
- raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}"
112
- .format(self.N1_CHANNEL_ALIGN, n1))
113
- if c0 != c1 or (c0 % self.C_CHANNEL_ALIGN) != 0:
114
- raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got "
115
- "{} and {}".format(self.C_CHANNEL_ALIGN, c0, c1))
116
- # n0 pad
117
- n0 = ((n0 + self.N0_CHANNEL_ALIGN - 1) //
118
- self.N0_CHANNEL_ALIGN) * self.N0_CHANNEL_ALIGN
119
- # h0, w0 pad
120
- if self.has_pad:
121
- h0 = h0 + pad_list[0] + pad_list[1]
122
- w0 = w0 + pad_list[2] + pad_list[3]
123
- # c0, c1 pad
124
- c0 = ((c0 + self.C_CHANNEL_ALIGN - 1) // self.C_CHANNEL_ALIGN) * self.C_CHANNEL_ALIGN
125
- c1 = c0
126
- # n1 pad
127
- n1 = ((n1 + self.N1_CHANNEL_ALIGN - 1) //
128
- self.N1_CHANNEL_ALIGN) * self.N1_CHANNEL_ALIGN
129
-
130
- # check if can optimize to matmul
131
- self.m, self.n, self.k = n0 * h0 * w0, n1, c1
132
- self.can_optimize_to_matmul = self._optimize_to_matmul()
133
-
134
- # requirements
135
- if self.can_optimize_to_matmul:
136
- if self.k > self.K_LIMIT:
137
- raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got "
138
- "{}".format(self.K_LIMIT, self.k))
139
- if self.m * self.n * self.k >= self.MNK_LIMIT:
140
- raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than "
141
- "{}, but got {}".format(self.MNK_LIMIT, self.m * self.n * self.k))
142
- else:
143
- out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
144
- if ((n0 * out_h * out_w) % self.OUT_NHW_ALIGN) != 0:
145
- raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}"
146
- .format(n0, out_h, out_w, self.OUT_NHW_ALIGN))
147
- if stride != [1, 1, 2, 2]:
148
- raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}"
149
- .format(stride))
150
-
151
- self.shape_0_pad = [n0, h0, w0, c0]
152
- self.shape_1_pad = [n1, h1, w1, c1]
153
-
154
- def _expand(self, graph_builder):
155
- input_0 = self.inputs[0]
156
- input_1 = self.inputs[1]
157
- n0, _, _, c0 = input_0.shape
158
- n1, _, _, c1 = input_1.shape
159
- n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
160
- n1_p, _, _, c1_p = self.shape_1_pad
161
-
162
- pad_value = 0
163
- # input0 pad
164
- input_0_pad_before = [0, 0, 0, 0]
165
- input_0_pad_after = [0, 0, 0, 0]
166
- if self.has_pad:
167
- pad_list = self.attrs['pad_list']
168
- input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
169
- input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
170
- input_0_pad_after[0] = n0_p - n0
171
- input_0_pad_after[3] = c0_p - c0
172
- if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
173
- input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
174
- 'tail': input_0_pad_after,
175
- 'pad_val': pad_value})
176
- # input1 pad
177
- input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
178
- if input_1_pad_after != [0, 0, 0, 0]:
179
- input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
180
- 'tail': input_1_pad_after,
181
- 'pad_val': pad_value})
182
- if self.can_optimize_to_matmul:
183
- a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
184
- b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
185
- c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
186
- 'transpose_b': True,
187
- 'dst_type': self.dst_type})
188
- result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
189
- 'format': self.dst_format})
190
- else:
191
- attrs = self.attrs
192
- attrs['pad_list'] = [0, 0, 0, 0]
193
- attrs['dst_type'] = self.dst_type
194
- result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
195
- # unpad
196
- unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]
197
- if unpad_after != [0, 0, 0, 0]:
198
- result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after})
199
-
200
- return result
@@ -1,30 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for DropoutGrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- @VLD.check_attrs('keep_prob')
21
- class DropoutGrad(Expander):
22
- """DropoutGrad expander"""
23
-
24
- def _expand(self, graph_builder):
25
- input_dy, input_mask = self.inputs
26
- keep_prob = self.attrs['keep_prob']
27
- r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
28
- result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
29
- result = graph_builder.emit('Mul', [result, input_mask])
30
- return result
@@ -1,50 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for equal_count"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class EqualCount(Expander):
22
- """EqualCount expander"""
23
-
24
- def __init__(self, expand_info):
25
- super().__init__(expand_info)
26
- self.shape_x = self.inputs[0]['shape']
27
- self.shape_y = self.inputs[1]['shape']
28
- self.dtype_x = self.inputs[0]['data_type']
29
- self.dtype_y = self.inputs[1]['data_type']
30
-
31
- def _check(self):
32
- if self.shape_x != self.shape_y:
33
- raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}"
34
- .format(self.shape_x, self.shape_y))
35
- if self.dtype_x != self.dtype_y:
36
- raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}"
37
- .format(self.dtype_x, self.dtype_y))
38
-
39
- def _expand(self, graph_builder):
40
- input_x = self.inputs[0]
41
- input_y = self.inputs[1]
42
-
43
- eql_val = graph_builder.emit('Equal', [input_x, input_y])
44
- cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
45
- axis = list(range(len(input_x.shape)))
46
- result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': False})
47
-
48
- if result.dtype != input_x.dtype:
49
- result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
50
- return result
@@ -1,35 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for erfc"""
16
- from ._utils import Expander
17
-
18
-
19
- class Erfc(Expander):
20
- """Erfc expander"""
21
-
22
- def _expand(self, graph_builder):
23
- input_x = self.inputs[0]
24
- result = None
25
- if input_x.dtype == "float16":
26
- const_one = graph_builder.value("float32", 1)
27
- input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
28
- erf_result = graph_builder.emit('Erf', [input_x])
29
- result = graph_builder.emit('Sub', [const_one, erf_result])
30
- result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
31
- return result
32
- const_one = graph_builder.value(input_x.dtype, 1)
33
- erf_result = graph_builder.emit('Erf', [input_x])
34
- result = graph_builder.emit('Sub', [const_one, erf_result])
35
- return result
@@ -1,50 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for expand_dims"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_attrs('axis')
20
- class ExpandDims(Expander):
21
- """ExpandDims expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x = self.inputs[0]
25
- shape = self.infer_shape(input_x.shape, self.attrs['axis'])
26
- result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape})
27
-
28
- return result
29
-
30
- @staticmethod
31
- def infer_shape(shape, axis):
32
- """infer shape for expand_dims"""
33
- def insert_axis(shape, axis):
34
- if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1:
35
- raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, "
36
- "{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis)))
37
- if axis >= 0:
38
- shape.insert(axis, 1)
39
- else:
40
- shape.insert(axis + len(shape) + 1, 1)
41
- return shape
42
- out_shape = shape[:]
43
- if isinstance(axis, int):
44
- return insert_axis(out_shape, axis)
45
- if isinstance(axis, (list, tuple)):
46
- for i in axis:
47
- out_shape = insert_axis(out_shape, i)
48
- return out_shape
49
- raise ValueError("For 'ExpandDims', type of attr 'axis' should be one of ['int', 'list', 'tuple'], but got {} "
50
- "with type {}".format(axis, type(axis)))
@@ -1,44 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for fused_adam"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class FusedAdam(Expander):
21
- """FusedAdam expander"""
22
-
23
- def _expand(self, graph_builder):
24
- beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
25
-
26
- beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
27
- one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
28
- next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
29
- beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
30
- grad_square = graph_builder.emit('Mul', [gradient, gradient])
31
- one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
32
- next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
33
- sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
34
- sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
35
- update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
36
- update_with_lr = graph_builder.emit('Mul', [lr, update])
37
- next_para = graph_builder.emit('Sub', [param, update_with_lr])
38
-
39
- param_result = graph_builder.emit(
40
- 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
41
- param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
42
- param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
43
-
44
- return param_result