mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (488) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -1
  7. mindspore/_checkparam.py +23 -29
  8. mindspore/_extends/graph_kernel/__init__.py +0 -1
  9. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  10. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  11. mindspore/_extends/graph_kernel/splitter.py +4 -11
  12. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  13. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  14. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  15. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  16. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  17. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  20. mindspore/_extends/parse/__init__.py +13 -15
  21. mindspore/_extends/parse/namespace.py +7 -33
  22. mindspore/_extends/parse/parser.py +67 -72
  23. mindspore/_extends/parse/resources.py +1 -1
  24. mindspore/_extends/parse/standard_method.py +86 -106
  25. mindspore/_extends/parse/trope.py +1 -1
  26. mindspore/_extends/remote/kernel_build_server.py +25 -7
  27. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  28. mindspore/_install_custom.py +43 -0
  29. mindspore/amp.py +47 -11
  30. mindspore/boost/boost.py +1 -8
  31. mindspore/boost/boost_cell_wrapper.py +3 -2
  32. mindspore/boost/grad_accumulation.py +1 -1
  33. mindspore/boost/group_loss_scale_manager.py +8 -7
  34. mindspore/common/__init__.py +5 -3
  35. mindspore/common/_jit_fallback_utils.py +6 -0
  36. mindspore/common/_register_for_adapter.py +2 -0
  37. mindspore/common/_register_for_tensor.py +2 -2
  38. mindspore/common/_stub_tensor.py +13 -0
  39. mindspore/common/_utils.py +29 -0
  40. mindspore/common/api.py +174 -259
  41. mindspore/common/auto_dynamic_shape.py +494 -0
  42. mindspore/common/dtype.py +18 -11
  43. mindspore/common/dump.py +6 -4
  44. mindspore/common/initializer.py +14 -14
  45. mindspore/common/jit_config.py +33 -15
  46. mindspore/common/lazy_inline.py +126 -7
  47. mindspore/common/mindir_util.py +101 -0
  48. mindspore/common/parameter.py +51 -41
  49. mindspore/common/seed.py +4 -4
  50. mindspore/common/sparse_tensor.py +13 -14
  51. mindspore/common/tensor.py +243 -165
  52. mindspore/communication/__init__.py +7 -4
  53. mindspore/communication/_comm_helper.py +83 -4
  54. mindspore/communication/management.py +152 -84
  55. mindspore/config/op_info.config +14 -3
  56. mindspore/context.py +152 -61
  57. mindspore/dataset/__init__.py +5 -5
  58. mindspore/dataset/audio/__init__.py +2 -2
  59. mindspore/dataset/audio/transforms.py +52 -52
  60. mindspore/dataset/callback/ds_callback.py +16 -2
  61. mindspore/dataset/core/config.py +68 -51
  62. mindspore/dataset/engine/cache_client.py +33 -7
  63. mindspore/dataset/engine/datasets.py +250 -112
  64. mindspore/dataset/engine/datasets_audio.py +43 -211
  65. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  66. mindspore/dataset/engine/datasets_text.py +43 -67
  67. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  68. mindspore/dataset/engine/datasets_vision.py +219 -1029
  69. mindspore/dataset/engine/iterators.py +11 -4
  70. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  71. mindspore/dataset/engine/obs/util.py +3 -0
  72. mindspore/dataset/engine/samplers.py +1 -1
  73. mindspore/dataset/engine/validators.py +19 -5
  74. mindspore/dataset/text/__init__.py +3 -3
  75. mindspore/dataset/text/transforms.py +101 -127
  76. mindspore/dataset/text/utils.py +205 -138
  77. mindspore/dataset/transforms/__init__.py +1 -1
  78. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  79. mindspore/dataset/transforms/transforms.py +95 -40
  80. mindspore/dataset/utils/browse_dataset.py +8 -2
  81. mindspore/dataset/utils/line_reader.py +17 -19
  82. mindspore/dataset/vision/__init__.py +3 -3
  83. mindspore/dataset/vision/c_transforms.py +6 -3
  84. mindspore/dataset/vision/transforms.py +409 -287
  85. mindspore/dataset/vision/utils.py +13 -14
  86. mindspore/dataset/vision/validators.py +11 -1
  87. mindspore/dnnl.dll +0 -0
  88. mindspore/experimental/map_parameter.py +14 -0
  89. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  90. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  91. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  92. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  93. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  94. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  95. mindspore/gen_ops.py +273 -0
  96. mindspore/include/OWNERS +0 -1
  97. mindspore/include/api/data_type.h +2 -1
  98. mindspore/include/api/graph.h +0 -15
  99. mindspore/include/api/kernel.h +2 -0
  100. mindspore/include/api/kernel_api.h +37 -12
  101. mindspore/include/api/model.h +17 -14
  102. mindspore/include/api/status.h +8 -3
  103. mindspore/include/api/types.h +37 -4
  104. mindspore/include/c_api/ms/abstract.h +67 -0
  105. mindspore/include/c_api/ms/attribute.h +197 -0
  106. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  107. mindspore/include/c_api/ms/base/macros.h +32 -0
  108. mindspore/include/c_api/ms/base/status.h +33 -0
  109. mindspore/include/c_api/ms/base/types.h +282 -0
  110. mindspore/include/c_api/ms/context.h +102 -0
  111. mindspore/include/c_api/ms/graph.h +160 -0
  112. mindspore/include/c_api/ms/node.h +606 -0
  113. mindspore/include/c_api/ms/tensor.h +161 -0
  114. mindspore/include/c_api/ms/value.h +84 -0
  115. mindspore/include/dataset/constants.h +6 -5
  116. mindspore/include/dataset/execute.h +23 -13
  117. mindspore/include/dataset/text.h +26 -26
  118. mindspore/include/dataset/transforms.h +13 -13
  119. mindspore/include/dataset/vision.h +60 -60
  120. mindspore/include/dataset/vision_ascend.h +5 -6
  121. mindspore/include/dataset/vision_lite.h +17 -17
  122. mindspore/jpeg62.dll +0 -0
  123. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  124. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  125. mindspore/mindspore_backend.dll +0 -0
  126. mindspore/mindspore_common.dll +0 -0
  127. mindspore/mindspore_core.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_shared_lib.dll +0 -0
  130. mindspore/nn/__init__.py +0 -2
  131. mindspore/nn/cell.py +313 -74
  132. mindspore/nn/dynamic_lr.py +21 -21
  133. mindspore/nn/layer/activation.py +22 -30
  134. mindspore/nn/layer/basic.py +15 -13
  135. mindspore/nn/layer/channel_shuffle.py +1 -1
  136. mindspore/nn/layer/container.py +271 -9
  137. mindspore/nn/layer/conv.py +323 -204
  138. mindspore/nn/layer/dense.py +8 -5
  139. mindspore/nn/layer/embedding.py +33 -27
  140. mindspore/nn/layer/flash_attention.py +61 -95
  141. mindspore/nn/layer/image.py +8 -6
  142. mindspore/nn/layer/math.py +16 -25
  143. mindspore/nn/layer/normalization.py +107 -66
  144. mindspore/nn/layer/padding.py +1 -1
  145. mindspore/nn/layer/pooling.py +131 -109
  146. mindspore/nn/layer/rnn_cells.py +27 -22
  147. mindspore/nn/layer/rnns.py +13 -16
  148. mindspore/nn/layer/thor_layer.py +1 -1
  149. mindspore/nn/layer/transformer.py +221 -154
  150. mindspore/nn/learning_rate_schedule.py +9 -1
  151. mindspore/nn/loss/loss.py +235 -174
  152. mindspore/nn/optim/ada_grad.py +2 -1
  153. mindspore/nn/optim/adadelta.py +1 -0
  154. mindspore/nn/optim/adafactor.py +2 -1
  155. mindspore/nn/optim/adam.py +7 -4
  156. mindspore/nn/optim/adamax.py +3 -2
  157. mindspore/nn/optim/adasum.py +2 -2
  158. mindspore/nn/optim/asgd.py +2 -3
  159. mindspore/nn/optim/ftrl.py +6 -5
  160. mindspore/nn/optim/lamb.py +7 -4
  161. mindspore/nn/optim/lars.py +1 -1
  162. mindspore/nn/optim/lazyadam.py +5 -3
  163. mindspore/nn/optim/momentum.py +2 -1
  164. mindspore/nn/optim/optimizer.py +53 -4
  165. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  166. mindspore/nn/optim/rmsprop.py +4 -3
  167. mindspore/nn/optim/rprop.py +23 -12
  168. mindspore/nn/optim/sgd.py +26 -11
  169. mindspore/nn/optim/thor.py +9 -7
  170. mindspore/nn/probability/bijector/bijector.py +5 -5
  171. mindspore/nn/probability/bijector/power_transform.py +27 -27
  172. mindspore/nn/probability/bijector/softplus.py +3 -3
  173. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  174. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  175. mindspore/nn/probability/distribution/beta.py +3 -3
  176. mindspore/nn/probability/distribution/categorical.py +7 -7
  177. mindspore/nn/probability/distribution/cauchy.py +0 -1
  178. mindspore/nn/probability/distribution/distribution.py +3 -3
  179. mindspore/nn/probability/distribution/gamma.py +3 -3
  180. mindspore/nn/probability/distribution/geometric.py +4 -4
  181. mindspore/nn/probability/distribution/gumbel.py +4 -4
  182. mindspore/nn/probability/distribution/log_normal.py +2 -2
  183. mindspore/nn/probability/distribution/logistic.py +2 -2
  184. mindspore/nn/probability/distribution/poisson.py +4 -4
  185. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  186. mindspore/nn/probability/distribution/uniform.py +6 -6
  187. mindspore/nn/wrap/__init__.py +4 -2
  188. mindspore/nn/wrap/cell_wrapper.py +87 -34
  189. mindspore/nn/wrap/grad_reducer.py +8 -5
  190. mindspore/nn/wrap/loss_scale.py +105 -42
  191. mindspore/numpy/array_creations.py +1 -2
  192. mindspore/numpy/array_ops.py +3 -2
  193. mindspore/numpy/utils_const.py +5 -5
  194. mindspore/opencv_core452.dll +0 -0
  195. mindspore/opencv_imgcodecs452.dll +0 -0
  196. mindspore/opencv_imgproc452.dll +0 -0
  197. mindspore/ops/_grad_experimental/__init__.py +0 -5
  198. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  199. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  200. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  201. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  202. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  203. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  204. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  205. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  206. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  207. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  208. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  209. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  210. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  211. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  212. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  213. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  214. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  215. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  216. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  217. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  218. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  219. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  220. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  221. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  222. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  223. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  224. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  225. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  226. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  227. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  228. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  229. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  230. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  231. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  232. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  233. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  234. mindspore/ops/_primitive_cache.py +1 -1
  235. mindspore/ops/_tracefunc.py +45 -13
  236. mindspore/ops/_utils/utils.py +6 -1
  237. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  238. mindspore/ops/_vmap/vmap_base.py +3 -3
  239. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  240. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  241. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  242. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  243. mindspore/ops/arg_dtype_cast.py +54 -0
  244. mindspore/ops/composite/base.py +37 -10
  245. mindspore/ops/composite/math_ops.py +5 -4
  246. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  247. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  248. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  249. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  250. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  251. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  252. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  253. mindspore/ops/deprecated.py +304 -0
  254. mindspore/ops/function/__init__.py +4 -1
  255. mindspore/ops/function/array_func.py +174 -193
  256. mindspore/ops/function/clip_func.py +81 -13
  257. mindspore/ops/function/debug_func.py +1 -1
  258. mindspore/ops/function/grad/grad_func.py +18 -9
  259. mindspore/ops/function/image_func.py +10 -4
  260. mindspore/ops/function/linalg_func.py +5 -5
  261. mindspore/ops/function/math_func.py +575 -386
  262. mindspore/ops/function/nn_func.py +568 -260
  263. mindspore/ops/function/random_func.py +88 -57
  264. mindspore/ops/function/sparse_func.py +1 -1
  265. mindspore/ops/function/sparse_unary_func.py +14 -12
  266. mindspore/ops/function/vmap_func.py +6 -5
  267. mindspore/ops/functional.py +15 -10
  268. mindspore/ops/op_info_register.py +244 -25
  269. mindspore/ops/operations/__init__.py +31 -19
  270. mindspore/ops/operations/_grad_ops.py +71 -7
  271. mindspore/ops/operations/_inner_ops.py +350 -17
  272. mindspore/ops/operations/_quant_ops.py +4 -8
  273. mindspore/ops/operations/_sequence_ops.py +42 -0
  274. mindspore/ops/operations/array_ops.py +68 -282
  275. mindspore/ops/operations/comm_ops.py +107 -59
  276. mindspore/ops/operations/custom_ops.py +94 -70
  277. mindspore/ops/operations/debug_ops.py +8 -4
  278. mindspore/ops/operations/image_ops.py +18 -12
  279. mindspore/ops/operations/inner_ops.py +26 -3
  280. mindspore/ops/operations/math_ops.py +192 -144
  281. mindspore/ops/operations/nn_ops.py +857 -489
  282. mindspore/ops/operations/other_ops.py +0 -22
  283. mindspore/ops/operations/random_ops.py +53 -111
  284. mindspore/ops/operations/sparse_ops.py +3 -1
  285. mindspore/ops/primitive.py +24 -18
  286. mindspore/parallel/_auto_parallel_context.py +68 -8
  287. mindspore/parallel/_cost_model_context.py +2 -2
  288. mindspore/parallel/_offload_context.py +17 -3
  289. mindspore/parallel/_parallel_serialization.py +12 -5
  290. mindspore/parallel/_ps_context.py +12 -0
  291. mindspore/parallel/_tensor.py +18 -13
  292. mindspore/parallel/_transformer/layers.py +5 -3
  293. mindspore/parallel/_transformer/loss.py +1 -0
  294. mindspore/parallel/_transformer/moe.py +2 -2
  295. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  296. mindspore/parallel/_transformer/transformer.py +23 -3
  297. mindspore/parallel/_utils.py +11 -7
  298. mindspore/parallel/algo_parameter_config.py +85 -5
  299. mindspore/parallel/checkpoint_transform.py +19 -12
  300. mindspore/parallel/shard.py +21 -14
  301. mindspore/profiler/common/struct_type.py +3 -3
  302. mindspore/profiler/common/util.py +4 -2
  303. mindspore/profiler/envprofiling.py +1 -1
  304. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  305. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  306. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  307. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  308. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  309. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  310. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  311. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  312. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  313. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  314. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  315. mindspore/profiler/parser/flops_parser.py +15 -11
  316. mindspore/profiler/parser/framework_parser.py +38 -22
  317. mindspore/profiler/parser/hccl_parser.py +16 -12
  318. mindspore/profiler/parser/integrator.py +22 -11
  319. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  320. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  321. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  322. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  323. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  324. mindspore/profiler/parser/optime_parser.py +1 -1
  325. mindspore/profiler/parser/profiler_info.py +21 -2
  326. mindspore/profiler/parser/step_trace_parser.py +11 -14
  327. mindspore/profiler/profiling.py +179 -89
  328. mindspore/rewrite/api/node.py +102 -19
  329. mindspore/rewrite/api/node_type.py +5 -1
  330. mindspore/rewrite/api/pattern_engine.py +1 -1
  331. mindspore/rewrite/api/scoped_value.py +9 -17
  332. mindspore/rewrite/api/symbol_tree.py +131 -47
  333. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  334. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  335. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  336. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  337. mindspore/rewrite/common/rewrite_elog.py +5 -1
  338. mindspore/rewrite/namer.py +33 -24
  339. mindspore/rewrite/namespace.py +14 -5
  340. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  341. mindspore/rewrite/node/call_function.py +79 -0
  342. mindspore/rewrite/node/cell_container.py +135 -0
  343. mindspore/rewrite/node/control_flow.py +88 -0
  344. mindspore/rewrite/{node.py → node/node.py} +273 -234
  345. mindspore/rewrite/node/node_manager.py +254 -0
  346. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  347. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  348. mindspore/rewrite/parsers/assign_parser.py +216 -221
  349. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  350. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  351. mindspore/rewrite/parsers/constant_parser.py +9 -6
  352. mindspore/rewrite/parsers/container_parser.py +9 -7
  353. mindspore/rewrite/parsers/for_parser.py +42 -21
  354. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  355. mindspore/rewrite/parsers/if_parser.py +28 -24
  356. mindspore/rewrite/parsers/module_parser.py +196 -25
  357. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  358. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  359. mindspore/rewrite/parsers/return_parser.py +6 -6
  360. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  361. mindspore/rewrite/sparsify/utils.py +1 -1
  362. mindspore/rewrite/symbol_tree.py +523 -578
  363. mindspore/rewrite/symbol_tree_builder.py +9 -193
  364. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  365. mindspore/run_check/_check_version.py +6 -4
  366. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  367. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  368. mindspore/tinyxml2.dll +0 -0
  369. mindspore/train/_utils.py +7 -3
  370. mindspore/train/amp.py +323 -123
  371. mindspore/train/anf_ir_pb2.py +14 -2
  372. mindspore/train/callback/_backup_and_restore.py +2 -12
  373. mindspore/train/callback/_callback.py +29 -4
  374. mindspore/train/callback/_checkpoint.py +23 -8
  375. mindspore/train/callback/_early_stop.py +2 -2
  376. mindspore/train/callback/_landscape.py +4 -4
  377. mindspore/train/callback/_loss_monitor.py +2 -2
  378. mindspore/train/callback/_on_request_exit.py +2 -2
  379. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  380. mindspore/train/callback/_summary_collector.py +15 -8
  381. mindspore/train/callback/_time_monitor.py +58 -5
  382. mindspore/train/data_sink.py +5 -11
  383. mindspore/train/dataset_helper.py +84 -57
  384. mindspore/train/loss_scale_manager.py +2 -2
  385. mindspore/train/metrics/__init__.py +3 -3
  386. mindspore/train/metrics/cosine_similarity.py +1 -1
  387. mindspore/train/metrics/hausdorff_distance.py +3 -2
  388. mindspore/train/metrics/mean_surface_distance.py +3 -2
  389. mindspore/train/metrics/metric.py +39 -19
  390. mindspore/train/metrics/roc.py +2 -2
  391. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  392. mindspore/train/mind_ir_pb2.py +85 -36
  393. mindspore/train/model.py +187 -47
  394. mindspore/train/serialization.py +487 -161
  395. mindspore/train/summary/_summary_adapter.py +1 -1
  396. mindspore/train/summary/_writer_pool.py +3 -2
  397. mindspore/train/summary/summary_record.py +37 -17
  398. mindspore/train/train_thor/convert_utils.py +3 -3
  399. mindspore/train/train_thor/dataset_helper.py +1 -1
  400. mindspore/turbojpeg.dll +0 -0
  401. mindspore/version.py +1 -1
  402. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
  403. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +406 -463
  404. mindspore/_extends/graph_kernel/expander.py +0 -80
  405. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  406. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  407. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  408. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  409. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  410. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  411. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  412. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  413. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  414. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  415. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  416. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  417. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  418. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  419. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  420. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  421. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  422. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  423. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  424. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  425. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  426. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  427. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  428. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  429. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  430. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  431. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  432. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  433. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  434. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  435. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  436. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  437. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  438. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  439. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  440. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  441. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  442. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  443. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  444. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  445. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  446. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  447. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  448. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  449. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  450. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  451. mindspore/dataset/datapreprocess/__init__.py +0 -20
  452. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  453. mindspore/include/api/net.h +0 -142
  454. mindspore/nn/lr_scheduler.py +0 -262
  455. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  456. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  457. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  458. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  459. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  460. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  461. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  462. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  463. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  464. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  465. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  466. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  467. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  468. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  469. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  470. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  471. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  472. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  473. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  474. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  475. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  476. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  477. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  478. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  479. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  480. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  481. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  482. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  483. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  484. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  485. mindspore/rewrite/node_visitor.py +0 -44
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  487. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  488. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,105 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for BatchNormGrad"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from .expand_dims import ExpandDims
19
-
20
-
21
- @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
22
- @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
24
- @VLD.check_attrs('is_training', 'epsilon')
25
- class BatchNormGrad(Expander):
26
- """BatchNormGrad expander"""
27
-
28
- def _expand(self, graph_builder):
29
- # get op info
30
- input_dy = self.inputs[0]
31
- input_x = self.inputs[1]
32
- input_scale = self.inputs[2]
33
- input_save_mean = self.inputs[3]
34
- input_save_inv_variance = self.inputs[4]
35
-
36
- reduce_axis = ()
37
- shape_x = input_x.shape
38
- if input_x.data_format == DF.NHWC:
39
- reduce_axis = (0, 1, 2)
40
- num = shape_x[0] * shape_x[1] * shape_x[2]
41
- else:
42
- reduce_axis = (0, 2, 3)
43
- num = shape_x[0] * shape_x[2] * shape_x[3]
44
- ori_type = input_x.dtype
45
- if ori_type == 'float16':
46
- input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
47
- if input_dy.dtype == 'float16':
48
- input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
49
- num_rec = -1.0 / num
50
- num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
51
- dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
52
-
53
- # in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass
54
- if self.attrs['is_training']:
55
- inv_variance = input_save_inv_variance
56
- else:
57
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
58
- var_add = graph_builder.emit('Add', [input_save_inv_variance, epsilon_v])
59
- sqrt_var_eps = graph_builder.emit('Sqrt', [var_add])
60
- scalar_one = 1.0
61
- scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
62
- inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
63
-
64
- # compute dgamma
65
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
66
- input_save_mean = graph_builder.emit(
67
- 'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
68
- inv_variance = graph_builder.emit(
69
- 'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
70
- input_scale = graph_builder.emit(
71
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
72
- x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean])
73
- x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance])
74
- dgamma_param = graph_builder.emit('Mul', [input_dy, x_div])
75
- dgamma = graph_builder.emit(
76
- 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
77
-
78
- # compute dx
79
- if self.attrs['is_training']:
80
- tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
81
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
82
- dgamma_expand = graph_builder.emit(
83
- 'Reshape', [dgamma], attrs={'shape': ExpandDims.infer_shape(dgamma.shape, [-1, -1])})
84
- tmp_b = graph_builder.emit(
85
- 'Reshape', [tmp_b], attrs={'shape': ExpandDims.infer_shape(tmp_b.shape, [-1, -1])})
86
- else:
87
- dgamma_expand = dgamma
88
- x_sub_mean_dgamma_mul = graph_builder.emit('Mul', [x_div, dgamma_expand])
89
- tmp_c = graph_builder.emit('Mul', [num_rec_v, x_sub_mean_dgamma_mul])
90
- tmp_ab_add = graph_builder.emit('Add', [input_dy, tmp_b])
91
- tmp_abc_add = graph_builder.emit('Add', [tmp_ab_add, tmp_c])
92
- gamma_mul = graph_builder.emit('Mul', [input_scale, tmp_abc_add])
93
- dx = graph_builder.emit('Mul', [inv_variance, gamma_mul])
94
- else:
95
- y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
96
- dx = graph_builder.emit('Mul', [inv_variance, y_scale])
97
- if ori_type == 'float16':
98
- dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
99
-
100
- # set output tensors' data_format
101
- dx.data_format = self.outputs[0]['format']
102
- dgamma.data_format = self.outputs[1]['format']
103
- dbeta.data_format = self.outputs[2]['format']
104
-
105
- return dx, dgamma, dbeta
@@ -1,33 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for ClipByNormNoDivSum"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class ClipByNormNoDivSum(Expander):
21
- """ClipByNormNoDivSum expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x0, input_x1, input_x2, input_x3 = self.inputs
25
-
26
- # cal result
27
- greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
28
- select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
29
- sqrt_res = graph_builder.emit('Sqrt', [select_res0])
30
- select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
31
- result = graph_builder.emit('Maximum', [select_res1, input_x3])
32
-
33
- return result
@@ -1,30 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cabs"""
16
- from mindspore._extends.graph_kernel.expanders._utils import Expander
17
-
18
-
19
- class CAbs(Expander):
20
- """CAbs expander"""
21
-
22
- def _expand(self, graph_builder):
23
- input_x = self.inputs[0]
24
- x_real = graph_builder.emit('CReal', [input_x])
25
- x_imag = graph_builder.emit('CImag', [input_x])
26
- squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
27
- squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
28
- squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
29
- result = graph_builder.emit('Sqrt', [squre_sum])
30
- return result
@@ -1,44 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cadd"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CAdd(Expander):
22
- """CAdd expander"""
23
-
24
- def _expand(self, graph_builder):
25
- input_x, input_y = self.inputs
26
- if input_x.dtype == input_y.dtype:
27
- x_real = graph_builder.emit('CReal', [input_x])
28
- y_real = graph_builder.emit('CReal', [input_y])
29
- x_imag = graph_builder.emit('CImag', [input_x])
30
- y_imag = graph_builder.emit('CImag', [input_y])
31
- result_real = graph_builder.emit('Add', [x_real, y_real])
32
- result_imag = graph_builder.emit('Add', [x_imag, y_imag])
33
- result = graph_builder.emit('Complex', [result_real, result_imag])
34
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
35
- x_real = graph_builder.emit('CReal', [input_x])
36
- x_imag = graph_builder.emit('CImag', [input_x])
37
- x_real_add_y = graph_builder.emit('Add', [x_real, input_y])
38
- result = graph_builder.emit('Complex', [x_real_add_y, x_imag])
39
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
40
- y_real = graph_builder.emit('CReal', [input_y])
41
- y_imag = graph_builder.emit('CImag', [input_y])
42
- y_real_add_x = graph_builder.emit('Add', [y_real, input_x])
43
- result = graph_builder.emit('Complex', [y_real_add_x, y_imag])
44
- return result
@@ -1,62 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cdiv"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CDiv(Expander):
22
- """CDiv expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CDiv Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
33
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
34
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
35
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
36
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
37
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
38
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
39
- final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
40
- final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
41
- result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
42
- result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
43
- result = graph_builder.emit('Complex', [result_real, result_imag])
44
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
45
- x_real = graph_builder.emit('CReal', [input_x])
46
- x_imag = graph_builder.emit('CImag', [input_x])
47
- x_real_div_y = graph_builder.emit('RealDiv', [x_real, input_y])
48
- x_imag_div_y = graph_builder.emit('RealDiv', [x_imag, input_y])
49
- result = graph_builder.emit('Complex', [x_real_div_y, x_imag_div_y])
50
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
51
- y_real = graph_builder.emit('CReal', [input_y])
52
- y_imag = graph_builder.emit('CImag', [input_y])
53
- neg_y_imag = graph_builder.emit('Neg', [y_imag])
54
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
55
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
56
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
57
- x_mul_y_real = graph_builder.emit('Mul', [input_x, y_real])
58
- x_mul_neg_y_imag = graph_builder.emit('Mul', [input_x, neg_y_imag])
59
- y_real_div_x = graph_builder.emit('RealDiv', [x_mul_y_real, final_denominator])
60
- y_imag_div_x = graph_builder.emit('RealDiv', [x_mul_neg_y_imag, final_denominator])
61
- result = graph_builder.emit('Complex', [y_real_div_x, y_imag_div_x])
62
- return result
@@ -1,52 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for cmul"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CMul(Expander):
22
- """CMul expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CMul Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
33
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
34
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
35
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
36
- result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag])
37
- result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real])
38
- result = graph_builder.emit('Complex', [result_real, result_imag])
39
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
40
- x_real = graph_builder.emit('CReal', [input_x])
41
- x_imag = graph_builder.emit('CImag', [input_x])
42
- x_real_mul_y = graph_builder.emit('Mul', [x_real, input_y])
43
- x_imag_mul_y = graph_builder.emit('Mul', [x_imag, input_y])
44
- result = graph_builder.emit('Complex', [x_real_mul_y, x_imag_mul_y])
45
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
46
- y_real = graph_builder.emit('CReal', [input_y])
47
- y_imag = graph_builder.emit('CImag', [input_y])
48
- y_real_mul_x = graph_builder.emit('Mul', [y_real, input_x])
49
- y_imag_mul_x = graph_builder.emit('Mul', [y_imag, input_x])
50
- result = graph_builder.emit('Complex', [y_real_mul_x, y_imag_mul_x])
51
-
52
- return result
@@ -1,62 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for crealdiv"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CRealDiv(Expander):
22
- """CRealDiv expander"""
23
-
24
- def _expand(self, graph_builder):
25
- """CRealDiv Implementation"""
26
- input_x, input_y = self.inputs
27
- if input_x.dtype == input_y.dtype:
28
- x_real = graph_builder.emit('CReal', [input_x])
29
- y_real = graph_builder.emit('CReal', [input_y])
30
- x_imag = graph_builder.emit('CImag', [input_x])
31
- y_imag = graph_builder.emit('CImag', [input_y])
32
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
33
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
34
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
35
- x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
36
- x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
37
- x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
38
- x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
39
- final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
40
- final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
41
- result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
42
- result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
43
- result = graph_builder.emit('Complex', [result_real, result_imag])
44
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
45
- x_real = graph_builder.emit('CReal', [input_x])
46
- x_imag = graph_builder.emit('CImag', [input_x])
47
- x_real_div_y = graph_builder.emit('RealDiv', [x_real, input_y])
48
- x_imag_div_y = graph_builder.emit('RealDiv', [x_imag, input_y])
49
- result = graph_builder.emit('Complex', [x_real_div_y, x_imag_div_y])
50
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
51
- y_real = graph_builder.emit('CReal', [input_y])
52
- y_imag = graph_builder.emit('CImag', [input_y])
53
- neg_y_imag = graph_builder.emit('Neg', [y_imag])
54
- squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
55
- squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
56
- final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
57
- x_mul_y_real = graph_builder.emit('Mul', [input_x, y_real])
58
- x_mul_neg_y_imag = graph_builder.emit('Mul', [input_x, neg_y_imag])
59
- y_real_div_x = graph_builder.emit('RealDiv', [x_mul_y_real, final_denominator])
60
- y_imag_div_x = graph_builder.emit('RealDiv', [x_mul_neg_y_imag, final_denominator])
61
- result = graph_builder.emit('Complex', [y_real_div_x, y_imag_div_x])
62
- return result
@@ -1,45 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for csub"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class CSub(Expander):
22
- """CSub expander"""
23
-
24
- def _expand(self, graph_builder):
25
- input_x, input_y = self.inputs
26
- if input_x.dtype == input_y.dtype:
27
- x_real = graph_builder.emit('CReal', [input_x])
28
- y_real = graph_builder.emit('CReal', [input_y])
29
- x_imag = graph_builder.emit('CImag', [input_x])
30
- y_imag = graph_builder.emit('CImag', [input_y])
31
- result_real = graph_builder.emit('Sub', [x_real, y_real])
32
- result_imag = graph_builder.emit('Sub', [x_imag, y_imag])
33
- result = graph_builder.emit('Complex', [result_real, result_imag])
34
- elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
35
- x_real = graph_builder.emit('CReal', [input_x])
36
- x_imag = graph_builder.emit('CImag', [input_x])
37
- x_real_sub_y = graph_builder.emit('Sub', [x_real, input_y])
38
- result = graph_builder.emit('Complex', [x_real_sub_y, x_imag])
39
- elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
40
- y_real = graph_builder.emit('CReal', [input_y])
41
- y_imag = graph_builder.emit('CImag', [input_y])
42
- x_sub_y_real = graph_builder.emit('Sub', [input_x, y_real])
43
- y_imag = graph_builder.emit('Neg', [y_imag])
44
- result = graph_builder.emit('Complex', [x_sub_y_real, y_imag])
45
- return result
@@ -1,200 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for Conv2D"""
16
- from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad
17
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
18
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
19
- from ._utils import Expander, ExpanderInfoValidator as VLD
20
-
21
-
22
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.NHWC, DF.NHWC)
24
- @VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
25
- class Conv2D(Expander):
26
- """
27
- Conv2D expander
28
-
29
- Currently, only Conv2D that meets several conditions can be expanded, other cases will be skipped.
30
- Conditions to expand:
31
- inputs are NHWC format and float16.
32
- attr groups and group are 1.
33
- attr dilation are all 1.
34
- N channel of inputs > 16.
35
- C channel of inputs > 8.
36
- output N*H*W are multiplies of 128.
37
- """
38
- M_ALIGN = 32
39
- N_ALIGN = 32
40
- K_ALIGN = 16
41
- K_LIMIT = 800
42
- MNK_LIMIT = 3 * (10 ** 10)
43
- N0_CHANNEL_ALIGN = 32
44
- N1_CHANNEL_ALIGN = 32
45
- C_CHANNEL_ALIGN = 16
46
- OUT_NHW_ALIGN = 128
47
-
48
- def __init__(self, expand_info):
49
- super().__init__(expand_info)
50
- self.dst_type = self.outputs[0]['data_type']
51
- self.dst_format = self.outputs[0]['format']
52
- self.has_pad = False
53
- self.can_optimize_to_matmul = False
54
- self.shape_0_pad = self.inputs[0]['shape']
55
- self.shape_1_pad = self.inputs[1]['shape']
56
- self.m = 0
57
- self.n = 0
58
- self.k = 0
59
-
60
- def _optimize_to_matmul(self):
61
- stride = self.attrs['stride']
62
- dilation = self.attrs['dilation']
63
- _, h, w, _ = self.inputs[1]['shape']
64
- if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
65
- self.m % self.M_ALIGN == 0 and self.n % self.N_ALIGN == 0 and self.k % self.K_ALIGN == 0:
66
- return True
67
- return False
68
-
69
- def _common_check(self):
70
- """common check for inputs and attrs"""
71
- type_0 = self.inputs[0]['data_type']
72
- type_1 = self.inputs[1]['data_type']
73
- if type_0 != "float16" or type_1 != "float16":
74
- raise GKException("For 'Conv2D', inputs data type should be both float16, but got {} and {}"
75
- .format(type_0, type_1))
76
-
77
- formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
78
- check_format_any(formats, DF.NHWC)
79
-
80
- groups = self.attrs['groups']
81
- group = self.attrs['group']
82
- if groups != 1 or group != 1:
83
- raise GKException("For 'Conv2D', value of attr 'groups' and 'group' should be both 1, but got {} and {}."
84
- .format(groups, group))
85
-
86
- dilation = self.attrs['dilation']
87
- check_nd(dilation, 4)
88
- if dilation != [1, 1, 1, 1]:
89
- raise GKException("For 'Conv2D', value of attr 'dilation' should be [1, 1, 1, 1], but got {}"
90
- .format(dilation))
91
-
92
- def _check(self):
93
- self._common_check()
94
-
95
- pad_list = self.attrs['pad_list']
96
- check_nd(pad_list, 4)
97
- self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode'])
98
-
99
- shape_0 = self.inputs[0]['shape']
100
- shape_1 = self.inputs[1]['shape']
101
- stride = self.attrs['stride']
102
- check_nd(shape_0, 4)
103
- check_nd(shape_1, 4)
104
- check_nd(stride, 4)
105
- n0, h0, w0, c0 = shape_0
106
- n1, h1, w1, c1 = shape_1
107
- if (n0 % self.N0_CHANNEL_ALIGN) != 0:
108
- raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}"
109
- .format(self.N0_CHANNEL_ALIGN, n0))
110
- if (n1 % self.N1_CHANNEL_ALIGN) != 0:
111
- raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}"
112
- .format(self.N1_CHANNEL_ALIGN, n1))
113
- if c0 != c1 or (c0 % self.C_CHANNEL_ALIGN) != 0:
114
- raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got "
115
- "{} and {}".format(self.C_CHANNEL_ALIGN, c0, c1))
116
- # n0 pad
117
- n0 = ((n0 + self.N0_CHANNEL_ALIGN - 1) //
118
- self.N0_CHANNEL_ALIGN) * self.N0_CHANNEL_ALIGN
119
- # h0, w0 pad
120
- if self.has_pad:
121
- h0 = h0 + pad_list[0] + pad_list[1]
122
- w0 = w0 + pad_list[2] + pad_list[3]
123
- # c0, c1 pad
124
- c0 = ((c0 + self.C_CHANNEL_ALIGN - 1) // self.C_CHANNEL_ALIGN) * self.C_CHANNEL_ALIGN
125
- c1 = c0
126
- # n1 pad
127
- n1 = ((n1 + self.N1_CHANNEL_ALIGN - 1) //
128
- self.N1_CHANNEL_ALIGN) * self.N1_CHANNEL_ALIGN
129
-
130
- # check if can optimize to matmul
131
- self.m, self.n, self.k = n0 * h0 * w0, n1, c1
132
- self.can_optimize_to_matmul = self._optimize_to_matmul()
133
-
134
- # requirements
135
- if self.can_optimize_to_matmul:
136
- if self.k > self.K_LIMIT:
137
- raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got "
138
- "{}".format(self.K_LIMIT, self.k))
139
- if self.m * self.n * self.k >= self.MNK_LIMIT:
140
- raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than "
141
- "{}, but got {}".format(self.MNK_LIMIT, self.m * self.n * self.k))
142
- else:
143
- out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
144
- if ((n0 * out_h * out_w) % self.OUT_NHW_ALIGN) != 0:
145
- raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}"
146
- .format(n0, out_h, out_w, self.OUT_NHW_ALIGN))
147
- if stride != [1, 1, 2, 2]:
148
- raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}"
149
- .format(stride))
150
-
151
- self.shape_0_pad = [n0, h0, w0, c0]
152
- self.shape_1_pad = [n1, h1, w1, c1]
153
-
154
- def _expand(self, graph_builder):
155
- input_0 = self.inputs[0]
156
- input_1 = self.inputs[1]
157
- n0, _, _, c0 = input_0.shape
158
- n1, _, _, c1 = input_1.shape
159
- n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
160
- n1_p, _, _, c1_p = self.shape_1_pad
161
-
162
- pad_value = 0
163
- # input0 pad
164
- input_0_pad_before = [0, 0, 0, 0]
165
- input_0_pad_after = [0, 0, 0, 0]
166
- if self.has_pad:
167
- pad_list = self.attrs['pad_list']
168
- input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
169
- input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
170
- input_0_pad_after[0] = n0_p - n0
171
- input_0_pad_after[3] = c0_p - c0
172
- if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
173
- input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
174
- 'tail': input_0_pad_after,
175
- 'pad_val': pad_value})
176
- # input1 pad
177
- input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
178
- if input_1_pad_after != [0, 0, 0, 0]:
179
- input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
180
- 'tail': input_1_pad_after,
181
- 'pad_val': pad_value})
182
- if self.can_optimize_to_matmul:
183
- a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
184
- b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
185
- c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
186
- 'transpose_b': True,
187
- 'dst_type': self.dst_type})
188
- result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
189
- 'format': self.dst_format})
190
- else:
191
- attrs = self.attrs
192
- attrs['pad_list'] = [0, 0, 0, 0]
193
- attrs['dst_type'] = self.dst_type
194
- result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
195
- # unpad
196
- unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]
197
- if unpad_after != [0, 0, 0, 0]:
198
- result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after})
199
-
200
- return result