mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,142 +0,0 @@
1
- /**
2
- * Copyright 2022-2023 Huawei Technologies Co., Ltd
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #ifndef MINDSPORE_INCLUDE_API_NET_H
18
- #define MINDSPORE_INCLUDE_API_NET_H
19
-
20
- #include <memory>
21
- #include <vector>
22
- #include <unordered_set>
23
- #include <string>
24
- #include "include/api/types.h"
25
- #include "include/api/data_type.h"
26
- #include "include/api/cfg.h"
27
-
28
- namespace mindspore {
29
- /// \brief Register node or sub network
30
- #define REG(_name) Register(_name, #_name)
31
-
32
- class Expr;
33
- class NodeImpl;
34
- class NetImpl;
35
- class NodeSet;
36
- class Graph;
37
- class NetData;
38
-
39
- class MS_API NetBase {
40
- public:
41
- NetBase() = default;
42
- virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0;
43
- virtual uint32_t type() = 0;
44
- };
45
-
46
- class MS_API Node : public NetBase {
47
- public:
48
- Node();
49
- virtual ~Node();
50
- /// \brief Create output expression from node
51
-
52
- /// \param[in] name Name of input (like "labels" etc.)
53
- ///
54
- /// \return Expression
55
- Expr *Create(std::string name);
56
- /// \brief Run node on inputs. This operator is used in Net::construct()
57
- ///
58
- /// \param[in] inputs Inputs expression for the node.
59
- /// \return Output node expression vector
60
- std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) override;
61
- uint32_t type() final;
62
-
63
- private:
64
- friend NodeImpl;
65
- std::shared_ptr<NodeImpl> impl_ = nullptr;
66
- };
67
-
68
- class MS_API Net : public NetBase, public std::enable_shared_from_this<Net> {
69
- public:
70
- Net();
71
- virtual ~Net();
72
- explicit Net(std::string name);
73
- explicit Net(const Graph &g);
74
- /// \brief Define the relation between network inputs and outputs
75
- ///
76
- /// \param[in] inputs expression vector
77
- ///
78
- /// \return expression vector
79
-
80
- virtual std::vector<Expr *> construct(const std::vector<Expr *> &inputs);
81
- /// \brief Addition operation
82
- ///
83
- /// \param[in] inputs Two elements to add
84
- ///
85
- /// \return expression vector (single element)
86
-
87
- /// \brief Execution operator. Connect inputs to outputs via user defined construct
88
- ///
89
- /// \return expression vector
90
-
91
- std::vector<Expr *> operator()(const std::vector<Expr *> &inputs);
92
- void Register(Net *net, std::string &&name);
93
- void Register(Node *node, std::string &&name);
94
- /// \brief Find the trainable params for the trained network
95
- ///
96
- /// \return NodeSet for all trainable nodes
97
- std::shared_ptr<NodeSet> trainable_params();
98
- virtual void Add(NetBase *element);
99
- /// \brief Input shape
100
- ///
101
- /// \param[in] idx input index
102
- ///
103
- /// \return Specific input shape vector
104
- const std::vector<int> InputShape(int idx);
105
- /// \brief Output shape
106
- ///
107
- /// \param[in] idx Output index
108
- ///
109
- /// \return Specific output shape vector
110
- const std::vector<int> OutputShape(int idx);
111
- uint32_t type() final;
112
-
113
- private:
114
- friend NetImpl;
115
- friend NetData;
116
- std::shared_ptr<NetImpl> impl_;
117
- };
118
-
119
- class MS_API SoftMaxCrossEntropyCfg {
120
- public:
121
- std::string reduction = "mean"; /**< Specifies reduction mode. The optional values are "none", "mean", "sum" */
122
- };
123
-
124
- class MS_API AdamConfig {
125
- public:
126
- float learning_rate_ = 1e-3;
127
- float beta1_ = 0.9;
128
- float beta2_ = 0.999;
129
- float eps_ = 1e-08;
130
- bool use_nesterov_ = false;
131
- };
132
-
133
- namespace NN {
134
- MS_API Net *NetWithLoss(Net *net, Node *loss);
135
- MS_API Graph *GraphWithLoss(Graph *g, Node *loss);
136
- MS_API Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg);
137
- MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
138
- MS_API std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32,
139
- int fmt = NHWC);
140
- }; // namespace NN
141
- } // namespace mindspore
142
- #endif // MINDSPORE_INCLUDE_API_NET_H
@@ -1,262 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """LRScheduler."""
16
- from mindspore import ops
17
- from mindspore.nn.optim_ex.optimizer import Optimizer
18
- from mindspore.common.api import jit_class
19
- from mindspore.common.parameter import Parameter
20
- from mindspore.common import Tensor
21
- import mindspore.common.dtype as mstype
22
- from mindspore.ops import functional as F
23
- from mindspore import _checkparam as Validator
24
-
25
-
26
- __all__ = ['StepLR', 'LinearLR', 'LRScheduler']
27
-
28
-
29
- @jit_class
30
- class LRScheduler():
31
- r"""
32
- Basic class of learning rate schedule.
33
-
34
- .. warning::
35
- This is an experimental lr scheduler module that is subject to change.
36
- This module must be used with optimizers in `Experimental Optimizer
37
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#experimental-optimizer>`_ .
38
-
39
- Args:
40
- optimizer (:class:`mindspore.nn.optim_ex.Optimizer`): The optimizer instance.
41
- last_epoch (int, optional): The epoch/step number. Default: ``-1``.
42
- verbose (bool, optional): Whether to print lr information. Default: ``False``.
43
-
44
- Raises:
45
- TypeError: If `optimizer` is not an Optimizer.
46
- TypeError: If `last_epoch` is not greater than -1.
47
- ValueError: If `verbose` is not bool.
48
-
49
- Supported Platforms:
50
- ``Ascend`` ``GPU`` ``CPU``
51
- """
52
-
53
- def __init__(self, optimizer, last_epoch=-1, verbose=False):
54
- if not isinstance(optimizer, Optimizer):
55
- raise TypeError('{} is not an Optimizer'.format(
56
- type(optimizer).__name__))
57
- Validator.check_value_type("last_epoch", last_epoch, [int])
58
- if last_epoch < -1:
59
- raise ValueError("Invalid last_epoch: {}".format(last_epoch))
60
- Validator.check_value_type("verbose", verbose, [bool])
61
-
62
- self.optimizer = optimizer
63
- self._last_lr = []
64
- self.groups_num = len(optimizer.param_groups)
65
- self.verbose = verbose
66
- self.last_epoch = Parameter(Tensor(last_epoch, dtype=mstype.float32),
67
- name='last_epoch_' + self.__class__.__name__)
68
- self.increase_tensor = Tensor(1, mstype.int32)
69
- self.assignadd = ops.AssignAdd()
70
- self.step()
71
-
72
- @staticmethod
73
- def _get_lr():
74
- """
75
- Compute current lr.
76
-
77
- This method must be overridden by all subclasses.
78
- """
79
- raise NotImplementedError
80
-
81
- @staticmethod
82
- def _print_lr(is_verbose, group, lr):
83
- """
84
- Display the current learning rate.
85
- """
86
- if is_verbose:
87
- print('Adjusting learning rate of group %s to %s.' % (group, lr.value()))
88
-
89
- def get_last_lr(self):
90
- """
91
- Return last computed learning rate by current scheduler.
92
- """
93
- return [group["lr"].value() for group in self.optimizer.param_groups]
94
-
95
- def step(self):
96
- """
97
- Get the current learning rate and change the learning rate.
98
- """
99
- self.assignadd(self.last_epoch, self.increase_tensor)
100
- values = self._get_lr()
101
- for i in range(self.groups_num):
102
- lr = values[i]
103
- lr = F.depend(lr, F.assign(self.optimizer.param_groups[i]["lr"], lr))
104
- self._print_lr(self.verbose, i, lr)
105
-
106
-
107
- @jit_class
108
- class StepLR(LRScheduler):
109
- """Decays the learning rate of each parameter group by gamma every
110
- step_size epochs. Notice that such decay can happen simultaneously with
111
- other changes to the learning rate from outside this scheduler.
112
-
113
- .. warning::
114
- This is an experimental lr scheduler module that is subject to change.
115
- This module must be used with optimizers in `Experimental Optimizer
116
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#experimental-optimizer>`_ .
117
-
118
- Args:
119
- optimizer (:class:`mindspore.nn.optim_ex.Optimizer`): Wrapped optimizer.
120
- step_size (int): Period of learning rate decay.
121
- gamma (float, optional): Multiplicative factor of learning rate decay.
122
- Default: ``0.1``.
123
- last_epoch (int, optional): The index of last epoch. Default: ``-1``.
124
- verbose (bool, optional): If ``True``, prints a message to stdout for
125
- each update. Default: ``False``.
126
-
127
- Supported Platforms:
128
- ``Ascend`` ``GPU`` ``CPU``
129
-
130
- Examples:
131
- >>> import mindspore
132
- >>> from mindspore import nn
133
- >>> # Define the network structure of LeNet5. Refer to
134
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
135
- >>> net = LeNet5()
136
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
137
- >>> optimizer = nn.optim_ex.Adam(net.trainable_params(), lr=0.05)
138
- >>> # Assuming optimizer uses lr = 0.05 for all groups
139
- >>> # lr = 0.05 if epoch < 2
140
- >>> # lr = 0.005 if 2 <= epoch < 4
141
- >>> # lr = 0.0005 if 4 <= epoch < 6
142
- >>> scheduler = nn.StepLR(optimizer, step_size=2, gamma=0.1)
143
- >>> def forward_fn(data, label):
144
- ... logits = net(data)
145
- ... loss = loss_fn(logits, label)
146
- ... return loss, logits
147
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
148
- >>> def train_step(data, label):
149
- ... (loss, _), grads = grad_fn(data, label)
150
- ... optimizer(grads)
151
- ... return loss
152
- >>> for epoch in range(6):
153
- ... # Create the dataset taking MNIST as an example. Refer to
154
- ... # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/mnist.py
155
- ... for data, label in create_dataset():
156
- ... train_step(data, label)
157
- ... scheduler.step()
158
- ... current_lr = scheduler.get_last_lr()
159
- """
160
- def __init__(self, optimizer, step_size, gamma=0.5, last_epoch=-1, verbose=False):
161
- self.step_size = step_size
162
- self.gamma = gamma
163
- super(StepLR, self).__init__(optimizer, last_epoch, verbose)
164
-
165
- def _get_lr(self):
166
- if (self.last_epoch == Tensor(0, mstype.float32)) or (
167
- self.last_epoch % self.step_size != Tensor(0, mstype.float32)):
168
- return [group['lr'] * 1. for group in self.optimizer.param_groups]
169
- return [group['lr'] * self.gamma
170
- for group in self.optimizer.param_groups]
171
-
172
-
173
- @jit_class
174
- class LinearLR(LRScheduler):
175
- """Decays the learning rate of each parameter group by linearly changing small
176
- multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
177
- Notice that such decay can happen simultaneously with other changes to the learning rate
178
- from outside this scheduler.
179
-
180
- .. warning::
181
- This is an experimental lr scheduler module that is subject to change.
182
- This module must be used with optimizers in `Experimental Optimizer
183
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#experimental-optimizer>`_ .
184
-
185
- Args:
186
- optimizer (:class:`mindspore.nn.optim_ex.Optimizer`): Wrapped optimizer.
187
- start_factor (float, optional): The number we multiply learning rate in the first epoch.
188
- The multiplication factor changes towards `end_factor` in the following epochs.
189
- Default: ``1.0 /3``.
190
- end_factor (float, optional): The number we multiply learning rate at the end of linear changing
191
- process. Default: ``1.0``.
192
- total_iters (int, optional): The number of iterations that multiplicative factor reaches to 1.
193
- Default: ``5``.
194
- last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
195
- verbose (bool, optional): If ``True``, prints a message to stdout for
196
- each update. Default: ``False``.
197
-
198
- Raises:
199
- ValueError: If `start_factor` is not in the range of (0, 1].
200
- ValueError: If `end_factor` is not in the range of [0, 1].
201
-
202
- Supported Platforms:
203
- ``Ascend`` ``GPU`` ``CPU``
204
-
205
- Examples:
206
- >>> import mindspore
207
- >>> from mindspore.nn import LinearLR
208
- >>> from mindspore import nn
209
- >>> # Define the network structure of LeNet5. Refer to
210
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
211
- >>> net = LeNet5()
212
- >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
213
- >>> optimizer = nn.optim_ex.Adam(net.trainable_params(), lr=0.05)
214
- >>> # Assuming optimizer uses lr = 0.05 for all groups
215
- >>> # lr = 0.025 if epoch == 0
216
- >>> # lr = 0.03125 if epoch == 1
217
- >>> # lr = 0.0375 if epoch == 2
218
- >>> # lr = 0.04375 if epoch == 3
219
- >>> # lr = 0.05 if epoch >= 4
220
- >>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
221
- >>> def forward_fn(data, label):
222
- ... logits = net(data)
223
- ... loss = loss_fn(logits, label)
224
- ... return loss, logits
225
- >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
226
- >>> def train_step(data, label):
227
- ... (loss, _), grads = grad_fn(data, label)
228
- ... optimizer(grads)
229
- ... return loss
230
- >>> for epoch in range(5):
231
- ... # Create the dataset taking MNIST as an example. Refer to
232
- ... # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/mnist.py
233
- ... for data, label in create_dataset():
234
- ... train_step(data, label)
235
- ... scheduler.step()
236
- ... current_lr = scheduler.get_last_lr()
237
- """
238
-
239
- def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
240
- verbose=False):
241
-
242
- if start_factor > 1.0 or start_factor <= 0:
243
- raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
244
-
245
- if end_factor > 1.0 or end_factor < 0:
246
- raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
247
-
248
- self.start_factor = start_factor
249
- self.end_factor = end_factor
250
- self.total_iters = total_iters
251
- super(LinearLR, self).__init__(optimizer, last_epoch, verbose)
252
-
253
- def _get_lr(self):
254
- if self.last_epoch == Tensor(0, mstype.float32):
255
- return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
256
-
257
- if self.last_epoch > self.total_iters:
258
- return [group['lr'] * 1. for group in self.optimizer.param_groups]
259
-
260
- return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
261
- (self.total_iters * self.start_factor + (self.last_epoch - 1) *
262
- (self.end_factor - self.start_factor))) for group in self.optimizer.param_groups]
@@ -1,248 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """image_ops"""
17
-
18
- from mindspore import Tensor
19
- from mindspore.common import dtype as mstype
20
- from mindspore.ops._grad_experimental.grad_base import bprop_getters
21
- from mindspore.ops import operations as P
22
- from mindspore.ops import functional as F
23
- from mindspore.ops.operations import _grad_ops as G
24
- from mindspore.ops.operations.image_ops import ResizeBicubic
25
- from mindspore.ops.operations._grad_ops import ResizeBicubicGrad
26
- from mindspore.ops.operations.image_ops import ResizeV2
27
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
28
- from mindspore.ops.operations.image_ops import CropAndResize
29
- from mindspore.ops.operations.image_ops import CropAndResizeGradImage
30
- from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes
31
- from mindspore.ops.operations.image_ops import RGBToHSV
32
- from mindspore.ops.operations.image_ops import ScaleAndTranslate
33
- from mindspore import context
34
-
35
-
36
- @bprop_getters.register(ResizeBicubic)
37
- def get_bprop_resize_bicubic(self):
38
- """Grad definition for `ResizeBicubic` operation."""
39
- resize_bicubic_grad = ResizeBicubicGrad(align_corners=self.align_corners,
40
- half_pixel_centers=self.half_pixel_centers)
41
-
42
- def bprop(images, size, out, dout):
43
- dx = resize_bicubic_grad(dout, images)
44
- return (dx, P.ZerosLike()(size))
45
- return bprop
46
-
47
-
48
- @bprop_getters.register(ResizeV2)
49
- def get_bprop_resize_v2(self):
50
- """Grad definition for `ResizeV2` operation."""
51
- resize_v2_grad = G.ResizeV2Grad(coordinate_transformation_mode=self.coordinate_transformation_mode,
52
- mode=self.mode)
53
-
54
- def bprop(x, roi, scales, sizes, out, dout):
55
- input_size = P.Shape()(x)
56
- dx = resize_v2_grad(dout, roi, scales, Tensor(input_size))
57
- return (dx, zeros_like(roi), zeros_like(scales), zeros_like(sizes))
58
- return bprop
59
-
60
-
61
- @bprop_getters.register(CropAndResize)
62
- def get_bprop_crop_and_resize(self):
63
- """Grad definition for `CropAndResize` operation."""
64
- allowed_types = [mstype.float16, mstype.float32, mstype.float64]
65
- gradboxes = CropAndResizeGradBoxes(method="bilinear")
66
- method_ = self.method
67
-
68
- is_ascend_cpu = context.get_context('device_target') in ("Ascend", "CPU")
69
-
70
- def bprop(x, boxes, box_index, crop_size, out, dout):
71
- if method_ != "bilinear":
72
- if not is_ascend_cpu:
73
- return (zeros_like(x), zeros_like(boxes), zeros_like(box_index), zeros_like(crop_size))
74
- image_type = x.dtype
75
- if image_type not in allowed_types:
76
- x = F.cast(x, mstype.float32)
77
- dimage_type = image_type
78
- gradimage = CropAndResizeGradImage(dimage_type, method=method_)
79
- image_shape = x.shape
80
- if F.is_sequence_value_unknown(image_shape):
81
- image_size = P.TensorShape()(x)
82
- image_size = F.cast(image_size, mstype.int32)
83
- else:
84
- image_size = Tensor(image_shape, dtype=mstype.int32)
85
- dimage = gradimage(dout, boxes, box_index, image_size)
86
- dbox = gradboxes(dout, x, boxes, box_index)
87
- return (dimage, dbox, zeros_like(box_index), zeros_like(crop_size))
88
- return bprop
89
-
90
-
91
- def crcp(x):
92
- """Grad definition for `RGBToHSV` operations."""
93
- return P.DivNoNan()(1, x)
94
-
95
-
96
- def function1_rgbtohsv(images, out, dout):
97
- """Grad definition for `RGBToHSV` operations."""
98
- dout = P.Cast()(dout, mstype.float32)
99
- images = P.Cast()(images, mstype.float32)
100
- out = P.Cast()(out, mstype.float32)
101
- return images, out, dout
102
-
103
-
104
- def function2_rgbtohsv(images):
105
- """Grad definition for `RGBToHSV` operations."""
106
- # Input Channels
107
- reds = images[..., 0]
108
- greens = images[..., 1]
109
- blues = images[..., 2]
110
- return reds, greens, blues
111
-
112
-
113
- def function3_rgbtohsv(out, reds):
114
- """Grad definition for `RGBToHSV` operations."""
115
- # Output Channels
116
- saturation = out[..., 1]
117
- value = out[..., 2]
118
- dsr1 = P.Cast()(reds > 0, mstype.float32)
119
- return dsr1, saturation, value
120
-
121
-
122
- def function4_rgbtohsv(reds, greens, blues):
123
- """Grad definition for `RGBToHSV` operations."""
124
- r_b = P.LogicalAnd()((reds >= blues), (reds >= greens))
125
- red_biggest = P.Cast()(r_b, mstype.float32)
126
- g_b = P.LogicalAnd()((greens > reds), (greens >= blues))
127
- green_biggest = P.Cast()(g_b, mstype.float32)
128
- b_b = P.LogicalAnd()((blues > reds), (blues > greens))
129
- blue_biggest = P.Cast()(b_b, mstype.float32)
130
- return red_biggest, green_biggest, blue_biggest
131
-
132
-
133
- def function5_rgbtohsv(reds, greens, blues):
134
- """Grad definition for `RGBToHSV` operations."""
135
- r_s = P.LogicalAnd()((reds < blues), (reds < greens))
136
- red_smallest = P.Cast()(r_s, mstype.float32)
137
- g_s = P.LogicalAnd()((greens <= reds), (greens < blues))
138
- green_smallest = P.Cast()(g_s, mstype.float32)
139
- b_s = P.LogicalAnd()((blues <= reds), (blues <= greens))
140
- blue_smallest = P.Cast()(b_s, mstype.float32)
141
- return red_smallest, green_smallest, blue_smallest
142
-
143
-
144
- def function6_rgbtohsv(red_biggest, green_biggest, blue_biggest):
145
- """Grad definition for `RGBToHSV` operations."""
146
- dv_dr = red_biggest
147
- dv_dg = green_biggest
148
- dv_db = blue_biggest
149
- return dv_dr, dv_dg, dv_db
150
-
151
-
152
- def function7_rgbtohsv(greens, green_biggest, dhb5, dh_db_1, dh_db_2, dh_db_3, dh_db_4,\
153
- dout, dv_dr, dv_dg, dv_db, ds_dr, ds_dg, ds_db, dh_dr, dh_dg):
154
- """Grad definition for `RGBToHSV` operations."""
155
- dh_db_5 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * dhb5)
156
-
157
- dh_db = dh_db_1 + dh_db_2 + dh_db_3 + dh_db_4 + dh_db_5
158
-
159
- dh_db = dh_db / 360
160
-
161
- dv_drgb = P.Stack(-1)(
162
- [dout[..., 2] * dv_dr, dout[..., 2] * dv_dg, dout[..., 2] * dv_db])
163
- ds_drgb = P.Stack(-1)(
164
- [dout[..., 1] * ds_dr, dout[..., 1] * ds_dg, dout[..., 1] * ds_db])
165
- dh_drgb = P.Stack(-1)(
166
- [dout[..., 0] * dh_dr, dout[..., 0] * dh_dg, dout[..., 0] * dh_db])
167
- dvds_drgb = P.Add()(dv_drgb, ds_drgb)
168
- doutient_input = P.Add()(dvds_drgb, dh_drgb)
169
- return (doutient_input,)
170
-
171
-
172
- @bprop_getters.register(RGBToHSV)
173
- def get_bprop_rgb_to_hsv(self):
174
- """dout definition for 'RGBToHSV' operation"""
175
-
176
- def bprop(images, out, dout):
177
- images, out, dout = function1_rgbtohsv(images, out, dout)
178
- reds, greens, blues = function2_rgbtohsv(images)
179
- dsr1, saturation, value = function3_rgbtohsv(out, reds)
180
- red_biggest, green_biggest, blue_biggest = function4_rgbtohsv(reds, greens, blues)
181
- red_smallest, green_smallest, blue_smallest = function5_rgbtohsv(reds, greens, blues)
182
- dv_dr, dv_dg, dv_db = function6_rgbtohsv(red_biggest, green_biggest, blue_biggest)
183
- dsr2 = red_biggest * P.Add()(green_smallest * greens, blue_smallest * blues) * crcp(P.Square()(reds))
184
- dsr3 = red_smallest * -1 * crcp((green_biggest * greens) + (blue_biggest * blues))
185
- ds_dr = dsr1 * P.Add()(dsr2, dsr3)
186
- dsg1 = P.Cast()((greens > 0), mstype.float32)
187
- dsg2 = green_biggest * P.Add()(red_smallest * reds, blue_smallest * blues) * crcp(P.Square()(greens))
188
- dsg3 = green_smallest * -1 * crcp((red_biggest * reds) + (blue_biggest * blues))
189
- ds_dg = dsg1 * P.Add()(dsg2, dsg3)
190
-
191
- dsb1 = P.Cast()((blues > 0), mstype.float32)
192
- dsb2 = blue_biggest * P.Add()(green_smallest * greens, red_smallest * reds) * crcp(P.Square()(blues))
193
- dsb3 = blue_smallest * -1 * crcp((green_biggest * greens) + (red_biggest * reds))
194
- ds_db = dsb1 * P.Add()(dsb2, dsb3)
195
-
196
- dhr1 = (greens - blues) * crcp(P.Square()(saturation)) * crcp(P.Square()(value))
197
- dh_dr_1 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * -1 * dhr1)
198
- dhr2 = red_smallest * (blues - greens) * crcp(P.Square()(reds - greens))
199
- dh_dr_2 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * dhr2)
200
- dhr3 = blue_smallest * -1 * crcp(greens - blues)
201
- dh_dr_3 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * dhr3)
202
- dhr4 = red_smallest * (blues - greens) * crcp(P.Square()(blues - reds))
203
- dh_dr_4 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * dhr4)
204
- dhr5 = green_smallest * crcp(blues - greens)
205
- dh_dr_5 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * dhr5)
206
-
207
- dh_dr = (dh_dr_1 + dh_dr_2 + dh_dr_3 + dh_dr_4 + dh_dr_5) / 360
208
-
209
- dhg1 = (blues - reds) * crcp(P.Square()(saturation)) * crcp(P.Square()(value))
210
- dh_dg_1 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * -1 * dhg1)
211
- dhg2 = green_smallest * (reds - blues) * crcp(P.Square()(reds - greens))
212
- dh_dg_2 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * dhg2)
213
- dhg3 = blue_smallest * crcp(reds - blues)
214
- dh_dg_3 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * dhg3)
215
- dhg4 = green_smallest * (reds - blues) * crcp(P.Square()(blues - greens))
216
- dh_dg_4 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * dhg4)
217
- dhg5 = red_smallest * -1 * crcp(blues - reds)
218
- dh_dg_5 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * dhg5)
219
-
220
- dh_dg = (dh_dg_1 + dh_dg_2 + dh_dg_3 + dh_dg_4 + dh_dg_5) / 360
221
-
222
- dhb1 = (reds - greens) * crcp(P.Square()(saturation)) * crcp(P.Square()(value))
223
- dh_db_1 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * -1 * dhb1)
224
- dhb2 = blue_smallest * (greens - reds) * crcp(P.Square()(reds - blues))
225
- dh_db_2 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * dhb2)
226
- dhb3 = green_smallest * -1 * crcp(reds - greens)
227
- dh_db_3 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * dhb3)
228
- dhb4 = blue_smallest * (greens - reds) * crcp(P.Square()(greens - blues))
229
- dh_db_4 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * dhb4)
230
- dhb5 = red_smallest * crcp(greens - reds)
231
- return function7_rgbtohsv(greens, green_biggest, dhb5, dh_db_1, dh_db_2, dh_db_3,\
232
- dh_db_4, dout, dv_dr, dv_dg, dv_db, ds_dr, ds_dg, ds_db, dh_dr, dh_dg)
233
- return bprop
234
-
235
-
236
- @bprop_getters.register(ScaleAndTranslate)
237
- def get_bprop_scale_and_translate(self):
238
- """Grad definition for `ScaleAndTranslate` operation"""
239
- scale_and_translate_grad = G.ScaleAndTranslateGrad(self.kernel_type, self.antialias)
240
-
241
- def bprop(images, size, scale, translation, out, dout):
242
- images_fp32 = F.cast(images, mstype.float32)
243
- grad0_fp32 = scale_and_translate_grad(dout, images_fp32, scale, translation)
244
- grad0 = F.cast(grad0_fp32, F.dtype(images))
245
- result = (grad0, F.zeros_like(size), F.zeros_like(scale), F.zeros_like(translation))
246
- return result
247
-
248
- return bprop