mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.11__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (488) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_check_jit_forbidden_api.py +3 -1
  7. mindspore/_checkparam.py +23 -29
  8. mindspore/_extends/graph_kernel/__init__.py +0 -1
  9. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  10. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  11. mindspore/_extends/graph_kernel/splitter.py +4 -11
  12. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  13. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  14. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  15. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  16. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  17. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  20. mindspore/_extends/parse/__init__.py +13 -15
  21. mindspore/_extends/parse/namespace.py +7 -33
  22. mindspore/_extends/parse/parser.py +67 -72
  23. mindspore/_extends/parse/resources.py +1 -1
  24. mindspore/_extends/parse/standard_method.py +86 -106
  25. mindspore/_extends/parse/trope.py +1 -1
  26. mindspore/_extends/remote/kernel_build_server.py +25 -7
  27. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  28. mindspore/_install_custom.py +43 -0
  29. mindspore/amp.py +47 -11
  30. mindspore/boost/boost.py +1 -8
  31. mindspore/boost/boost_cell_wrapper.py +3 -2
  32. mindspore/boost/grad_accumulation.py +1 -1
  33. mindspore/boost/group_loss_scale_manager.py +8 -7
  34. mindspore/common/__init__.py +5 -3
  35. mindspore/common/_jit_fallback_utils.py +6 -0
  36. mindspore/common/_register_for_adapter.py +2 -0
  37. mindspore/common/_register_for_tensor.py +2 -2
  38. mindspore/common/_stub_tensor.py +13 -0
  39. mindspore/common/_utils.py +29 -0
  40. mindspore/common/api.py +174 -259
  41. mindspore/common/auto_dynamic_shape.py +494 -0
  42. mindspore/common/dtype.py +18 -11
  43. mindspore/common/dump.py +6 -4
  44. mindspore/common/initializer.py +14 -14
  45. mindspore/common/jit_config.py +33 -15
  46. mindspore/common/lazy_inline.py +126 -7
  47. mindspore/common/mindir_util.py +101 -0
  48. mindspore/common/parameter.py +51 -41
  49. mindspore/common/seed.py +4 -4
  50. mindspore/common/sparse_tensor.py +13 -14
  51. mindspore/common/tensor.py +243 -165
  52. mindspore/communication/__init__.py +7 -4
  53. mindspore/communication/_comm_helper.py +83 -4
  54. mindspore/communication/management.py +152 -84
  55. mindspore/config/op_info.config +14 -3
  56. mindspore/context.py +152 -61
  57. mindspore/dataset/__init__.py +5 -5
  58. mindspore/dataset/audio/__init__.py +2 -2
  59. mindspore/dataset/audio/transforms.py +52 -52
  60. mindspore/dataset/callback/ds_callback.py +16 -2
  61. mindspore/dataset/core/config.py +68 -51
  62. mindspore/dataset/engine/cache_client.py +33 -7
  63. mindspore/dataset/engine/datasets.py +250 -112
  64. mindspore/dataset/engine/datasets_audio.py +43 -211
  65. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  66. mindspore/dataset/engine/datasets_text.py +43 -67
  67. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  68. mindspore/dataset/engine/datasets_vision.py +219 -1029
  69. mindspore/dataset/engine/iterators.py +11 -4
  70. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  71. mindspore/dataset/engine/obs/util.py +3 -0
  72. mindspore/dataset/engine/samplers.py +1 -1
  73. mindspore/dataset/engine/validators.py +19 -5
  74. mindspore/dataset/text/__init__.py +3 -3
  75. mindspore/dataset/text/transforms.py +101 -127
  76. mindspore/dataset/text/utils.py +205 -138
  77. mindspore/dataset/transforms/__init__.py +1 -1
  78. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  79. mindspore/dataset/transforms/transforms.py +95 -40
  80. mindspore/dataset/utils/browse_dataset.py +8 -2
  81. mindspore/dataset/utils/line_reader.py +17 -19
  82. mindspore/dataset/vision/__init__.py +3 -3
  83. mindspore/dataset/vision/c_transforms.py +6 -3
  84. mindspore/dataset/vision/transforms.py +409 -287
  85. mindspore/dataset/vision/utils.py +13 -14
  86. mindspore/dataset/vision/validators.py +11 -1
  87. mindspore/dnnl.dll +0 -0
  88. mindspore/experimental/map_parameter.py +14 -0
  89. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  90. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  91. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  92. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  93. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  94. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  95. mindspore/gen_ops.py +273 -0
  96. mindspore/include/OWNERS +0 -1
  97. mindspore/include/api/data_type.h +2 -1
  98. mindspore/include/api/graph.h +0 -15
  99. mindspore/include/api/kernel.h +2 -0
  100. mindspore/include/api/kernel_api.h +37 -12
  101. mindspore/include/api/model.h +17 -14
  102. mindspore/include/api/status.h +8 -3
  103. mindspore/include/api/types.h +37 -4
  104. mindspore/include/c_api/ms/abstract.h +67 -0
  105. mindspore/include/c_api/ms/attribute.h +197 -0
  106. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  107. mindspore/include/c_api/ms/base/macros.h +32 -0
  108. mindspore/include/c_api/ms/base/status.h +33 -0
  109. mindspore/include/c_api/ms/base/types.h +282 -0
  110. mindspore/include/c_api/ms/context.h +102 -0
  111. mindspore/include/c_api/ms/graph.h +160 -0
  112. mindspore/include/c_api/ms/node.h +606 -0
  113. mindspore/include/c_api/ms/tensor.h +161 -0
  114. mindspore/include/c_api/ms/value.h +84 -0
  115. mindspore/include/dataset/constants.h +6 -5
  116. mindspore/include/dataset/execute.h +23 -13
  117. mindspore/include/dataset/text.h +26 -26
  118. mindspore/include/dataset/transforms.h +13 -13
  119. mindspore/include/dataset/vision.h +60 -60
  120. mindspore/include/dataset/vision_ascend.h +5 -6
  121. mindspore/include/dataset/vision_lite.h +17 -17
  122. mindspore/jpeg62.dll +0 -0
  123. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  124. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  125. mindspore/mindspore_backend.dll +0 -0
  126. mindspore/mindspore_common.dll +0 -0
  127. mindspore/mindspore_core.dll +0 -0
  128. mindspore/mindspore_glog.dll +0 -0
  129. mindspore/mindspore_shared_lib.dll +0 -0
  130. mindspore/nn/__init__.py +0 -2
  131. mindspore/nn/cell.py +313 -74
  132. mindspore/nn/dynamic_lr.py +21 -21
  133. mindspore/nn/layer/activation.py +22 -30
  134. mindspore/nn/layer/basic.py +15 -13
  135. mindspore/nn/layer/channel_shuffle.py +1 -1
  136. mindspore/nn/layer/container.py +271 -9
  137. mindspore/nn/layer/conv.py +323 -204
  138. mindspore/nn/layer/dense.py +8 -5
  139. mindspore/nn/layer/embedding.py +33 -27
  140. mindspore/nn/layer/flash_attention.py +61 -95
  141. mindspore/nn/layer/image.py +8 -6
  142. mindspore/nn/layer/math.py +16 -25
  143. mindspore/nn/layer/normalization.py +107 -66
  144. mindspore/nn/layer/padding.py +1 -1
  145. mindspore/nn/layer/pooling.py +131 -109
  146. mindspore/nn/layer/rnn_cells.py +27 -22
  147. mindspore/nn/layer/rnns.py +13 -16
  148. mindspore/nn/layer/thor_layer.py +1 -1
  149. mindspore/nn/layer/transformer.py +221 -154
  150. mindspore/nn/learning_rate_schedule.py +9 -1
  151. mindspore/nn/loss/loss.py +235 -174
  152. mindspore/nn/optim/ada_grad.py +2 -1
  153. mindspore/nn/optim/adadelta.py +1 -0
  154. mindspore/nn/optim/adafactor.py +2 -1
  155. mindspore/nn/optim/adam.py +7 -4
  156. mindspore/nn/optim/adamax.py +3 -2
  157. mindspore/nn/optim/adasum.py +2 -2
  158. mindspore/nn/optim/asgd.py +2 -3
  159. mindspore/nn/optim/ftrl.py +6 -5
  160. mindspore/nn/optim/lamb.py +7 -4
  161. mindspore/nn/optim/lars.py +1 -1
  162. mindspore/nn/optim/lazyadam.py +5 -3
  163. mindspore/nn/optim/momentum.py +2 -1
  164. mindspore/nn/optim/optimizer.py +53 -4
  165. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  166. mindspore/nn/optim/rmsprop.py +4 -3
  167. mindspore/nn/optim/rprop.py +23 -12
  168. mindspore/nn/optim/sgd.py +26 -11
  169. mindspore/nn/optim/thor.py +9 -7
  170. mindspore/nn/probability/bijector/bijector.py +5 -5
  171. mindspore/nn/probability/bijector/power_transform.py +27 -27
  172. mindspore/nn/probability/bijector/softplus.py +3 -3
  173. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  174. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  175. mindspore/nn/probability/distribution/beta.py +3 -3
  176. mindspore/nn/probability/distribution/categorical.py +7 -7
  177. mindspore/nn/probability/distribution/cauchy.py +0 -1
  178. mindspore/nn/probability/distribution/distribution.py +3 -3
  179. mindspore/nn/probability/distribution/gamma.py +3 -3
  180. mindspore/nn/probability/distribution/geometric.py +4 -4
  181. mindspore/nn/probability/distribution/gumbel.py +4 -4
  182. mindspore/nn/probability/distribution/log_normal.py +2 -2
  183. mindspore/nn/probability/distribution/logistic.py +2 -2
  184. mindspore/nn/probability/distribution/poisson.py +4 -4
  185. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  186. mindspore/nn/probability/distribution/uniform.py +6 -6
  187. mindspore/nn/wrap/__init__.py +4 -2
  188. mindspore/nn/wrap/cell_wrapper.py +87 -34
  189. mindspore/nn/wrap/grad_reducer.py +8 -5
  190. mindspore/nn/wrap/loss_scale.py +105 -42
  191. mindspore/numpy/array_creations.py +1 -2
  192. mindspore/numpy/array_ops.py +3 -2
  193. mindspore/numpy/utils_const.py +5 -5
  194. mindspore/opencv_core452.dll +0 -0
  195. mindspore/opencv_imgcodecs452.dll +0 -0
  196. mindspore/opencv_imgproc452.dll +0 -0
  197. mindspore/ops/_grad_experimental/__init__.py +0 -5
  198. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  199. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  200. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  201. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  202. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  203. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  204. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  205. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  206. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  207. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  208. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  209. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  210. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  211. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  212. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  213. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  214. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  215. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  216. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  217. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  218. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  219. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  220. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  221. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  222. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  223. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  224. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  225. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  226. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  227. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  228. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  229. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  230. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  231. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  232. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  233. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  234. mindspore/ops/_primitive_cache.py +1 -1
  235. mindspore/ops/_tracefunc.py +45 -13
  236. mindspore/ops/_utils/utils.py +6 -1
  237. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  238. mindspore/ops/_vmap/vmap_base.py +3 -3
  239. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  240. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  241. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  242. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  243. mindspore/ops/arg_dtype_cast.py +54 -0
  244. mindspore/ops/composite/base.py +37 -10
  245. mindspore/ops/composite/math_ops.py +5 -4
  246. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  247. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  248. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  249. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  250. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  251. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  252. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  253. mindspore/ops/deprecated.py +304 -0
  254. mindspore/ops/function/__init__.py +4 -1
  255. mindspore/ops/function/array_func.py +174 -193
  256. mindspore/ops/function/clip_func.py +81 -13
  257. mindspore/ops/function/debug_func.py +1 -1
  258. mindspore/ops/function/grad/grad_func.py +18 -9
  259. mindspore/ops/function/image_func.py +10 -4
  260. mindspore/ops/function/linalg_func.py +5 -5
  261. mindspore/ops/function/math_func.py +575 -386
  262. mindspore/ops/function/nn_func.py +568 -260
  263. mindspore/ops/function/random_func.py +88 -57
  264. mindspore/ops/function/sparse_func.py +1 -1
  265. mindspore/ops/function/sparse_unary_func.py +14 -12
  266. mindspore/ops/function/vmap_func.py +6 -5
  267. mindspore/ops/functional.py +15 -10
  268. mindspore/ops/op_info_register.py +244 -25
  269. mindspore/ops/operations/__init__.py +31 -19
  270. mindspore/ops/operations/_grad_ops.py +71 -7
  271. mindspore/ops/operations/_inner_ops.py +350 -17
  272. mindspore/ops/operations/_quant_ops.py +4 -8
  273. mindspore/ops/operations/_sequence_ops.py +42 -0
  274. mindspore/ops/operations/array_ops.py +68 -282
  275. mindspore/ops/operations/comm_ops.py +107 -59
  276. mindspore/ops/operations/custom_ops.py +94 -70
  277. mindspore/ops/operations/debug_ops.py +8 -4
  278. mindspore/ops/operations/image_ops.py +18 -12
  279. mindspore/ops/operations/inner_ops.py +26 -3
  280. mindspore/ops/operations/math_ops.py +192 -144
  281. mindspore/ops/operations/nn_ops.py +857 -489
  282. mindspore/ops/operations/other_ops.py +0 -22
  283. mindspore/ops/operations/random_ops.py +53 -111
  284. mindspore/ops/operations/sparse_ops.py +3 -1
  285. mindspore/ops/primitive.py +24 -18
  286. mindspore/parallel/_auto_parallel_context.py +68 -8
  287. mindspore/parallel/_cost_model_context.py +2 -2
  288. mindspore/parallel/_offload_context.py +17 -3
  289. mindspore/parallel/_parallel_serialization.py +12 -5
  290. mindspore/parallel/_ps_context.py +12 -0
  291. mindspore/parallel/_tensor.py +18 -13
  292. mindspore/parallel/_transformer/layers.py +5 -3
  293. mindspore/parallel/_transformer/loss.py +1 -0
  294. mindspore/parallel/_transformer/moe.py +2 -2
  295. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  296. mindspore/parallel/_transformer/transformer.py +23 -3
  297. mindspore/parallel/_utils.py +11 -7
  298. mindspore/parallel/algo_parameter_config.py +85 -5
  299. mindspore/parallel/checkpoint_transform.py +19 -12
  300. mindspore/parallel/shard.py +21 -14
  301. mindspore/profiler/common/struct_type.py +3 -3
  302. mindspore/profiler/common/util.py +4 -2
  303. mindspore/profiler/envprofiling.py +1 -1
  304. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  305. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  306. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  307. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  308. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  309. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  310. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  311. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  312. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  313. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  314. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  315. mindspore/profiler/parser/flops_parser.py +15 -11
  316. mindspore/profiler/parser/framework_parser.py +38 -22
  317. mindspore/profiler/parser/hccl_parser.py +16 -12
  318. mindspore/profiler/parser/integrator.py +22 -11
  319. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  320. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  321. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  322. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  323. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  324. mindspore/profiler/parser/optime_parser.py +1 -1
  325. mindspore/profiler/parser/profiler_info.py +21 -2
  326. mindspore/profiler/parser/step_trace_parser.py +11 -14
  327. mindspore/profiler/profiling.py +179 -89
  328. mindspore/rewrite/api/node.py +102 -19
  329. mindspore/rewrite/api/node_type.py +5 -1
  330. mindspore/rewrite/api/pattern_engine.py +1 -1
  331. mindspore/rewrite/api/scoped_value.py +9 -17
  332. mindspore/rewrite/api/symbol_tree.py +131 -47
  333. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  334. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  335. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  336. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  337. mindspore/rewrite/common/rewrite_elog.py +5 -1
  338. mindspore/rewrite/namer.py +33 -24
  339. mindspore/rewrite/namespace.py +14 -5
  340. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  341. mindspore/rewrite/node/call_function.py +79 -0
  342. mindspore/rewrite/node/cell_container.py +135 -0
  343. mindspore/rewrite/node/control_flow.py +88 -0
  344. mindspore/rewrite/{node.py → node/node.py} +273 -234
  345. mindspore/rewrite/node/node_manager.py +254 -0
  346. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  347. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  348. mindspore/rewrite/parsers/assign_parser.py +216 -221
  349. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  350. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  351. mindspore/rewrite/parsers/constant_parser.py +9 -6
  352. mindspore/rewrite/parsers/container_parser.py +9 -7
  353. mindspore/rewrite/parsers/for_parser.py +42 -21
  354. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  355. mindspore/rewrite/parsers/if_parser.py +28 -24
  356. mindspore/rewrite/parsers/module_parser.py +196 -25
  357. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  358. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  359. mindspore/rewrite/parsers/return_parser.py +6 -6
  360. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  361. mindspore/rewrite/sparsify/utils.py +1 -1
  362. mindspore/rewrite/symbol_tree.py +523 -578
  363. mindspore/rewrite/symbol_tree_builder.py +9 -193
  364. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  365. mindspore/run_check/_check_version.py +6 -4
  366. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  367. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  368. mindspore/tinyxml2.dll +0 -0
  369. mindspore/train/_utils.py +7 -3
  370. mindspore/train/amp.py +323 -123
  371. mindspore/train/anf_ir_pb2.py +14 -2
  372. mindspore/train/callback/_backup_and_restore.py +2 -12
  373. mindspore/train/callback/_callback.py +29 -4
  374. mindspore/train/callback/_checkpoint.py +23 -8
  375. mindspore/train/callback/_early_stop.py +2 -2
  376. mindspore/train/callback/_landscape.py +4 -4
  377. mindspore/train/callback/_loss_monitor.py +2 -2
  378. mindspore/train/callback/_on_request_exit.py +2 -2
  379. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  380. mindspore/train/callback/_summary_collector.py +15 -8
  381. mindspore/train/callback/_time_monitor.py +58 -5
  382. mindspore/train/data_sink.py +5 -11
  383. mindspore/train/dataset_helper.py +84 -57
  384. mindspore/train/loss_scale_manager.py +2 -2
  385. mindspore/train/metrics/__init__.py +3 -3
  386. mindspore/train/metrics/cosine_similarity.py +1 -1
  387. mindspore/train/metrics/hausdorff_distance.py +3 -2
  388. mindspore/train/metrics/mean_surface_distance.py +3 -2
  389. mindspore/train/metrics/metric.py +39 -19
  390. mindspore/train/metrics/roc.py +2 -2
  391. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  392. mindspore/train/mind_ir_pb2.py +85 -36
  393. mindspore/train/model.py +187 -47
  394. mindspore/train/serialization.py +487 -161
  395. mindspore/train/summary/_summary_adapter.py +1 -1
  396. mindspore/train/summary/_writer_pool.py +3 -2
  397. mindspore/train/summary/summary_record.py +37 -17
  398. mindspore/train/train_thor/convert_utils.py +3 -3
  399. mindspore/train/train_thor/dataset_helper.py +1 -1
  400. mindspore/turbojpeg.dll +0 -0
  401. mindspore/version.py +1 -1
  402. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
  403. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +406 -463
  404. mindspore/_extends/graph_kernel/expander.py +0 -80
  405. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  406. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  407. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  408. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  409. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  410. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  411. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  412. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  413. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  414. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  415. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  416. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  417. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  418. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  419. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  420. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  421. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  422. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  423. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  424. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  425. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  426. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  427. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  428. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  429. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  430. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  431. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  432. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  433. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  434. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  435. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  436. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  437. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  438. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  439. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  440. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  441. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  442. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  443. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  444. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  445. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  446. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  447. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  448. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  449. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  450. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  451. mindspore/dataset/datapreprocess/__init__.py +0 -20
  452. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  453. mindspore/include/api/net.h +0 -142
  454. mindspore/nn/lr_scheduler.py +0 -262
  455. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  456. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  457. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  458. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  459. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  460. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  461. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  462. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  463. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  464. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  465. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  466. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  467. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  468. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  469. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  470. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  471. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  472. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  473. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  474. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  475. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  476. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  477. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  478. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  479. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  480. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  481. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  482. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  483. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  484. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  485. mindspore/rewrite/node_visitor.py +0 -44
  486. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  487. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  488. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,46 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for LogSoftmax"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT)
21
- @VLD.check_attrs('axis')
22
- class LogSoftmax(Expander):
23
- """LogSoftmax expander"""
24
-
25
- def _expand(self, graph_builder):
26
- input_x = self.inputs[0]
27
- axis = self.attrs['axis']
28
- processor = self.processor
29
-
30
- if isinstance(axis, int):
31
- axis = (axis,)
32
-
33
- ori_dtype = input_x.dtype
34
- if ori_dtype != "float16" and processor == "aicore":
35
- input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
36
- max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
37
- max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
38
- else:
39
- max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
40
- data_sub = graph_builder.emit('Sub', [input_x, max_x])
41
- data_exp = graph_builder.emit('Exp', [data_sub])
42
- data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
43
- log_expsum = graph_builder.emit('Log', [data_expsum])
44
- result = graph_builder.emit('Sub', [data_sub, log_expsum])
45
-
46
- return result
@@ -1,36 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for LogSoftmaxGrad"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- @VLD.check_attrs('axis')
22
- class LogSoftmaxGrad(Expander):
23
- """LogSoftmaxGrad expander"""
24
-
25
- def _expand(self, graph_builder):
26
- input_logits, input_dy = self.inputs
27
- axis = self.attrs['axis']
28
- if isinstance(axis, int):
29
- axis = (axis,)
30
-
31
- softmax = graph_builder.emit('Exp', [input_logits])
32
- dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True})
33
- mul_result = graph_builder.emit('Mul', [softmax, dy_sum])
34
- result = graph_builder.emit('Sub', [input_dy, mul_result])
35
-
36
- return result
@@ -1,80 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for BatchMatMul and MatMul"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
18
- from ._utils import Expander, ExpanderInfoValidator as VLD
19
-
20
-
21
- @VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
22
- class MatMul(Expander):
23
- """
24
- MatMul expander
25
- """
26
-
27
- def __init__(self, expand_info):
28
- super(MatMul, self).__init__(expand_info)
29
- self.shape_a = self.inputs[0]['shape']
30
- self.shape_b = self.inputs[1]['shape']
31
- self.transpose_a = False
32
- self.transpose_b = False
33
- self.left_format = ""
34
- self.right_format = ""
35
-
36
- def _optimize_to_mul(self):
37
- """check if matmul can be replace by mul"""
38
- if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT:
39
- return False
40
- k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1]
41
- k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2]
42
- if k_a != 1 or k_b != 1:
43
- return False
44
- return True
45
-
46
- def _check(self):
47
- input_num = len(self.inputs)
48
- if input_num < 2:
49
- raise GKException("For 'MatMul', inputs number should bigger than 1, but got {}.".format(input_num))
50
-
51
- def _expand(self, graph_builder):
52
- self.transpose_a = self.attrs['transpose_a']
53
- self.transpose_b = self.attrs['transpose_b']
54
- self.left_format = self.attrs['left_format']
55
- self.right_format = self.attrs['right_format']
56
-
57
- def transpose(shape):
58
- trans_shape = list(shape)
59
- trans_shape[-2] = shape[-1]
60
- trans_shape[-1] = shape[-2]
61
- return trans_shape
62
- if not self._optimize_to_mul():
63
- raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
64
- # Matmul is replaced by Mul([b m k], [b k n]) when k==1
65
- input_a = self.inputs[0]
66
- input_b = self.inputs[1]
67
- if self.transpose_a:
68
- shape_a_trans = transpose(self.shape_a)
69
- input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
70
- if self.transpose_b:
71
- shape_b_trans = transpose(self.shape_b)
72
- input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
73
- result = graph_builder.emit('Mul', [input_a, input_b])
74
- if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
75
- result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
76
- return result
77
-
78
-
79
- class BatchMatMul(MatMul):
80
- """BatchMatMul expander"""
@@ -1,59 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for maximum_grad"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from .minimum_grad import MinimumGrad
19
-
20
-
21
- @VLD.check_all_formats_same
22
- class MaximumGrad(Expander):
23
- """MaximumGrad expander"""
24
-
25
- def _check(self):
26
- if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
27
- raise GKException("For 'MaximumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
28
- "{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
29
- return super()._check()
30
-
31
- def _expand(self, graph_builder):
32
- input_x, input_y, input_dout = self.inputs
33
- ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
34
- ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
35
- dx = graph_builder.emit('Mul', [ge_result, input_dout])
36
- dy = graph_builder.emit('Sub', [input_dout, dx])
37
-
38
- reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape)
39
- reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape)
40
- if reduce_axis_x:
41
- dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
42
- if dx_reduce.shape != input_x.shape:
43
- dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
44
- else:
45
- dx_out = dx_reduce
46
- else:
47
- dx_out = dx
48
-
49
- if reduce_axis_y:
50
- dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
51
- if dy_reduce.shape != input_y.shape:
52
- dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
53
- else:
54
- dy_out = dy_reduce
55
- else:
56
- dy_out = dy
57
-
58
- # output two results, regardless of grad_x and grad_y
59
- return dx_out, dy_out
@@ -1,80 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for minimum_grad"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class MinimumGrad(Expander):
22
- """MinimumGrad expander"""
23
-
24
- def _check(self):
25
- if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
26
- raise GKException("For 'MinimumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
27
- "{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
28
- return super(MinimumGrad, self)._check()
29
-
30
- def _expand(self, graph_builder):
31
- input_x, input_y, input_dout = self.inputs
32
-
33
- le_result = graph_builder.emit('LessEqual', [input_x, input_y])
34
- le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype})
35
- dx = graph_builder.emit('Mul', [le_result, input_dout])
36
- dy = graph_builder.emit('Sub', [input_dout, dx])
37
-
38
- # for minimumgrad op, output_shape should be equal to input_shape,
39
- # but some elementwise operating may broadcast input_shape
40
- # then output_shape not equal to original input_shape, so need to reduce output to let them equal
41
- reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
42
- reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
43
- if reduce_axis_x:
44
- dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
45
- if dx_reduce.shape != input_x.shape:
46
- dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
47
- else:
48
- dx_out = dx_reduce
49
- else:
50
- dx_out = dx
51
-
52
- if reduce_axis_y:
53
- dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
54
- if dy_reduce.shape != input_y.shape:
55
- dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
56
- else:
57
- dy_out = dy_reduce
58
- else:
59
- dy_out = dy
60
-
61
- # output two results, regardless of grad_x and grad_y
62
- return dx_out, dy_out
63
-
64
- @staticmethod
65
- def get_reduce_axis(original_shape, broadcast_shape):
66
- """compute reduce axis for final output_shape"""
67
- if len(original_shape) > len(broadcast_shape):
68
- raise ValueError("For 'MinimumGrad', the length of original_shape should be less than or equal to the "
69
- "length of broadcast_shape, but got {} and {}".format(original_shape, broadcast_shape))
70
-
71
- tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape
72
- reduce_axis = []
73
- for idx, _ in enumerate(tmp_shape):
74
- if tmp_shape[idx] != broadcast_shape[idx]:
75
- if tmp_shape[idx] == 1:
76
- reduce_axis.append(idx)
77
- else:
78
- raise ValueError("For 'MinimumGrad', original_shape {} and broadcast_shape {} can not broadcast."
79
- .format(original_shape, broadcast_shape))
80
- return reduce_axis
@@ -1,26 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for OnesLike"""
16
- from ._utils import Expander
17
-
18
-
19
- class OnesLike(Expander):
20
- """OnesLike expander"""
21
-
22
- def _expand(self, graph_builder):
23
- input_x = self.inputs[0]
24
- const_one = graph_builder.value(input_x.dtype, 1)
25
- result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape})
26
- return result
@@ -1,43 +0,0 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for reduce_mean"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT)
21
- @VLD.check_attrs('axis', 'keep_dims')
22
- class ReduceMean(Expander):
23
- """ReduceMean expander"""
24
-
25
- def _expand(self, graph_builder):
26
- x = self.inputs[0]
27
- axis = self.attrs['axis']
28
- keep_dims = self.attrs['keep_dims']
29
-
30
- if not isinstance(axis, (tuple, list)):
31
- axis = (axis,)
32
- elif not axis:
33
- axis = list(range(len(x.shape)))
34
- reduce_size = 1.0
35
- for idx in axis:
36
- reduce_size *= x.shape[idx]
37
-
38
- reduce_size_value = graph_builder.value(x.dtype, reduce_size)
39
-
40
- sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
41
- result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value])
42
-
43
- return result
@@ -1,32 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for relu_grad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class ReluGrad(Expander):
21
- """ReLU expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x = self.inputs[0]
25
- input_y = self.inputs[1]
26
-
27
- const_zero = graph_builder.value(input_y.dtype, 0)
28
- ge_result = graph_builder.emit('Greater', [input_y, const_zero])
29
- ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
30
- result = graph_builder.emit('Mul', [ge_result, input_x])
31
-
32
- return result
@@ -1,41 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidCrossEntropyWithLogits"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidCrossEntropyWithLogits(Expander):
21
- """SigmoidCrossEntropyWithLogits expander"""
22
-
23
- def _expand(self, graph_builder):
24
- logits, labels = self.inputs
25
- # Calculate sigmoid_cross_entropy_with_logits(logits, labels)
26
- # formula of sigmoid_cross_entropy_with_logits is:
27
- # -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits)))
28
- # To ensure stability and avoid overflow, the formula equal to :
29
- # max(logits, 0) - logits * labels + log(1 + exp(-abs(logits)))
30
- const_one = graph_builder.value(logits.dtype, 1.0)
31
- const_zero = graph_builder.value(logits.dtype, 0.0)
32
- max_logits = graph_builder.emit('Maximum', [logits, const_zero])
33
- logits_mul_labels = graph_builder.emit('Mul', [logits, labels])
34
- abs_logits = graph_builder.emit('Abs', [logits])
35
- neg_abs_logits = graph_builder.emit('Neg', [abs_logits])
36
- exp_neg_abs_logits = graph_builder.emit('Exp', [neg_abs_logits])
37
- one_add_exp_neg_abs_logits = graph_builder.emit('Add', [const_one, exp_neg_abs_logits])
38
- log_one_add_exp_neg_abs_logits = graph_builder.emit('Log', [one_add_exp_neg_abs_logits])
39
- res_tmp = graph_builder.emit('Sub', [max_logits, logits_mul_labels])
40
- res = graph_builder.emit('Add', [res_tmp, log_one_add_exp_neg_abs_logits])
41
- return res
@@ -1,35 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidCrossEntropyWithLogitsGrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidCrossEntropyWithLogitsGrad(Expander):
21
- """SigmoidCrossEntropyWithLogitsGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- logits, label, dout = self.inputs
25
- # Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
26
- # formula of sigmoid_cross_entropy_with_logits_grad is :
27
- # (sigmoid(logits) - label) * dout
28
- const_one = graph_builder.value(logits.dtype, 1.0)
29
- neg_x = graph_builder.emit('Neg', [logits])
30
- exp_neg_x = graph_builder.emit('Exp', [neg_x])
31
- add_exp = graph_builder.emit('Add', [const_one, exp_neg_x])
32
- sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp])
33
- sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label])
34
- res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout])
35
- return res
@@ -1,31 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SigmoidGrad"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_all_formats_same
20
- class SigmoidGrad(Expander):
21
- """SigmoidGrad expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_y, dy = self.inputs
25
- # Calculate sigmoid_grad(y, dy)
26
- # formula of sigmoid_grad is : (1 - y) * y * dy
27
- const_one = graph_builder.value(input_y.dtype, 1.0)
28
- one_mins_y = graph_builder.emit('Sub', [const_one, input_y])
29
- y_mul_dy = graph_builder.emit('Mul', [input_y, dy])
30
- res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy])
31
- return res
@@ -1,35 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for slice"""
16
- from ._utils import Expander, ExpanderInfoValidator as VLD
17
-
18
-
19
- @VLD.check_attrs('begin', 'size')
20
- class Slice(Expander):
21
- """Slice expander"""
22
-
23
- def _expand(self, graph_builder):
24
- input_x = self.inputs[0]
25
- begin = self.attrs['begin']
26
- size = self.attrs['size']
27
- end = []
28
- strides = []
29
- for i, begin_idx in enumerate(begin):
30
- strides.append(1)
31
- end.append(begin_idx + size[i])
32
- output = graph_builder.tensor(size, input_x.dtype, input_x.data_format)
33
- graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides})
34
-
35
- return output
@@ -1,42 +0,0 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SoftmaxCrossEntropyWithLogits"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
21
- class SoftmaxCrossEntropyWithLogits(Expander):
22
- """SoftmaxCrossEntropyWithLogits expander"""
23
-
24
- def _expand(self, graph_builder):
25
- logits, label = self.inputs
26
- # Calculate softmax_cross_entropy_with_logits(logits, label)
27
- # formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits)))
28
- axis = (-1,)
29
- max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True})
30
- data_sub = graph_builder.emit('Sub', [logits, max_x])
31
- data_exp = graph_builder.emit('Exp', [data_sub])
32
- data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
33
- data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum])
34
- const_eps = graph_builder.value(logits.dtype, 0.000001)
35
- data_softmax_safety = graph_builder.emit("Maximum", [data_softmax, const_eps])
36
- softmax_log = graph_builder.emit('Log', [data_softmax_safety])
37
- label_mul_log = graph_builder.emit('Mul', [label, softmax_log])
38
- tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={
39
- 'reduce_axis': axis, 'keep_dims': False})
40
- loss = graph_builder.emit('Neg', [tmp_res])
41
- dlogits = graph_builder.emit('Sub', [data_softmax, label])
42
- return loss, dlogits
@@ -1,41 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for SoftmaxGradExt"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from ._utils import get_reduce_axis_shape
19
-
20
-
21
- @VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
22
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.check_attrs('axis')
24
- class SoftmaxGradExt(Expander):
25
- """SoftmaxGradExt expander"""
26
-
27
- def _expand(self, graph_builder):
28
- x, y, z = self.inputs
29
- axis = self.attrs['axis']
30
-
31
- reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
32
-
33
- data_mul = graph_builder.emit('Mul', [x, y])
34
- data_sum = graph_builder.emit('ReduceSum', [data_mul],
35
- attrs={'reduce_axis': reduce_axis, 'keep_dims': True, 'reduce_output_fuse': True})
36
- if x.data_format == DF.FRAC_NZ:
37
- data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape})
38
- data_sub = graph_builder.emit('Sub', [x, data_sum])
39
- data_mul2 = graph_builder.emit('Mul', [data_sub, y])
40
- result = graph_builder.emit('Mul', [data_mul2, z])
41
- return result