mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.0__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (550) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  172. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  173. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  174. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  175. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  176. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  177. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  178. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  179. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  180. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  181. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  182. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  183. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  184. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  185. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  186. mindspore/nn/__init__.py +0 -2
  187. mindspore/nn/cell.py +316 -74
  188. mindspore/nn/dynamic_lr.py +21 -21
  189. mindspore/nn/layer/activation.py +21 -28
  190. mindspore/nn/layer/basic.py +15 -13
  191. mindspore/nn/layer/channel_shuffle.py +1 -1
  192. mindspore/nn/layer/container.py +271 -9
  193. mindspore/nn/layer/conv.py +310 -207
  194. mindspore/nn/layer/dense.py +8 -5
  195. mindspore/nn/layer/embedding.py +33 -27
  196. mindspore/nn/layer/flash_attention.py +82 -41
  197. mindspore/nn/layer/image.py +8 -6
  198. mindspore/nn/layer/math.py +13 -18
  199. mindspore/nn/layer/normalization.py +107 -66
  200. mindspore/nn/layer/padding.py +1 -1
  201. mindspore/nn/layer/pooling.py +131 -109
  202. mindspore/nn/layer/rnn_cells.py +22 -17
  203. mindspore/nn/layer/rnns.py +13 -16
  204. mindspore/nn/layer/thor_layer.py +1 -1
  205. mindspore/nn/layer/transformer.py +221 -154
  206. mindspore/nn/learning_rate_schedule.py +9 -1
  207. mindspore/nn/loss/loss.py +235 -174
  208. mindspore/nn/optim/ada_grad.py +2 -1
  209. mindspore/nn/optim/adadelta.py +1 -0
  210. mindspore/nn/optim/adafactor.py +2 -1
  211. mindspore/nn/optim/adam.py +7 -4
  212. mindspore/nn/optim/adamax.py +3 -2
  213. mindspore/nn/optim/adasum.py +2 -2
  214. mindspore/nn/optim/asgd.py +2 -3
  215. mindspore/nn/optim/ftrl.py +6 -5
  216. mindspore/nn/optim/lamb.py +7 -4
  217. mindspore/nn/optim/lars.py +1 -1
  218. mindspore/nn/optim/lazyadam.py +5 -3
  219. mindspore/nn/optim/momentum.py +2 -1
  220. mindspore/nn/optim/optimizer.py +53 -4
  221. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  222. mindspore/nn/optim/rmsprop.py +4 -3
  223. mindspore/nn/optim/rprop.py +23 -12
  224. mindspore/nn/optim/sgd.py +26 -11
  225. mindspore/nn/optim/thor.py +9 -7
  226. mindspore/nn/probability/bijector/bijector.py +5 -5
  227. mindspore/nn/probability/bijector/power_transform.py +27 -27
  228. mindspore/nn/probability/bijector/softplus.py +3 -3
  229. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  230. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  231. mindspore/nn/probability/distribution/beta.py +3 -3
  232. mindspore/nn/probability/distribution/categorical.py +7 -7
  233. mindspore/nn/probability/distribution/cauchy.py +0 -1
  234. mindspore/nn/probability/distribution/distribution.py +3 -3
  235. mindspore/nn/probability/distribution/gamma.py +3 -3
  236. mindspore/nn/probability/distribution/geometric.py +4 -4
  237. mindspore/nn/probability/distribution/gumbel.py +4 -4
  238. mindspore/nn/probability/distribution/log_normal.py +2 -2
  239. mindspore/nn/probability/distribution/logistic.py +2 -2
  240. mindspore/nn/probability/distribution/poisson.py +4 -4
  241. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  242. mindspore/nn/probability/distribution/uniform.py +6 -6
  243. mindspore/nn/wrap/cell_wrapper.py +78 -34
  244. mindspore/nn/wrap/grad_reducer.py +8 -5
  245. mindspore/nn/wrap/loss_scale.py +105 -42
  246. mindspore/numpy/array_creations.py +1 -2
  247. mindspore/numpy/array_ops.py +3 -2
  248. mindspore/offline_debug/convert_async.py +2 -2
  249. mindspore/ops/_grad_experimental/__init__.py +0 -5
  250. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  251. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  252. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  253. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  254. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  255. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  256. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  257. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  258. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  259. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  260. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  261. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  262. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  263. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  264. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  265. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  266. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  267. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  268. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  269. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  270. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  271. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  272. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  273. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  274. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  275. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  276. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  277. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  278. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  279. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  280. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  281. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  282. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  283. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  284. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  285. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  286. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  287. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  288. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  289. mindspore/ops/_primitive_cache.py +1 -1
  290. mindspore/ops/_tracefunc.py +45 -13
  291. mindspore/ops/_utils/utils.py +4 -1
  292. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  293. mindspore/ops/_vmap/vmap_base.py +3 -3
  294. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  295. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  296. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  297. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  298. mindspore/ops/arg_dtype_cast.py +54 -0
  299. mindspore/ops/composite/base.py +37 -10
  300. mindspore/ops/composite/math_ops.py +5 -4
  301. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  302. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  303. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  304. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  305. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  306. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  307. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  308. mindspore/ops/deprecated.py +304 -0
  309. mindspore/ops/function/__init__.py +4 -1
  310. mindspore/ops/function/array_func.py +167 -189
  311. mindspore/ops/function/clip_func.py +81 -13
  312. mindspore/ops/function/debug_func.py +1 -1
  313. mindspore/ops/function/grad/grad_func.py +18 -8
  314. mindspore/ops/function/image_func.py +10 -4
  315. mindspore/ops/function/linalg_func.py +5 -5
  316. mindspore/ops/function/math_func.py +575 -386
  317. mindspore/ops/function/nn_func.py +470 -251
  318. mindspore/ops/function/random_func.py +86 -56
  319. mindspore/ops/function/sparse_func.py +1 -1
  320. mindspore/ops/function/sparse_unary_func.py +14 -12
  321. mindspore/ops/function/vmap_func.py +6 -5
  322. mindspore/ops/functional.py +15 -10
  323. mindspore/ops/op_info_register.py +235 -19
  324. mindspore/ops/operations/__init__.py +25 -17
  325. mindspore/ops/operations/_grad_ops.py +52 -7
  326. mindspore/ops/operations/_inner_ops.py +213 -12
  327. mindspore/ops/operations/_quant_ops.py +4 -8
  328. mindspore/ops/operations/_sequence_ops.py +42 -0
  329. mindspore/ops/operations/array_ops.py +64 -280
  330. mindspore/ops/operations/comm_ops.py +105 -57
  331. mindspore/ops/operations/custom_ops.py +10 -3
  332. mindspore/ops/operations/debug_ops.py +8 -4
  333. mindspore/ops/operations/image_ops.py +18 -12
  334. mindspore/ops/operations/math_ops.py +185 -138
  335. mindspore/ops/operations/nn_ops.py +716 -492
  336. mindspore/ops/operations/other_ops.py +0 -22
  337. mindspore/ops/operations/random_ops.py +53 -111
  338. mindspore/ops/operations/sparse_ops.py +3 -1
  339. mindspore/ops/primitive.py +24 -18
  340. mindspore/parallel/_auto_parallel_context.py +68 -8
  341. mindspore/parallel/_cost_model_context.py +2 -2
  342. mindspore/parallel/_offload_context.py +17 -3
  343. mindspore/parallel/_parallel_serialization.py +2 -2
  344. mindspore/parallel/_ps_context.py +12 -0
  345. mindspore/parallel/_tensor.py +14 -12
  346. mindspore/parallel/_transformer/layers.py +5 -3
  347. mindspore/parallel/_transformer/loss.py +1 -0
  348. mindspore/parallel/_transformer/moe.py +2 -2
  349. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  350. mindspore/parallel/_transformer/transformer.py +23 -3
  351. mindspore/parallel/_utils.py +11 -7
  352. mindspore/parallel/algo_parameter_config.py +85 -5
  353. mindspore/parallel/checkpoint_transform.py +6 -10
  354. mindspore/parallel/shard.py +4 -4
  355. mindspore/profiler/common/struct_type.py +3 -3
  356. mindspore/profiler/common/util.py +3 -2
  357. mindspore/profiler/envprofiling.py +1 -1
  358. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  359. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  360. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  361. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  362. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  363. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  364. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  365. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  366. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  367. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  368. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  369. mindspore/profiler/parser/flops_parser.py +15 -11
  370. mindspore/profiler/parser/framework_parser.py +37 -21
  371. mindspore/profiler/parser/hccl_parser.py +16 -12
  372. mindspore/profiler/parser/integrator.py +22 -11
  373. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  374. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  375. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  376. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  377. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  378. mindspore/profiler/parser/optime_parser.py +1 -1
  379. mindspore/profiler/parser/profiler_info.py +2 -2
  380. mindspore/profiler/parser/step_trace_parser.py +11 -14
  381. mindspore/profiler/profiling.py +139 -71
  382. mindspore/rewrite/api/node.py +102 -19
  383. mindspore/rewrite/api/node_type.py +5 -1
  384. mindspore/rewrite/api/scoped_value.py +9 -17
  385. mindspore/rewrite/api/symbol_tree.py +131 -47
  386. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  387. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  388. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  389. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  390. mindspore/rewrite/common/rewrite_elog.py +5 -1
  391. mindspore/rewrite/namer.py +33 -24
  392. mindspore/rewrite/namespace.py +14 -5
  393. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  394. mindspore/rewrite/node/call_function.py +79 -0
  395. mindspore/rewrite/node/cell_container.py +135 -0
  396. mindspore/rewrite/node/control_flow.py +88 -0
  397. mindspore/rewrite/{node.py → node/node.py} +273 -234
  398. mindspore/rewrite/node/node_manager.py +254 -0
  399. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  400. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  401. mindspore/rewrite/parsers/assign_parser.py +216 -221
  402. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  403. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  404. mindspore/rewrite/parsers/constant_parser.py +9 -6
  405. mindspore/rewrite/parsers/container_parser.py +9 -7
  406. mindspore/rewrite/parsers/for_parser.py +36 -15
  407. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  408. mindspore/rewrite/parsers/if_parser.py +28 -24
  409. mindspore/rewrite/parsers/module_parser.py +196 -25
  410. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  411. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  412. mindspore/rewrite/parsers/return_parser.py +6 -6
  413. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  414. mindspore/rewrite/sparsify/utils.py +1 -1
  415. mindspore/rewrite/symbol_tree.py +525 -577
  416. mindspore/rewrite/symbol_tree_builder.py +9 -193
  417. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  418. mindspore/run_check/_check_version.py +2 -2
  419. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  420. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  421. mindspore/scipy/linalg.py +1 -1
  422. mindspore/scipy/optimize/minimize.py +7 -3
  423. mindspore/train/_utils.py +7 -3
  424. mindspore/train/amp.py +323 -123
  425. mindspore/train/anf_ir_pb2.py +14 -2
  426. mindspore/train/callback/_backup_and_restore.py +2 -12
  427. mindspore/train/callback/_callback.py +29 -4
  428. mindspore/train/callback/_checkpoint.py +23 -8
  429. mindspore/train/callback/_early_stop.py +2 -2
  430. mindspore/train/callback/_landscape.py +4 -4
  431. mindspore/train/callback/_loss_monitor.py +2 -2
  432. mindspore/train/callback/_on_request_exit.py +2 -2
  433. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  434. mindspore/train/callback/_summary_collector.py +14 -7
  435. mindspore/train/callback/_time_monitor.py +58 -5
  436. mindspore/train/data_sink.py +5 -11
  437. mindspore/train/dataset_helper.py +83 -57
  438. mindspore/train/loss_scale_manager.py +2 -2
  439. mindspore/train/metrics/__init__.py +3 -3
  440. mindspore/train/metrics/cosine_similarity.py +1 -1
  441. mindspore/train/metrics/hausdorff_distance.py +3 -2
  442. mindspore/train/metrics/mean_surface_distance.py +3 -2
  443. mindspore/train/metrics/metric.py +39 -19
  444. mindspore/train/metrics/roc.py +2 -2
  445. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  446. mindspore/train/mind_ir_pb2.py +85 -36
  447. mindspore/train/model.py +185 -45
  448. mindspore/train/serialization.py +390 -150
  449. mindspore/train/summary/_writer_pool.py +3 -2
  450. mindspore/train/summary/summary_record.py +14 -10
  451. mindspore/train/train_thor/convert_utils.py +3 -3
  452. mindspore/train/train_thor/dataset_helper.py +1 -1
  453. mindspore/version.py +1 -1
  454. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  455. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +458 -518
  456. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  457. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  458. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  459. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  460. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  461. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  462. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  463. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  464. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  465. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  466. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  467. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  468. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  469. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  470. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  471. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  472. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  473. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  474. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  475. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  476. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  477. mindspore/_extends/graph_kernel/expander.py +0 -80
  478. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  479. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  480. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  481. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  482. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  483. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  484. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  485. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  486. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  487. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  488. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  489. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  490. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  491. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  492. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  493. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  494. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  495. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  496. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  497. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  498. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  499. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  500. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  501. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  502. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  503. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  504. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  505. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  506. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  507. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  508. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  509. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  510. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  511. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  512. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  514. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  515. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  516. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  517. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  518. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  519. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  520. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  521. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  522. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  523. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  524. mindspore/dataset/datapreprocess/__init__.py +0 -20
  525. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  526. mindspore/include/api/net.h +0 -142
  527. mindspore/nn/lr_scheduler.py +0 -262
  528. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  529. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  530. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  531. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  532. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  533. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  537. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  538. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  539. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  540. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  541. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  542. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  543. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  544. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  545. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  546. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  547. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  548. mindspore/rewrite/node_visitor.py +0 -44
  549. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  550. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,269 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """GraphKernel expander utils"""
16
- from abc import ABCMeta, abstractmethod
17
- from mindspore._extends.graph_kernel.model import model_builder as builder
18
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
19
-
20
-
21
- class Expander(metaclass=ABCMeta):
22
- """
23
- Expander is the base class of expanders.
24
-
25
- The method `_expand` should be overridden to implement the operator detail.
26
- """
27
- def __init__(self, expand_info):
28
- self.name = expand_info["name"]
29
- self.inputs = expand_info["input_desc"]
30
- self.outputs = expand_info["output_desc"]
31
- self.attrs = expand_info["attr"]
32
- self.processor = expand_info["process"]
33
-
34
- def run(self):
35
- """
36
- Expand the operator to a graph.
37
-
38
- `GraphKernelUnsupportedException` would be raised if check failed.
39
- """
40
- self._check()
41
- graph_builder = builder.GraphBuilder()
42
- with graph_builder.graph_scope(self.name) as graph_scope:
43
- # transform input_desc to Tensor
44
- self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs]
45
- graph_scope.set_input(*self.inputs)
46
- outputs = self._expand(graph_builder)
47
- if isinstance(outputs, (list, tuple)):
48
- self._check_output_same(outputs)
49
- graph_scope.set_output(*outputs)
50
- else:
51
- self._check_output_same([outputs])
52
- graph_scope.set_output(outputs)
53
-
54
- graph = graph_builder.get()[0]
55
- graph.set_processor(self.processor)
56
- return graph
57
-
58
- def _check(self):
59
- """Check inputs"""
60
-
61
- def _check_output_same(self, outputs):
62
- for index, value in enumerate(self.outputs):
63
- if list(outputs[index].shape) != list(value['shape']):
64
- raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
65
- self.__class__.__name__, list(outputs[index].shape), list(value['shape'])))
66
- if outputs[index].dtype != value['data_type']:
67
- raise GKException("{} 's output data_type {} is wrong. Expected: {}".format(
68
- self.__class__.__name__, outputs[index].dtype, value['data_type']))
69
- if outputs[index].data_format != value['format']:
70
- raise GKException("{} 's output format {} is wrong. Expected: {}".format(
71
- self.__class__.__name__, outputs[index].data_format, value['format']))
72
-
73
- @abstractmethod
74
- def _expand(self, graph_builder):
75
- """Expand operator, this function should be overridden in subclass"""
76
- raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
77
-
78
-
79
- class ExpanderInfoValidator:
80
- """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
81
-
82
- def __init__(self):
83
- """Init"""
84
-
85
- @staticmethod
86
- def _add_check_function(kls, func):
87
- """
88
- Rewrite the function `_check` in class Expander
89
- to append the new `func` after the original checks.
90
- """
91
- old_check = getattr(kls, "_check")
92
-
93
- def new_check(obj):
94
- old_check(obj)
95
- func(obj)
96
-
97
- setattr(kls, "_check", new_check)
98
-
99
- @staticmethod
100
- def add_format(*input_format):
101
- """
102
- Add new supported format for the operator
103
-
104
- this function will add a list `__supported_formats` into the expander,
105
- saving the whitelist of formats that this op supports.
106
- it also rewrites the `_check` function to check the formats.
107
- """
108
- format_list_name = "__supported_formats"
109
-
110
- def _check_format(obj):
111
- inp_formats = [inp['format'] for inp in obj.inputs]
112
- for formats in getattr(obj, format_list_name):
113
- if len(formats) != len(inp_formats):
114
- raise GKException("For '{}', length of registered format is different from the length of inputs "
115
- "format: {} vs {}".format(obj.name, len(formats), len(inp_formats)))
116
- if all((fmt == inp for fmt, inp in zip(formats, inp_formats))):
117
- return
118
- raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
119
-
120
- def wrapper(cls):
121
- if not issubclass(cls, Expander):
122
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
123
- if not hasattr(cls, format_list_name):
124
- setattr(cls, format_list_name, list())
125
- ExpanderInfoValidator._add_check_function(cls, _check_format)
126
- getattr(cls, format_list_name).append(input_format)
127
- return cls
128
-
129
- return wrapper
130
-
131
- @staticmethod
132
- def check_all_formats_same(kls):
133
- """Check that all formats are the same"""
134
-
135
- # Ensure no args case can return a class
136
- def _check(*args):
137
- def _check_format(obj):
138
- inp_formats = [inp['format'] for inp in obj.inputs]
139
- if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
140
- return
141
- raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
142
- ','.join(inp_formats), obj.name))
143
-
144
- def wrapper(cls):
145
- if not issubclass(cls, Expander):
146
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
147
- ExpanderInfoValidator._add_check_function(cls, _check_format)
148
- return cls
149
-
150
- return wrapper
151
-
152
- return _check()(kls)
153
-
154
- @staticmethod
155
- def check_attrs(*args):
156
- """Check the attrs exist"""
157
-
158
- def _check_attr(obj):
159
- for a in args:
160
- if a not in obj.attrs:
161
- raise GKException("attr '{}' does not exist. {}".format(a, obj.name))
162
-
163
- def wrapper(cls):
164
- if not issubclass(cls, Expander):
165
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
166
- ExpanderInfoValidator._add_check_function(cls, _check_attr)
167
- return cls
168
-
169
- return wrapper
170
-
171
-
172
- def to_frac_z_axis(ori_shape, ori_axis):
173
- """
174
- judge the format is fractal NZ
175
- Parameters
176
- ----------
177
- ori_shape: list or tuple
178
- original shape of input
179
- ori_axis: list or tuple
180
- original axis of original shape to operate
181
- Returns
182
- -------
183
- output: list
184
- axis of the fractal Nz shape
185
- """
186
- frac_z_axis = list(ori_axis)
187
- shape_len = len(ori_shape)
188
- axis_count = len(frac_z_axis)
189
- axis_negative_1 = shape_len - 1
190
- axis_negative_2 = shape_len - 2
191
- for i in range(axis_count):
192
- axis_index = (frac_z_axis[i] + shape_len) % shape_len
193
- if axis_index == axis_negative_1:
194
- if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
195
- frac_z_axis[i] = axis_index - 1
196
- frac_z_axis.append(axis_index + 2)
197
- else: # no case cover this branch now
198
- frac_z_axis[i] = axis_index - 1
199
- frac_z_axis.append(axis_index + 2)
200
- elif axis_index == axis_negative_2:
201
- frac_z_axis[i] = axis_index + 1
202
- frac_z_axis.append(axis_index + 2)
203
- else:
204
- frac_z_axis[i] = axis_index
205
- return frac_z_axis
206
-
207
-
208
- def infer_shape_from_fractalnz(fractal):
209
- "get original shape from fractalnz shape"
210
- shape = []
211
- dims = len(fractal)
212
- batch = dims - 4
213
- for i in range(batch):
214
- shape.append(fractal[i])
215
- m = fractal[dims - 3] * fractal[dims - 2]
216
- n = fractal[dims - 4] * fractal[dims - 1]
217
- shape.append(m)
218
- shape.append(n)
219
- return shape
220
-
221
-
222
- def get_reduced_ori_shape(shape, axis):
223
- "get shape after reduced which is based on original shape"
224
- reduced_ori_shape = []
225
- for i, value in enumerate(shape):
226
- if i in axis:
227
- reduced_ori_shape.append(1)
228
- else:
229
- reduced_ori_shape.append(value)
230
- return reduced_ori_shape
231
-
232
-
233
- def get_reduce_axis_shape(shape, data_format, axis):
234
- """
235
- Get the reduce axis under format `data_format` and original reduced shape.
236
- Parameters
237
- ----------
238
- shape: list or tuple
239
- shape of input
240
- data_format: str
241
- data format of input
242
- axis: None, int, list or tuple
243
- reduce axis of the original shape
244
- Returns
245
- -------
246
- reduce_axis: list
247
- reduce axis of the `data_format` shape
248
- ori_reduced_shape: list
249
- original reduced shape
250
- """
251
- ori_shape = shape
252
- if data_format == "FRACTAL_NZ":
253
- ori_shape = infer_shape_from_fractalnz(shape)
254
- if not axis:
255
- axis = []
256
- for i, _ in enumerate(ori_shape):
257
- axis.append(i)
258
- else:
259
- if isinstance(axis, int):
260
- axis = [axis]
261
- for i, _ in enumerate(list(axis)):
262
- if axis[i] < 0:
263
- axis[i] += len(ori_shape)
264
-
265
- ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
266
- reduce_axis = axis
267
- if data_format == "FRACTAL_NZ":
268
- reduce_axis = to_frac_z_axis(ori_shape, axis)
269
- return reduce_axis, ori_reduced_shape
@@ -1,33 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for addn"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class AddN(Expander):
22
- """Expand AddN to multiple Adds"""
23
-
24
- def _check(self):
25
- if len(self.inputs) < 2:
26
- raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
27
- .format(len(self.inputs)))
28
-
29
- def _expand(self, graph_builder):
30
- result = self.inputs[0]
31
- for inp in self.inputs[1:]:
32
- result = graph_builder.emit('Add', [result, inp])
33
- return result
@@ -1,152 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for BatchNorm"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from .expand_dims import ExpandDims
19
-
20
-
21
- @VLD.add_format(DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
22
- @VLD.add_format(DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
24
- @VLD.check_attrs('is_training', 'momentum', 'epsilon')
25
- class BatchNorm(Expander):
26
- """BatchNorm expander"""
27
-
28
- def _expand(self, graph_builder):
29
- # get op info
30
- input_x = self.inputs[0]
31
- input_scale = self.inputs[1]
32
- input_offset = self.inputs[2]
33
- input_mean = self.inputs[3]
34
- input_variance = self.inputs[4]
35
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
36
-
37
- input_x_ori_type = input_x.dtype
38
- input_x_new_type = input_x.dtype
39
- if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
40
- input_mean.dtype == "float32" and input_variance.dtype == "float32":
41
- input_x_new_type = "float32"
42
- if input_x_new_type != input_x_ori_type:
43
- input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
44
-
45
- if self.attrs['is_training']:
46
- self.inputs[0] = input_x
47
- res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
48
- if input_x_new_type != input_x_ori_type:
49
- res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
50
- return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
51
- # infer mode
52
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
53
- input_mean = graph_builder.emit(
54
- 'Reshape', [input_mean], attrs={'shape': ExpandDims.infer_shape(input_mean.shape, [-1, -1])})
55
- input_scale = graph_builder.emit(
56
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
57
- input_offset = graph_builder.emit(
58
- 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
59
- x_sub = graph_builder.emit('Sub', [input_x, input_mean])
60
- x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub])
61
- var_add = graph_builder.emit('Add', [epsilon_v, input_variance])
62
- var_add_sqrt = graph_builder.emit('Sqrt', [var_add])
63
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
64
- var_add_sqrt = graph_builder.emit(
65
- 'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
66
- x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
67
- res_y = graph_builder.emit('Add', [input_offset, x_div])
68
- if input_x_new_type != input_x_ori_type:
69
- res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
70
- return res_y, var_add, var_add, var_add, var_add
71
-
72
- def _bn_train(self, graph_builder):
73
- """expand BatchNorm for training mode"""
74
- input_x = self.inputs[0]
75
- input_scale = self.inputs[1]
76
- input_offset = self.inputs[2]
77
- input_mean = self.inputs[3]
78
- input_variance = self.inputs[4]
79
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
80
- reduce_axis = ()
81
- shape_x = input_x.shape
82
- if input_x.data_format == DF.NHWC:
83
- reduce_axis = (0, 1, 2)
84
- num = shape_x[0] * shape_x[1] * shape_x[2]
85
- else:
86
- reduce_axis = (0, 2, 3)
87
- num = shape_x[0] * shape_x[2] * shape_x[3]
88
- num_rec = 1.0 / num
89
- num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
90
-
91
- # compute mean value of input_x
92
- mean_sum = graph_builder.emit(
93
- 'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
94
- mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
95
-
96
- # compute variance of input_x
97
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
98
- mean_muls_expand = graph_builder.emit(
99
- 'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
100
- else:
101
- mean_muls_expand = mean_muls
102
- var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
103
- var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
104
- var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
105
- var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
106
-
107
- # y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
108
- scalar_one = 1.0
109
- scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
110
- y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
111
- y_sqrt = graph_builder.emit('Sqrt', [y_add])
112
- y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
113
-
114
- # compute res_y
115
- tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
116
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
117
- y_sqrt_rec_expand = graph_builder.emit(
118
- 'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
119
- else:
120
- y_sqrt_rec_expand = y_sqrt_rec
121
- y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
122
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
123
- input_scale_expand = graph_builder.emit(
124
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
125
- else:
126
- input_scale_expand = input_scale
127
- res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
128
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
129
- input_offset_expand = graph_builder.emit(
130
- 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
131
- else:
132
- input_offset_expand = input_offset
133
- res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
134
-
135
- # compute mean_res
136
- momentum_sub = scalar_one - self.attrs['momentum']
137
- momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
138
- new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
139
- momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
140
- current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
141
- updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
142
- mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
143
-
144
- # variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
145
- var_num = float(num) / (num - 1)
146
- var_num_v = graph_builder.value(input_scale.dtype, var_num)
147
- var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
148
- new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
149
- current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
150
- updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
151
- variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
152
- return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
@@ -1,105 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for BatchNormGrad"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from .expand_dims import ExpandDims
19
-
20
-
21
- @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
22
- @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
24
- @VLD.check_attrs('is_training', 'epsilon')
25
- class BatchNormGrad(Expander):
26
- """BatchNormGrad expander"""
27
-
28
- def _expand(self, graph_builder):
29
- # get op info
30
- input_dy = self.inputs[0]
31
- input_x = self.inputs[1]
32
- input_scale = self.inputs[2]
33
- input_save_mean = self.inputs[3]
34
- input_save_inv_variance = self.inputs[4]
35
-
36
- reduce_axis = ()
37
- shape_x = input_x.shape
38
- if input_x.data_format == DF.NHWC:
39
- reduce_axis = (0, 1, 2)
40
- num = shape_x[0] * shape_x[1] * shape_x[2]
41
- else:
42
- reduce_axis = (0, 2, 3)
43
- num = shape_x[0] * shape_x[2] * shape_x[3]
44
- ori_type = input_x.dtype
45
- if ori_type == 'float16':
46
- input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
47
- if input_dy.dtype == 'float16':
48
- input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
49
- num_rec = -1.0 / num
50
- num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
51
- dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
52
-
53
- # in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass
54
- if self.attrs['is_training']:
55
- inv_variance = input_save_inv_variance
56
- else:
57
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
58
- var_add = graph_builder.emit('Add', [input_save_inv_variance, epsilon_v])
59
- sqrt_var_eps = graph_builder.emit('Sqrt', [var_add])
60
- scalar_one = 1.0
61
- scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
62
- inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
63
-
64
- # compute dgamma
65
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
66
- input_save_mean = graph_builder.emit(
67
- 'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
68
- inv_variance = graph_builder.emit(
69
- 'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
70
- input_scale = graph_builder.emit(
71
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
72
- x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean])
73
- x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance])
74
- dgamma_param = graph_builder.emit('Mul', [input_dy, x_div])
75
- dgamma = graph_builder.emit(
76
- 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
77
-
78
- # compute dx
79
- if self.attrs['is_training']:
80
- tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
81
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
82
- dgamma_expand = graph_builder.emit(
83
- 'Reshape', [dgamma], attrs={'shape': ExpandDims.infer_shape(dgamma.shape, [-1, -1])})
84
- tmp_b = graph_builder.emit(
85
- 'Reshape', [tmp_b], attrs={'shape': ExpandDims.infer_shape(tmp_b.shape, [-1, -1])})
86
- else:
87
- dgamma_expand = dgamma
88
- x_sub_mean_dgamma_mul = graph_builder.emit('Mul', [x_div, dgamma_expand])
89
- tmp_c = graph_builder.emit('Mul', [num_rec_v, x_sub_mean_dgamma_mul])
90
- tmp_ab_add = graph_builder.emit('Add', [input_dy, tmp_b])
91
- tmp_abc_add = graph_builder.emit('Add', [tmp_ab_add, tmp_c])
92
- gamma_mul = graph_builder.emit('Mul', [input_scale, tmp_abc_add])
93
- dx = graph_builder.emit('Mul', [inv_variance, gamma_mul])
94
- else:
95
- y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
96
- dx = graph_builder.emit('Mul', [inv_variance, y_scale])
97
- if ori_type == 'float16':
98
- dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
99
-
100
- # set output tensors' data_format
101
- dx.data_format = self.outputs[0]['format']
102
- dgamma.data_format = self.outputs[1]['format']
103
- dbeta.data_format = self.outputs[2]['format']
104
-
105
- return dx, dgamma, dbeta
@@ -1,33 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for ClipByNormNoDivSum"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class ClipByNormNoDivSum(Expander):
21
- """ClipByNormNoDivSum expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x0, input_x1, input_x2, input_x3 = self.inputs
25
-
26
- # cal result
27
- greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
28
- select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
29
- sqrt_res = graph_builder.emit('Sqrt', [select_res0])
30
- select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
31
- result = graph_builder.emit('Maximum', [select_res1, input_x3])
32
-
33
- return result
@@ -1,30 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cabs"""
16
- from mindspore._extends.graph_kernel.expanders._utils import Expander
17
-
18
-
19
- class CAbs(Expander):
20
- """CAbs expander"""
21
-
22
- def _expand(self, graph_builder):
23
- input_x = self.inputs[0]
24
- x_real = graph_builder.emit('CReal', [input_x])
25
- x_imag = graph_builder.emit('CImag', [input_x])
26
- squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
27
- squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
28
- squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
29
- result = graph_builder.emit('Sqrt', [squre_sum])
30
- return result
@@ -1,44 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cadd"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CAdd(Expander):
22
- """CAdd expander"""
23
-
24
- def _expand(self, graph_builder):
25
- input_x, input_y = self.inputs
26
- if input_x.dtype == input_y.dtype:
27
- x_real = graph_builder.emit('CReal', [input_x])
28
- y_real = graph_builder.emit('CReal', [input_y])
29
- x_imag = graph_builder.emit('CImag', [input_x])
30
- y_imag = graph_builder.emit('CImag', [input_y])
31
- result_real = graph_builder.emit('Add', [x_real, y_real])
32
- result_imag = graph_builder.emit('Add', [x_imag, y_imag])
33
- result = graph_builder.emit('Complex', [result_real, result_imag])
34
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
35
- x_real = graph_builder.emit('CReal', [input_x])
36
- x_imag = graph_builder.emit('CImag', [input_x])
37
- x_real_add_y = graph_builder.emit('Add', [x_real, input_y])
38
- result = graph_builder.emit('Complex', [x_real_add_y, x_imag])
39
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
40
- y_real = graph_builder.emit('CReal', [input_y])
41
- y_imag = graph_builder.emit('CImag', [input_y])
42
- y_real_add_x = graph_builder.emit('Add', [y_real, input_x])
43
- result = graph_builder.emit('Complex', [y_real_add_x, y_imag])
44
- return result