mindspore 2.1.0__cp38-cp38-win_amd64.whl → 2.2.11__cp38-cp38-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (511) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +4 -1
  5. mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
  8. mindspore/_check_jit_forbidden_api.py +3 -1
  9. mindspore/_checkparam.py +23 -29
  10. mindspore/_extends/graph_kernel/__init__.py +0 -1
  11. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  12. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  13. mindspore/_extends/graph_kernel/splitter.py +4 -11
  14. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  15. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  16. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  17. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  18. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  19. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  20. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  21. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  22. mindspore/_extends/parse/__init__.py +13 -15
  23. mindspore/_extends/parse/namespace.py +7 -33
  24. mindspore/_extends/parse/parser.py +67 -72
  25. mindspore/_extends/parse/resources.py +1 -1
  26. mindspore/_extends/parse/standard_method.py +86 -106
  27. mindspore/_extends/parse/trope.py +1 -1
  28. mindspore/_extends/remote/kernel_build_server.py +25 -7
  29. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  30. mindspore/_install_custom.py +43 -0
  31. mindspore/amp.py +47 -11
  32. mindspore/atlprov.dll +0 -0
  33. mindspore/boost/boost.py +1 -8
  34. mindspore/boost/boost_cell_wrapper.py +3 -2
  35. mindspore/boost/grad_accumulation.py +1 -1
  36. mindspore/boost/group_loss_scale_manager.py +8 -7
  37. mindspore/c1.dll +0 -0
  38. mindspore/c1xx.dll +0 -0
  39. mindspore/c2.dll +0 -0
  40. mindspore/common/__init__.py +5 -3
  41. mindspore/common/_jit_fallback_utils.py +6 -0
  42. mindspore/common/_register_for_adapter.py +2 -0
  43. mindspore/common/_register_for_tensor.py +2 -2
  44. mindspore/common/_stub_tensor.py +13 -0
  45. mindspore/common/_utils.py +29 -0
  46. mindspore/common/api.py +174 -259
  47. mindspore/common/auto_dynamic_shape.py +494 -0
  48. mindspore/common/dtype.py +18 -11
  49. mindspore/common/dump.py +6 -4
  50. mindspore/common/initializer.py +14 -14
  51. mindspore/common/jit_config.py +33 -15
  52. mindspore/common/lazy_inline.py +126 -7
  53. mindspore/common/mindir_util.py +101 -0
  54. mindspore/common/parameter.py +51 -41
  55. mindspore/common/seed.py +4 -4
  56. mindspore/common/sparse_tensor.py +13 -14
  57. mindspore/common/tensor.py +243 -165
  58. mindspore/communication/__init__.py +7 -4
  59. mindspore/communication/_comm_helper.py +83 -4
  60. mindspore/communication/management.py +152 -84
  61. mindspore/config/op_info.config +14 -3
  62. mindspore/context.py +152 -61
  63. mindspore/dataset/__init__.py +5 -5
  64. mindspore/dataset/audio/__init__.py +2 -2
  65. mindspore/dataset/audio/transforms.py +52 -52
  66. mindspore/dataset/callback/ds_callback.py +16 -2
  67. mindspore/dataset/core/config.py +68 -51
  68. mindspore/dataset/engine/cache_client.py +33 -7
  69. mindspore/dataset/engine/datasets.py +250 -112
  70. mindspore/dataset/engine/datasets_audio.py +43 -211
  71. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  72. mindspore/dataset/engine/datasets_text.py +43 -67
  73. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  74. mindspore/dataset/engine/datasets_vision.py +219 -1029
  75. mindspore/dataset/engine/iterators.py +11 -4
  76. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  77. mindspore/dataset/engine/obs/util.py +3 -0
  78. mindspore/dataset/engine/samplers.py +1 -1
  79. mindspore/dataset/engine/validators.py +19 -5
  80. mindspore/dataset/text/__init__.py +3 -3
  81. mindspore/dataset/text/transforms.py +101 -127
  82. mindspore/dataset/text/utils.py +205 -138
  83. mindspore/dataset/transforms/__init__.py +1 -1
  84. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  85. mindspore/dataset/transforms/transforms.py +95 -40
  86. mindspore/dataset/utils/browse_dataset.py +8 -2
  87. mindspore/dataset/utils/line_reader.py +17 -19
  88. mindspore/dataset/vision/__init__.py +3 -3
  89. mindspore/dataset/vision/c_transforms.py +6 -3
  90. mindspore/dataset/vision/transforms.py +409 -287
  91. mindspore/dataset/vision/utils.py +13 -14
  92. mindspore/dataset/vision/validators.py +11 -1
  93. mindspore/dnnl.dll +0 -0
  94. mindspore/dpcmi.dll +0 -0
  95. mindspore/experimental/map_parameter.py +14 -0
  96. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  97. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  98. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  99. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  100. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  101. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  102. mindspore/gen_ops.py +273 -0
  103. mindspore/include/OWNERS +0 -1
  104. mindspore/include/api/data_type.h +2 -1
  105. mindspore/include/api/graph.h +0 -15
  106. mindspore/include/api/kernel.h +2 -0
  107. mindspore/include/api/kernel_api.h +37 -12
  108. mindspore/include/api/model.h +17 -14
  109. mindspore/include/api/status.h +8 -3
  110. mindspore/include/api/types.h +37 -4
  111. mindspore/include/c_api/ms/abstract.h +67 -0
  112. mindspore/include/c_api/ms/attribute.h +197 -0
  113. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  114. mindspore/include/c_api/ms/base/macros.h +32 -0
  115. mindspore/include/c_api/ms/base/status.h +33 -0
  116. mindspore/include/c_api/ms/base/types.h +282 -0
  117. mindspore/include/c_api/ms/context.h +102 -0
  118. mindspore/include/c_api/ms/graph.h +160 -0
  119. mindspore/include/c_api/ms/node.h +606 -0
  120. mindspore/include/c_api/ms/tensor.h +161 -0
  121. mindspore/include/c_api/ms/value.h +84 -0
  122. mindspore/include/dataset/constants.h +6 -5
  123. mindspore/include/dataset/execute.h +23 -13
  124. mindspore/include/dataset/text.h +26 -26
  125. mindspore/include/dataset/transforms.h +13 -13
  126. mindspore/include/dataset/vision.h +60 -60
  127. mindspore/include/dataset/vision_ascend.h +5 -6
  128. mindspore/include/dataset/vision_lite.h +17 -17
  129. mindspore/jpeg62.dll +0 -0
  130. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  131. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  132. mindspore/mindspore_backend.dll +0 -0
  133. mindspore/mindspore_common.dll +0 -0
  134. mindspore/mindspore_core.dll +0 -0
  135. mindspore/mindspore_glog.dll +0 -0
  136. mindspore/mindspore_shared_lib.dll +0 -0
  137. mindspore/msobj140.dll +0 -0
  138. mindspore/mspdb140.dll +0 -0
  139. mindspore/mspdbcore.dll +0 -0
  140. mindspore/mspdbst.dll +0 -0
  141. mindspore/mspft140.dll +0 -0
  142. mindspore/msvcdis140.dll +0 -0
  143. mindspore/msvcp140_1.dll +0 -0
  144. mindspore/msvcp140_2.dll +0 -0
  145. mindspore/msvcp140_atomic_wait.dll +0 -0
  146. mindspore/msvcp140_codecvt_ids.dll +0 -0
  147. mindspore/nn/__init__.py +0 -2
  148. mindspore/nn/cell.py +313 -74
  149. mindspore/nn/dynamic_lr.py +21 -21
  150. mindspore/nn/layer/activation.py +22 -30
  151. mindspore/nn/layer/basic.py +15 -13
  152. mindspore/nn/layer/channel_shuffle.py +1 -1
  153. mindspore/nn/layer/container.py +271 -9
  154. mindspore/nn/layer/conv.py +323 -204
  155. mindspore/nn/layer/dense.py +8 -5
  156. mindspore/nn/layer/embedding.py +33 -27
  157. mindspore/nn/layer/flash_attention.py +61 -95
  158. mindspore/nn/layer/image.py +8 -6
  159. mindspore/nn/layer/math.py +16 -25
  160. mindspore/nn/layer/normalization.py +107 -66
  161. mindspore/nn/layer/padding.py +1 -1
  162. mindspore/nn/layer/pooling.py +131 -109
  163. mindspore/nn/layer/rnn_cells.py +27 -22
  164. mindspore/nn/layer/rnns.py +13 -16
  165. mindspore/nn/layer/thor_layer.py +1 -1
  166. mindspore/nn/layer/transformer.py +221 -154
  167. mindspore/nn/learning_rate_schedule.py +9 -1
  168. mindspore/nn/loss/loss.py +235 -174
  169. mindspore/nn/optim/ada_grad.py +2 -1
  170. mindspore/nn/optim/adadelta.py +1 -0
  171. mindspore/nn/optim/adafactor.py +2 -1
  172. mindspore/nn/optim/adam.py +7 -4
  173. mindspore/nn/optim/adamax.py +3 -2
  174. mindspore/nn/optim/adasum.py +2 -2
  175. mindspore/nn/optim/asgd.py +2 -3
  176. mindspore/nn/optim/ftrl.py +6 -5
  177. mindspore/nn/optim/lamb.py +7 -4
  178. mindspore/nn/optim/lars.py +1 -1
  179. mindspore/nn/optim/lazyadam.py +5 -3
  180. mindspore/nn/optim/momentum.py +2 -1
  181. mindspore/nn/optim/optimizer.py +53 -4
  182. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  183. mindspore/nn/optim/rmsprop.py +4 -3
  184. mindspore/nn/optim/rprop.py +23 -12
  185. mindspore/nn/optim/sgd.py +26 -11
  186. mindspore/nn/optim/thor.py +9 -7
  187. mindspore/nn/probability/bijector/bijector.py +5 -5
  188. mindspore/nn/probability/bijector/power_transform.py +27 -27
  189. mindspore/nn/probability/bijector/softplus.py +3 -3
  190. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  191. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  192. mindspore/nn/probability/distribution/beta.py +3 -3
  193. mindspore/nn/probability/distribution/categorical.py +7 -7
  194. mindspore/nn/probability/distribution/cauchy.py +0 -1
  195. mindspore/nn/probability/distribution/distribution.py +3 -3
  196. mindspore/nn/probability/distribution/gamma.py +3 -3
  197. mindspore/nn/probability/distribution/geometric.py +4 -4
  198. mindspore/nn/probability/distribution/gumbel.py +4 -4
  199. mindspore/nn/probability/distribution/log_normal.py +2 -2
  200. mindspore/nn/probability/distribution/logistic.py +2 -2
  201. mindspore/nn/probability/distribution/poisson.py +4 -4
  202. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  203. mindspore/nn/probability/distribution/uniform.py +6 -6
  204. mindspore/nn/wrap/__init__.py +4 -2
  205. mindspore/nn/wrap/cell_wrapper.py +87 -34
  206. mindspore/nn/wrap/grad_reducer.py +8 -5
  207. mindspore/nn/wrap/loss_scale.py +105 -42
  208. mindspore/numpy/array_creations.py +1 -2
  209. mindspore/numpy/array_ops.py +3 -2
  210. mindspore/numpy/utils_const.py +5 -5
  211. mindspore/opencv_core452.dll +0 -0
  212. mindspore/opencv_imgcodecs452.dll +0 -0
  213. mindspore/opencv_imgproc452.dll +0 -0
  214. mindspore/ops/_grad_experimental/__init__.py +0 -5
  215. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  216. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  217. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  218. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  219. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  220. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  221. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  222. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  223. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  224. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  225. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  226. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  227. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  228. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  229. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  230. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  231. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  232. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  233. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  234. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  235. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  236. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  237. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  238. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  239. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  240. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  241. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  242. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  243. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  244. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  245. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  246. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  247. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  248. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  249. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  250. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  251. mindspore/ops/_primitive_cache.py +1 -1
  252. mindspore/ops/_tracefunc.py +45 -13
  253. mindspore/ops/_utils/utils.py +6 -1
  254. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  255. mindspore/ops/_vmap/vmap_base.py +3 -3
  256. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  257. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  258. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  259. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  260. mindspore/ops/arg_dtype_cast.py +54 -0
  261. mindspore/ops/composite/base.py +37 -10
  262. mindspore/ops/composite/math_ops.py +5 -4
  263. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  264. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  265. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  266. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  267. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  268. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  269. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  270. mindspore/ops/deprecated.py +304 -0
  271. mindspore/ops/function/__init__.py +4 -1
  272. mindspore/ops/function/array_func.py +174 -193
  273. mindspore/ops/function/clip_func.py +81 -13
  274. mindspore/ops/function/debug_func.py +1 -1
  275. mindspore/ops/function/grad/grad_func.py +18 -9
  276. mindspore/ops/function/image_func.py +10 -4
  277. mindspore/ops/function/linalg_func.py +5 -5
  278. mindspore/ops/function/math_func.py +575 -386
  279. mindspore/ops/function/nn_func.py +568 -260
  280. mindspore/ops/function/random_func.py +88 -57
  281. mindspore/ops/function/sparse_func.py +1 -1
  282. mindspore/ops/function/sparse_unary_func.py +14 -12
  283. mindspore/ops/function/vmap_func.py +6 -5
  284. mindspore/ops/functional.py +15 -10
  285. mindspore/ops/op_info_register.py +244 -25
  286. mindspore/ops/operations/__init__.py +31 -19
  287. mindspore/ops/operations/_grad_ops.py +71 -7
  288. mindspore/ops/operations/_inner_ops.py +350 -17
  289. mindspore/ops/operations/_quant_ops.py +4 -8
  290. mindspore/ops/operations/_sequence_ops.py +42 -0
  291. mindspore/ops/operations/array_ops.py +68 -282
  292. mindspore/ops/operations/comm_ops.py +107 -59
  293. mindspore/ops/operations/custom_ops.py +94 -70
  294. mindspore/ops/operations/debug_ops.py +8 -4
  295. mindspore/ops/operations/image_ops.py +18 -12
  296. mindspore/ops/operations/inner_ops.py +26 -3
  297. mindspore/ops/operations/math_ops.py +192 -144
  298. mindspore/ops/operations/nn_ops.py +857 -489
  299. mindspore/ops/operations/other_ops.py +0 -22
  300. mindspore/ops/operations/random_ops.py +53 -111
  301. mindspore/ops/operations/sparse_ops.py +3 -1
  302. mindspore/ops/primitive.py +24 -18
  303. mindspore/parallel/_auto_parallel_context.py +68 -8
  304. mindspore/parallel/_cost_model_context.py +2 -2
  305. mindspore/parallel/_offload_context.py +17 -3
  306. mindspore/parallel/_parallel_serialization.py +12 -5
  307. mindspore/parallel/_ps_context.py +12 -0
  308. mindspore/parallel/_tensor.py +18 -13
  309. mindspore/parallel/_transformer/layers.py +5 -3
  310. mindspore/parallel/_transformer/loss.py +1 -0
  311. mindspore/parallel/_transformer/moe.py +2 -2
  312. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  313. mindspore/parallel/_transformer/transformer.py +23 -3
  314. mindspore/parallel/_utils.py +11 -7
  315. mindspore/parallel/algo_parameter_config.py +85 -5
  316. mindspore/parallel/checkpoint_transform.py +19 -12
  317. mindspore/parallel/shard.py +21 -14
  318. mindspore/pgodb140.dll +0 -0
  319. mindspore/pgort140.dll +0 -0
  320. mindspore/profiler/common/struct_type.py +3 -3
  321. mindspore/profiler/common/util.py +4 -2
  322. mindspore/profiler/envprofiling.py +1 -1
  323. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  324. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  325. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  326. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  327. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  328. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  329. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  330. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  331. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  332. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  333. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  334. mindspore/profiler/parser/flops_parser.py +15 -11
  335. mindspore/profiler/parser/framework_parser.py +38 -22
  336. mindspore/profiler/parser/hccl_parser.py +16 -12
  337. mindspore/profiler/parser/integrator.py +22 -11
  338. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  339. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  340. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  341. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  342. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  343. mindspore/profiler/parser/optime_parser.py +1 -1
  344. mindspore/profiler/parser/profiler_info.py +21 -2
  345. mindspore/profiler/parser/step_trace_parser.py +11 -14
  346. mindspore/profiler/profiling.py +179 -89
  347. mindspore/rewrite/api/node.py +102 -19
  348. mindspore/rewrite/api/node_type.py +5 -1
  349. mindspore/rewrite/api/pattern_engine.py +1 -1
  350. mindspore/rewrite/api/scoped_value.py +9 -17
  351. mindspore/rewrite/api/symbol_tree.py +131 -47
  352. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  353. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  354. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  355. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  356. mindspore/rewrite/common/rewrite_elog.py +5 -1
  357. mindspore/rewrite/namer.py +33 -24
  358. mindspore/rewrite/namespace.py +14 -5
  359. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  360. mindspore/rewrite/node/call_function.py +79 -0
  361. mindspore/rewrite/node/cell_container.py +135 -0
  362. mindspore/rewrite/node/control_flow.py +88 -0
  363. mindspore/rewrite/{node.py → node/node.py} +273 -234
  364. mindspore/rewrite/node/node_manager.py +254 -0
  365. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  366. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  367. mindspore/rewrite/parsers/assign_parser.py +216 -221
  368. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  369. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  370. mindspore/rewrite/parsers/constant_parser.py +9 -6
  371. mindspore/rewrite/parsers/container_parser.py +9 -7
  372. mindspore/rewrite/parsers/for_parser.py +42 -21
  373. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  374. mindspore/rewrite/parsers/if_parser.py +28 -24
  375. mindspore/rewrite/parsers/module_parser.py +196 -25
  376. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  377. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  378. mindspore/rewrite/parsers/return_parser.py +6 -6
  379. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  380. mindspore/rewrite/sparsify/utils.py +1 -1
  381. mindspore/rewrite/symbol_tree.py +523 -578
  382. mindspore/rewrite/symbol_tree_builder.py +9 -193
  383. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  384. mindspore/run_check/_check_version.py +6 -4
  385. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  386. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  387. mindspore/tbbmalloc.dll +0 -0
  388. mindspore/tinyxml2.dll +0 -0
  389. mindspore/train/_utils.py +7 -3
  390. mindspore/train/amp.py +323 -123
  391. mindspore/train/anf_ir_pb2.py +14 -2
  392. mindspore/train/callback/_backup_and_restore.py +2 -12
  393. mindspore/train/callback/_callback.py +29 -4
  394. mindspore/train/callback/_checkpoint.py +23 -8
  395. mindspore/train/callback/_early_stop.py +2 -2
  396. mindspore/train/callback/_landscape.py +4 -4
  397. mindspore/train/callback/_loss_monitor.py +2 -2
  398. mindspore/train/callback/_on_request_exit.py +2 -2
  399. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  400. mindspore/train/callback/_summary_collector.py +15 -8
  401. mindspore/train/callback/_time_monitor.py +58 -5
  402. mindspore/train/data_sink.py +5 -11
  403. mindspore/train/dataset_helper.py +84 -57
  404. mindspore/train/loss_scale_manager.py +2 -2
  405. mindspore/train/metrics/__init__.py +3 -3
  406. mindspore/train/metrics/cosine_similarity.py +1 -1
  407. mindspore/train/metrics/hausdorff_distance.py +3 -2
  408. mindspore/train/metrics/mean_surface_distance.py +3 -2
  409. mindspore/train/metrics/metric.py +39 -19
  410. mindspore/train/metrics/roc.py +2 -2
  411. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  412. mindspore/train/mind_ir_pb2.py +85 -36
  413. mindspore/train/model.py +187 -47
  414. mindspore/train/serialization.py +487 -161
  415. mindspore/train/summary/_summary_adapter.py +1 -1
  416. mindspore/train/summary/_writer_pool.py +3 -2
  417. mindspore/train/summary/summary_record.py +37 -17
  418. mindspore/train/train_thor/convert_utils.py +3 -3
  419. mindspore/train/train_thor/dataset_helper.py +1 -1
  420. mindspore/turbojpeg.dll +0 -0
  421. mindspore/vcmeta.dll +0 -0
  422. mindspore/vcruntime140.dll +0 -0
  423. mindspore/vcruntime140_1.dll +0 -0
  424. mindspore/version.py +1 -1
  425. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
  426. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
  427. mindspore/_extends/graph_kernel/expander.py +0 -80
  428. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  429. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  430. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  431. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  432. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  433. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  434. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  435. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  436. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  437. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  438. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  439. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  440. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  441. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  442. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  443. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  444. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  445. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  446. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  447. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  448. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  449. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  450. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  451. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  452. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  453. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  454. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  455. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  456. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  457. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  458. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  459. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  460. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  461. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  462. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  463. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  464. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  465. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  466. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  467. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  468. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  469. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  470. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  471. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  472. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  473. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  474. mindspore/dataset/datapreprocess/__init__.py +0 -20
  475. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  476. mindspore/include/api/net.h +0 -142
  477. mindspore/nn/lr_scheduler.py +0 -262
  478. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  479. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  480. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  481. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  482. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  483. mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
  484. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  485. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  486. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  487. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  488. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  489. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  490. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  491. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  492. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  493. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  494. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  495. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  496. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  497. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  498. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  499. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  500. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  501. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  502. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  503. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  504. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  505. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  506. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  507. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  508. mindspore/rewrite/node_visitor.py +0 -44
  509. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  510. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  511. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -1,80 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """generate json desc for graph kernel ops"""
16
- import json
17
- import json.decoder as jd
18
- import traceback
19
- from mindspore import log as logger
20
- import mindspore._extends.graph_kernel.expanders as expanders
21
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException
22
-
23
-
24
- def create_expander(expand_info):
25
- """Create an expander according to op name"""
26
- def call_func(func, arg):
27
- return func(arg)
28
- op_name = str(expand_info['name'])
29
- if not hasattr(expanders, op_name):
30
- raise GraphKernelUnsupportedException("Expander does not support op: {}".format(op_name))
31
- expander = getattr(expanders, op_name)
32
- return call_func(expander, expand_info)
33
-
34
-
35
- def extract_expand_info(kernel_info):
36
- """Convert the json into a more friendly format"""
37
- input_desc = []
38
- if 'input_desc' in kernel_info and kernel_info['input_desc']:
39
- for desc in kernel_info['input_desc']:
40
- input_desc += desc
41
- attrs = {}
42
- if 'attr' in kernel_info and kernel_info['attr']:
43
- for attr in kernel_info["attr"]:
44
- attrs[attr["name"]] = attr["value"]
45
- expand_info = {
46
- "name": kernel_info["name"],
47
- "input_desc": input_desc,
48
- "output_desc": kernel_info["output_desc"],
49
- "attr": attrs,
50
- "process": kernel_info["process"],
51
- }
52
- return expand_info
53
-
54
-
55
- def get_op_expander(json_str: str):
56
- """get op expander by json info"""
57
- try:
58
- kernel_info = json.loads(json_str)
59
- expand_info = extract_expand_info(kernel_info)
60
-
61
- expander = create_expander(expand_info)
62
- graph = expander.run()
63
-
64
- # dump graph to json desc.
65
- desc = graph.dump()
66
- return json.dumps(desc)
67
-
68
- except jd.JSONDecodeError:
69
- logger.error("Decode input json str failed in expander, json is: {}".format(json_str))
70
- logger.error(traceback.format_exc())
71
- return ""
72
- except GraphKernelUnsupportedException as e:
73
- logger.info(e.message)
74
- return ""
75
-
76
-
77
- def get_expander_op_list():
78
- """get supported expander op list"""
79
- op_list = [name for name in dir(expanders) if name[0].isupper()]
80
- return ' '.join(op_list)
@@ -1,54 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """expanders init. Deprecated, please add the new operators in the c++ file"""
16
-
17
-
18
- from .addn import AddN
19
- from .batchnorm import BatchNorm
20
- from .batchnorm_grad import BatchNormGrad
21
- from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
22
- from .conv2d import Conv2D
23
- from .complex import CAbs, CAdd, CDiv, CMul, CSub, CRealDiv
24
- from .dropout_grad import DropoutGrad
25
- from .equal_count import EqualCount
26
- from .erfc import Erfc
27
- from .fused_adam import FusedAdam
28
- from .fused_adam_weight_decay import FusedAdamWeightDecay
29
- from .fused_mul_add import FusedMulAdd
30
- from .gelu_grad import GeLUGrad
31
- from .gkdropout import GkDropout
32
- from .identity import Identity
33
- from .layernorm import LayerNorm
34
- from .layernorm_grad import LayerNormGrad
35
- from .logsoftmax import LogSoftmax
36
- from .logsoftmax_grad import LogSoftmaxGrad
37
- from .matmul import BatchMatMul, MatMul
38
- from .maximum_grad import MaximumGrad
39
- from .minimum_grad import MinimumGrad
40
- from .oneslike import OnesLike
41
- from .reduce_mean import ReduceMean
42
- from .relu_grad import ReluGrad
43
- from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits
44
- from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad
45
- from .sigmoid_grad import SigmoidGrad
46
- from .slice import Slice
47
- from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits
48
- from .softmax_grad_ext import SoftmaxGradExt
49
- from .sqrt_grad import SqrtGrad
50
- from .squared_difference import SquaredDifference
51
- from .square_sum_v1 import SquareSumV1
52
- from .square_sum_all import SquareSumAll
53
- from .tanh_grad import TanhGrad
54
- from .softsign import Softsign
@@ -1,269 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """GraphKernel expander utils"""
16
- from abc import ABCMeta, abstractmethod
17
- from mindspore._extends.graph_kernel.model import model_builder as builder
18
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
19
-
20
-
21
- class Expander(metaclass=ABCMeta):
22
- """
23
- Expander is the base class of expanders.
24
-
25
- The method `_expand` should be overridden to implement the operator detail.
26
- """
27
- def __init__(self, expand_info):
28
- self.name = expand_info["name"]
29
- self.inputs = expand_info["input_desc"]
30
- self.outputs = expand_info["output_desc"]
31
- self.attrs = expand_info["attr"]
32
- self.processor = expand_info["process"]
33
-
34
- def run(self):
35
- """
36
- Expand the operator to a graph.
37
-
38
- `GraphKernelUnsupportedException` would be raised if check failed.
39
- """
40
- self._check()
41
- graph_builder = builder.GraphBuilder()
42
- with graph_builder.graph_scope(self.name) as graph_scope:
43
- # transform input_desc to Tensor
44
- self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs]
45
- graph_scope.set_input(*self.inputs)
46
- outputs = self._expand(graph_builder)
47
- if isinstance(outputs, (list, tuple)):
48
- self._check_output_same(outputs)
49
- graph_scope.set_output(*outputs)
50
- else:
51
- self._check_output_same([outputs])
52
- graph_scope.set_output(outputs)
53
-
54
- graph = graph_builder.get()[0]
55
- graph.set_processor(self.processor)
56
- return graph
57
-
58
- def _check(self):
59
- """Check inputs"""
60
-
61
- def _check_output_same(self, outputs):
62
- for index, value in enumerate(self.outputs):
63
- if list(outputs[index].shape) != list(value['shape']):
64
- raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
65
- self.__class__.__name__, list(outputs[index].shape), list(value['shape'])))
66
- if outputs[index].dtype != value['data_type']:
67
- raise GKException("{} 's output data_type {} is wrong. Expected: {}".format(
68
- self.__class__.__name__, outputs[index].dtype, value['data_type']))
69
- if outputs[index].data_format != value['format']:
70
- raise GKException("{} 's output format {} is wrong. Expected: {}".format(
71
- self.__class__.__name__, outputs[index].data_format, value['format']))
72
-
73
- @abstractmethod
74
- def _expand(self, graph_builder):
75
- """Expand operator, this function should be overridden in subclass"""
76
- raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
77
-
78
-
79
- class ExpanderInfoValidator:
80
- """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
81
-
82
- def __init__(self):
83
- """Init"""
84
-
85
- @staticmethod
86
- def _add_check_function(kls, func):
87
- """
88
- Rewrite the function `_check` in class Expander
89
- to append the new `func` after the original checks.
90
- """
91
- old_check = getattr(kls, "_check")
92
-
93
- def new_check(obj):
94
- old_check(obj)
95
- func(obj)
96
-
97
- setattr(kls, "_check", new_check)
98
-
99
- @staticmethod
100
- def add_format(*input_format):
101
- """
102
- Add new supported format for the operator
103
-
104
- this function will add a list `__supported_formats` into the expander,
105
- saving the whitelist of formats that this op supports.
106
- it also rewrites the `_check` function to check the formats.
107
- """
108
- format_list_name = "__supported_formats"
109
-
110
- def _check_format(obj):
111
- inp_formats = [inp['format'] for inp in obj.inputs]
112
- for formats in getattr(obj, format_list_name):
113
- if len(formats) != len(inp_formats):
114
- raise GKException("For '{}', length of registered format is different from the length of inputs "
115
- "format: {} vs {}".format(obj.name, len(formats), len(inp_formats)))
116
- if all((fmt == inp for fmt, inp in zip(formats, inp_formats))):
117
- return
118
- raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
119
-
120
- def wrapper(cls):
121
- if not issubclass(cls, Expander):
122
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
123
- if not hasattr(cls, format_list_name):
124
- setattr(cls, format_list_name, list())
125
- ExpanderInfoValidator._add_check_function(cls, _check_format)
126
- getattr(cls, format_list_name).append(input_format)
127
- return cls
128
-
129
- return wrapper
130
-
131
- @staticmethod
132
- def check_all_formats_same(kls):
133
- """Check that all formats are the same"""
134
-
135
- # Ensure no args case can return a class
136
- def _check(*args):
137
- def _check_format(obj):
138
- inp_formats = [inp['format'] for inp in obj.inputs]
139
- if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
140
- return
141
- raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
142
- ','.join(inp_formats), obj.name))
143
-
144
- def wrapper(cls):
145
- if not issubclass(cls, Expander):
146
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
147
- ExpanderInfoValidator._add_check_function(cls, _check_format)
148
- return cls
149
-
150
- return wrapper
151
-
152
- return _check()(kls)
153
-
154
- @staticmethod
155
- def check_attrs(*args):
156
- """Check the attrs exist"""
157
-
158
- def _check_attr(obj):
159
- for a in args:
160
- if a not in obj.attrs:
161
- raise GKException("attr '{}' does not exist. {}".format(a, obj.name))
162
-
163
- def wrapper(cls):
164
- if not issubclass(cls, Expander):
165
- raise Exception("{} should be subclass of Expander.".format(cls.__name__))
166
- ExpanderInfoValidator._add_check_function(cls, _check_attr)
167
- return cls
168
-
169
- return wrapper
170
-
171
-
172
- def to_frac_z_axis(ori_shape, ori_axis):
173
- """
174
- judge the format is fractal NZ
175
- Parameters
176
- ----------
177
- ori_shape: list or tuple
178
- original shape of input
179
- ori_axis: list or tuple
180
- original axis of original shape to operate
181
- Returns
182
- -------
183
- output: list
184
- axis of the fractal Nz shape
185
- """
186
- frac_z_axis = list(ori_axis)
187
- shape_len = len(ori_shape)
188
- axis_count = len(frac_z_axis)
189
- axis_negative_1 = shape_len - 1
190
- axis_negative_2 = shape_len - 2
191
- for i in range(axis_count):
192
- axis_index = (frac_z_axis[i] + shape_len) % shape_len
193
- if axis_index == axis_negative_1:
194
- if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
195
- frac_z_axis[i] = axis_index - 1
196
- frac_z_axis.append(axis_index + 2)
197
- else: # no case cover this branch now
198
- frac_z_axis[i] = axis_index - 1
199
- frac_z_axis.append(axis_index + 2)
200
- elif axis_index == axis_negative_2:
201
- frac_z_axis[i] = axis_index + 1
202
- frac_z_axis.append(axis_index + 2)
203
- else:
204
- frac_z_axis[i] = axis_index
205
- return frac_z_axis
206
-
207
-
208
- def infer_shape_from_fractalnz(fractal):
209
- "get original shape from fractalnz shape"
210
- shape = []
211
- dims = len(fractal)
212
- batch = dims - 4
213
- for i in range(batch):
214
- shape.append(fractal[i])
215
- m = fractal[dims - 3] * fractal[dims - 2]
216
- n = fractal[dims - 4] * fractal[dims - 1]
217
- shape.append(m)
218
- shape.append(n)
219
- return shape
220
-
221
-
222
- def get_reduced_ori_shape(shape, axis):
223
- "get shape after reduced which is based on original shape"
224
- reduced_ori_shape = []
225
- for i, value in enumerate(shape):
226
- if i in axis:
227
- reduced_ori_shape.append(1)
228
- else:
229
- reduced_ori_shape.append(value)
230
- return reduced_ori_shape
231
-
232
-
233
- def get_reduce_axis_shape(shape, data_format, axis):
234
- """
235
- Get the reduce axis under format `data_format` and original reduced shape.
236
- Parameters
237
- ----------
238
- shape: list or tuple
239
- shape of input
240
- data_format: str
241
- data format of input
242
- axis: None, int, list or tuple
243
- reduce axis of the original shape
244
- Returns
245
- -------
246
- reduce_axis: list
247
- reduce axis of the `data_format` shape
248
- ori_reduced_shape: list
249
- original reduced shape
250
- """
251
- ori_shape = shape
252
- if data_format == "FRACTAL_NZ":
253
- ori_shape = infer_shape_from_fractalnz(shape)
254
- if not axis:
255
- axis = []
256
- for i, _ in enumerate(ori_shape):
257
- axis.append(i)
258
- else:
259
- if isinstance(axis, int):
260
- axis = [axis]
261
- for i, _ in enumerate(list(axis)):
262
- if axis[i] < 0:
263
- axis[i] += len(ori_shape)
264
-
265
- ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
266
- reduce_axis = axis
267
- if data_format == "FRACTAL_NZ":
268
- reduce_axis = to_frac_z_axis(ori_shape, axis)
269
- return reduce_axis, ori_reduced_shape
@@ -1,33 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for addn"""
16
- from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
-
19
-
20
- @VLD.check_all_formats_same
21
- class AddN(Expander):
22
- """Expand AddN to multiple Adds"""
23
-
24
- def _check(self):
25
- if len(self.inputs) < 2:
26
- raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
27
- .format(len(self.inputs)))
28
-
29
- def _expand(self, graph_builder):
30
- result = self.inputs[0]
31
- for inp in self.inputs[1:]:
32
- result = graph_builder.emit('Add', [result, inp])
33
- return result
@@ -1,152 +0,0 @@
1
- # Copyright 2021-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ===========================================================================
15
- """generate json desc for BatchNorm"""
16
- from mindspore._extends.graph_kernel.model.model import DataFormat as DF
17
- from ._utils import Expander, ExpanderInfoValidator as VLD
18
- from .expand_dims import ExpandDims
19
-
20
-
21
- @VLD.add_format(DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
22
- @VLD.add_format(DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
23
- @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
24
- @VLD.check_attrs('is_training', 'momentum', 'epsilon')
25
- class BatchNorm(Expander):
26
- """BatchNorm expander"""
27
-
28
- def _expand(self, graph_builder):
29
- # get op info
30
- input_x = self.inputs[0]
31
- input_scale = self.inputs[1]
32
- input_offset = self.inputs[2]
33
- input_mean = self.inputs[3]
34
- input_variance = self.inputs[4]
35
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
36
-
37
- input_x_ori_type = input_x.dtype
38
- input_x_new_type = input_x.dtype
39
- if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
40
- input_mean.dtype == "float32" and input_variance.dtype == "float32":
41
- input_x_new_type = "float32"
42
- if input_x_new_type != input_x_ori_type:
43
- input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
44
-
45
- if self.attrs['is_training']:
46
- self.inputs[0] = input_x
47
- res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
48
- if input_x_new_type != input_x_ori_type:
49
- res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
50
- return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
51
- # infer mode
52
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
53
- input_mean = graph_builder.emit(
54
- 'Reshape', [input_mean], attrs={'shape': ExpandDims.infer_shape(input_mean.shape, [-1, -1])})
55
- input_scale = graph_builder.emit(
56
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
57
- input_offset = graph_builder.emit(
58
- 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
59
- x_sub = graph_builder.emit('Sub', [input_x, input_mean])
60
- x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub])
61
- var_add = graph_builder.emit('Add', [epsilon_v, input_variance])
62
- var_add_sqrt = graph_builder.emit('Sqrt', [var_add])
63
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
64
- var_add_sqrt = graph_builder.emit(
65
- 'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
66
- x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
67
- res_y = graph_builder.emit('Add', [input_offset, x_div])
68
- if input_x_new_type != input_x_ori_type:
69
- res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
70
- return res_y, var_add, var_add, var_add, var_add
71
-
72
- def _bn_train(self, graph_builder):
73
- """expand BatchNorm for training mode"""
74
- input_x = self.inputs[0]
75
- input_scale = self.inputs[1]
76
- input_offset = self.inputs[2]
77
- input_mean = self.inputs[3]
78
- input_variance = self.inputs[4]
79
- epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
80
- reduce_axis = ()
81
- shape_x = input_x.shape
82
- if input_x.data_format == DF.NHWC:
83
- reduce_axis = (0, 1, 2)
84
- num = shape_x[0] * shape_x[1] * shape_x[2]
85
- else:
86
- reduce_axis = (0, 2, 3)
87
- num = shape_x[0] * shape_x[2] * shape_x[3]
88
- num_rec = 1.0 / num
89
- num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
90
-
91
- # compute mean value of input_x
92
- mean_sum = graph_builder.emit(
93
- 'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
94
- mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
95
-
96
- # compute variance of input_x
97
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
98
- mean_muls_expand = graph_builder.emit(
99
- 'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
100
- else:
101
- mean_muls_expand = mean_muls
102
- var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
103
- var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
104
- var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
105
- var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
106
-
107
- # y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
108
- scalar_one = 1.0
109
- scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
110
- y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
111
- y_sqrt = graph_builder.emit('Sqrt', [y_add])
112
- y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
113
-
114
- # compute res_y
115
- tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
116
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
117
- y_sqrt_rec_expand = graph_builder.emit(
118
- 'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
119
- else:
120
- y_sqrt_rec_expand = y_sqrt_rec
121
- y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
122
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
123
- input_scale_expand = graph_builder.emit(
124
- 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
125
- else:
126
- input_scale_expand = input_scale
127
- res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
128
- if input_x.data_format in (DF.DEFAULT, DF.NCHW):
129
- input_offset_expand = graph_builder.emit(
130
- 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
131
- else:
132
- input_offset_expand = input_offset
133
- res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
134
-
135
- # compute mean_res
136
- momentum_sub = scalar_one - self.attrs['momentum']
137
- momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
138
- new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
139
- momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
140
- current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
141
- updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
142
- mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
143
-
144
- # variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
145
- var_num = float(num) / (num - 1)
146
- var_num_v = graph_builder.value(input_scale.dtype, var_num)
147
- var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
148
- new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
149
- current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
150
- updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
151
- variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
152
- return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec